From 6107f2b4e3a94d5e0b6d4a68c6ba0f71c2783599 Mon Sep 17 00:00:00 2001 From: Dominik Roos Date: Sun, 27 Sep 2020 01:20:04 +0200 Subject: [PATCH] config: add flag values automatically (#17) Add mechanism to automatically add flag values based on the config struct. The default value is inferred from the provided struct. Currently, no help text is added to the flags. --- go.mod | 2 +- pkg/boa/config.go | 93 +++++++++++++++++++++++++++++++++++++++++ pkg/boa/config_test.go | 33 +++++++++++++-- pkg/boa/flag/flag.go | 80 +++++++++++++++++++++++++++++++++++ pkg/boa/hooks.go | 7 ++-- sample/config/config.go | 40 +++++++++++------- 6 files changed, 232 insertions(+), 23 deletions(-) create mode 100644 pkg/boa/flag/flag.go diff --git a/go.mod b/go.mod index b13bbfb..ac4d82b 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,7 @@ require ( github.com/golang/mock v1.4.4 github.com/mitchellh/mapstructure v1.1.2 github.com/spf13/cobra v1.0.0 - github.com/spf13/pflag v1.0.5 // indirect + github.com/spf13/pflag v1.0.5 github.com/spf13/viper v1.4.0 github.com/stretchr/testify v1.3.0 golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 diff --git a/pkg/boa/config.go b/pkg/boa/config.go index f08ebcc..d86cd06 100644 --- a/pkg/boa/config.go +++ b/pkg/boa/config.go @@ -14,9 +14,14 @@ package boa import ( + "fmt" + "net" + "reflect" "strings" + "time" "github.com/mitchellh/mapstructure" + "github.com/spf13/pflag" ) // ConfigRegistry is an abstraction of viper.Viper. @@ -72,6 +77,94 @@ func bindEnv(r ConfigRegistry, p path, config map[string]interface{}) error { return nil } +// AddFlags adds flags to the provided flag set based on the config struct. +// Default are set according to the values present in the config struct. +func AddFlags(r *pflag.FlagSet, config interface{}) error { + m := map[string]interface{}{} + if err := mapstructure.Decode(config, &m); err != nil { + return err + } + return addFlags(r, nil, m) +} + +// nolint: gocyclo +func addFlags(r *pflag.FlagSet, p path, config map[string]interface{}) error { + for key, value := range config { + keyPath := p.Extend(key) + if m, ok := value.(map[string]interface{}); ok { + if err := addFlags(r, keyPath, m); err != nil { + return err + } + continue + } + + if v, ok := value.(pflag.Value); ok { + r.Var(v, keyPath.String(), "") + continue + } + + t, v := reflect.TypeOf(value), reflect.ValueOf(value) + if t.Kind() == reflect.Ptr || t.Kind() == reflect.Interface { + t, v = t.Elem(), reflect.Indirect(v) + if v.Kind() == reflect.Invalid { + v = reflect.Zero(t) + } + } + switch { + case t == reflect.TypeOf(time.Duration(1)): + r.Duration(keyPath.String(), time.Duration(v.Int()), "") + case t == reflect.TypeOf(net.IP{}): + r.IP(keyPath.String(), v.Interface().(net.IP), "") + case t.Kind() == reflect.Bool: + r.Bool(keyPath.String(), v.Bool(), "") + case t.Kind() == reflect.Float32: + r.Float32(keyPath.String(), float32(v.Float()), "") + case t.Kind() == reflect.Float64: + r.Float64(keyPath.String(), v.Float(), "") + case t.Kind() == reflect.Int8: + r.Int8(keyPath.String(), int8(v.Int()), "") + case t.Kind() == reflect.Int16: + r.Int16(keyPath.String(), int16(v.Int()), "") + case t.Kind() == reflect.Int32: + r.Int32(keyPath.String(), int32(v.Int()), "") + case t.Kind() == reflect.Int64: + r.Int64(keyPath.String(), v.Int(), "") + case t.Kind() == reflect.Int: + r.Int(keyPath.String(), int(v.Int()), "") + case t.Kind() == reflect.Uint8: + r.Uint8(keyPath.String(), uint8(v.Uint()), "") + case t.Kind() == reflect.Uint16: + r.Uint16(keyPath.String(), uint16(v.Uint()), "") + case t.Kind() == reflect.Uint32: + r.Uint32(keyPath.String(), uint32(v.Uint()), "") + case t.Kind() == reflect.Uint64: + r.Uint64(keyPath.String(), v.Uint(), "") + case t.Kind() == reflect.Uint: + r.Uint(keyPath.String(), uint(v.Uint()), "") + case t.Kind() == reflect.String: + r.String(keyPath.String(), v.String(), "") + case t.Kind() == reflect.Slice: + switch t.Elem().Kind() { + case reflect.Bool: + r.BoolSlice(keyPath.String(), v.Interface().([]bool), "") + case reflect.Int32: + r.Int32Slice(keyPath.String(), v.Interface().([]int32), "") + case reflect.Int64: + r.Int64Slice(keyPath.String(), v.Interface().([]int64), "") + case reflect.Int: + r.IntSlice(keyPath.String(), v.Interface().([]int), "") + case reflect.Uint: + r.UintSlice(keyPath.String(), v.Interface().([]uint), "") + case reflect.String: + r.StringSlice(keyPath.String(), v.Interface().([]string), "") + } + default: + return fmt.Errorf("unsupported value: %s (%T)", keyPath, value) + } + } + return nil +} + type path []string func (p path) Extend(key string) path { diff --git a/pkg/boa/config_test.go b/pkg/boa/config_test.go index f41dd92..1208dff 100644 --- a/pkg/boa/config_test.go +++ b/pkg/boa/config_test.go @@ -21,11 +21,13 @@ import ( "github.com/golang/mock/gomock" "github.com/mitchellh/mapstructure" + "github.com/spf13/pflag" "github.com/spf13/viper" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/oncilla/boa/pkg/boa" + "github.com/oncilla/boa/pkg/boa/flag" "github.com/oncilla/boa/pkg/boa/mock_boa" ) @@ -34,7 +36,7 @@ type Config struct { User string `mapstructure:"user"` Password string `mapstructure:"password"` } `mapstructure:"db"` - Addr *net.TCPAddr `mapstructure:"addr"` + Addr *flag.TCPAddr `mapstructure:"addr"` Token `mapstructure:",squash"` } @@ -64,13 +66,13 @@ func TestSetDefault(t *testing.T) { var config Config config.DB.User = "oncilla" config.DB.Password = "password" - config.Addr = &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 8080} + config.Addr = &flag.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 8080} config.Token.Token = "token" r := mock_boa.NewMockConfigRegistry(ctrl) r.EXPECT().SetDefault("db.user", "oncilla") r.EXPECT().SetDefault("db.password", "password") - r.EXPECT().SetDefault("addr", &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 8080}) + r.EXPECT().SetDefault("addr", &flag.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 8080}) r.EXPECT().SetDefault("token", "token") err := boa.SetDefaults(r, &config) @@ -103,3 +105,28 @@ func TestViperBindEnv(t *testing.T) { assert.Equal(t, "127.0.0.1:8080", config.Addr.String()) assert.Equal(t, "token", config.Token.Token) } + +func TestAddFlags(t *testing.T) { + var config Config + config.DB.User = "oncilla" + config.DB.Password = "password" + config.Addr = &flag.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 8080} + config.Token.Token = "token" + + s := pflag.NewFlagSet("", pflag.ContinueOnError) + err := boa.AddFlags(s, &config) + require.NoError(t, err) + + user, err := s.GetString("db.user") + assert.NoError(t, err) + assert.Equal(t, "oncilla", user) + password, err := s.GetString("db.password") + assert.NoError(t, err) + assert.Equal(t, "password", password) + addr := s.Lookup("addr") + assert.NotNil(t, addr) + assert.Equal(t, "127.0.0.1:8080", addr.Value.String()) + token, err := s.GetString("token") + assert.NoError(t, err) + assert.Equal(t, "token", token) +} diff --git a/pkg/boa/flag/flag.go b/pkg/boa/flag/flag.go new file mode 100644 index 0000000..7982406 --- /dev/null +++ b/pkg/boa/flag/flag.go @@ -0,0 +1,80 @@ +// Copyright 2020 oncilla +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package flag + +import ( + "encoding" + "net" + + "github.com/spf13/pflag" +) + +var _ pflag.Value = (*TCPAddr)(nil) +var _ encoding.TextMarshaler = (*TCPAddr)(nil) + +// TCPAddr implements pflags.Value +type TCPAddr net.TCPAddr + +func (addr *TCPAddr) Set(input string) error { + p, err := net.ResolveTCPAddr("tcp", input) + if err != nil { + return err + } + *addr = TCPAddr(*p) + return nil +} + +func (addr *TCPAddr) UnmarshalText(b []byte) error { + return addr.Set(string(b)) +} + +func (addr *TCPAddr) Type() string { + return "tcp-addr" +} + +func (addr *TCPAddr) MarshalText() ([]byte, error) { + return []byte(addr.String()), nil +} + +func (addr *TCPAddr) String() string { + return (*net.TCPAddr)(addr).String() +} + +// UDPAddr implements pflags.Value +type UDPAddr net.UDPAddr + +func (addr *UDPAddr) Set(input string) error { + p, err := net.ResolveUDPAddr("tcp", input) + if err != nil { + return err + } + *addr = UDPAddr(*p) + return nil +} + +func (addr *UDPAddr) UnmarshalText(b []byte) error { + return addr.Set(string(b)) +} + +func (addr *UDPAddr) Type() string { + return "tcp-addr" +} + +func (addr *UDPAddr) MarshalText() ([]byte, error) { + return []byte(addr.String()), nil +} + +func (addr *UDPAddr) String() string { + return (*net.UDPAddr)(addr).String() +} diff --git a/pkg/boa/hooks.go b/pkg/boa/hooks.go index c1a0a4a..3cf2d45 100644 --- a/pkg/boa/hooks.go +++ b/pkg/boa/hooks.go @@ -18,6 +18,7 @@ import ( "reflect" "github.com/mitchellh/mapstructure" + "github.com/oncilla/boa/pkg/boa/flag" ) // DefaultDecodeHooks returns a list of useful decoding hooks. @@ -41,10 +42,9 @@ func StringToTCPAddrHookFunc() mapstructure.DecodeHookFunc { if f.Kind() != reflect.String { return data, nil } - if t != reflect.TypeOf(net.TCPAddr{}) { + if t != reflect.TypeOf(net.TCPAddr{}) && t != reflect.TypeOf(flag.TCPAddr{}) { return data, nil } - addr, err := net.ResolveTCPAddr("tcp", data.(string)) if err != nil { return nil, err @@ -63,10 +63,9 @@ func StringToUDPAddrHookFunc() mapstructure.DecodeHookFunc { if f.Kind() != reflect.String { return data, nil } - if t != reflect.TypeOf(net.UDPAddr{}) { + if t != reflect.TypeOf(net.UDPAddr{}) && t != reflect.TypeOf(flag.UDPAddr{}) { return data, nil } - addr, err := net.ResolveUDPAddr("udp", data.(string)) if err != nil { return nil, err diff --git a/sample/config/config.go b/sample/config/config.go index 2a0b273..7d9b0e5 100644 --- a/sample/config/config.go +++ b/sample/config/config.go @@ -25,9 +25,22 @@ import ( "gopkg.in/yaml.v3" "github.com/oncilla/boa/pkg/boa" + "github.com/oncilla/boa/pkg/boa/flag" ) +func defaultConfig() *Config { + return &Config{ + DB: DB{ + User: "user", + Password: "password", + }, + Addr: &flag.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 8080}, + } +} + func main() { + v := viper.New() + cmd := &cobra.Command{ Use: "config ", Short: "A sample application with config parsing", @@ -35,9 +48,10 @@ func main() { This application loads the configuration based on the following precedence: -1. Environment variable -2. Configuration file -3. Default values +1. Command line flag +2. Environment variable +3. Configuration file +4. Default value The default configuration is: @@ -54,18 +68,10 @@ SAMPLE_DB_USER=secure `, SilenceErrors: true, RunE: func(cmd *cobra.Command, args []string) error { - v := viper.New() + cfg := defaultConfig() v.SetEnvPrefix("sample") v.SetEnvKeyReplacer(strings.NewReplacer(".", "_")) - - // Define the configuration struct. - cfg := &Config{ - DB: DB{ - User: "user", - Password: "password", - }, - Addr: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 8080}, - } + v.BindPFlags(cmd.Flags()) // Set the default values for the configuration. if err := boa.SetDefaults(v, cfg); err != nil { @@ -102,6 +108,10 @@ SAMPLE_DB_USER=secure return enc.Encode(out) }, } + if err := boa.AddFlags(cmd.Flags(), defaultConfig()); err != nil { + fmt.Fprintf(os.Stderr, "Error: %s\n", err) + os.Exit(1) + } if err := cmd.Execute(); err != nil { fmt.Fprintf(os.Stderr, "Error: %s\n", err) os.Exit(1) @@ -109,8 +119,8 @@ SAMPLE_DB_USER=secure } type Config struct { - DB DB `mapstructure:"db"` - Addr *net.TCPAddr `mapstructure:"addr"` + DB DB `mapstructure:"db"` + Addr *flag.TCPAddr `mapstructure:"addr"` } type DB struct {