Skip to content

Commit

Permalink
config: add flag values automatically (#17)
Browse files Browse the repository at this point in the history
Add mechanism to automatically add flag values based on the config struct.
The default value is inferred from the provided struct.

Currently, no help text is added to the flags.
  • Loading branch information
oncilla authored Sep 26, 2020
1 parent 2a2ff23 commit 6107f2b
Show file tree
Hide file tree
Showing 6 changed files with 232 additions and 23 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ require (
github.com/golang/mock v1.4.4
github.com/mitchellh/mapstructure v1.1.2
github.com/spf13/cobra v1.0.0
github.com/spf13/pflag v1.0.5 // indirect
github.com/spf13/pflag v1.0.5
github.com/spf13/viper v1.4.0
github.com/stretchr/testify v1.3.0
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543
Expand Down
93 changes: 93 additions & 0 deletions pkg/boa/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,14 @@
package boa

import (
"fmt"
"net"
"reflect"
"strings"
"time"

"github.com/mitchellh/mapstructure"
"github.com/spf13/pflag"
)

// ConfigRegistry is an abstraction of viper.Viper.
Expand Down Expand Up @@ -72,6 +77,94 @@ func bindEnv(r ConfigRegistry, p path, config map[string]interface{}) error {
return nil
}

// AddFlags adds flags to the provided flag set based on the config struct.
// Default are set according to the values present in the config struct.
func AddFlags(r *pflag.FlagSet, config interface{}) error {
m := map[string]interface{}{}
if err := mapstructure.Decode(config, &m); err != nil {
return err
}
return addFlags(r, nil, m)
}

// nolint: gocyclo
func addFlags(r *pflag.FlagSet, p path, config map[string]interface{}) error {
for key, value := range config {
keyPath := p.Extend(key)
if m, ok := value.(map[string]interface{}); ok {
if err := addFlags(r, keyPath, m); err != nil {
return err
}
continue
}

if v, ok := value.(pflag.Value); ok {
r.Var(v, keyPath.String(), "")
continue
}

t, v := reflect.TypeOf(value), reflect.ValueOf(value)
if t.Kind() == reflect.Ptr || t.Kind() == reflect.Interface {
t, v = t.Elem(), reflect.Indirect(v)
if v.Kind() == reflect.Invalid {
v = reflect.Zero(t)
}
}
switch {
case t == reflect.TypeOf(time.Duration(1)):
r.Duration(keyPath.String(), time.Duration(v.Int()), "")
case t == reflect.TypeOf(net.IP{}):
r.IP(keyPath.String(), v.Interface().(net.IP), "")
case t.Kind() == reflect.Bool:
r.Bool(keyPath.String(), v.Bool(), "")
case t.Kind() == reflect.Float32:
r.Float32(keyPath.String(), float32(v.Float()), "")
case t.Kind() == reflect.Float64:
r.Float64(keyPath.String(), v.Float(), "")
case t.Kind() == reflect.Int8:
r.Int8(keyPath.String(), int8(v.Int()), "")
case t.Kind() == reflect.Int16:
r.Int16(keyPath.String(), int16(v.Int()), "")
case t.Kind() == reflect.Int32:
r.Int32(keyPath.String(), int32(v.Int()), "")
case t.Kind() == reflect.Int64:
r.Int64(keyPath.String(), v.Int(), "")
case t.Kind() == reflect.Int:
r.Int(keyPath.String(), int(v.Int()), "")
case t.Kind() == reflect.Uint8:
r.Uint8(keyPath.String(), uint8(v.Uint()), "")
case t.Kind() == reflect.Uint16:
r.Uint16(keyPath.String(), uint16(v.Uint()), "")
case t.Kind() == reflect.Uint32:
r.Uint32(keyPath.String(), uint32(v.Uint()), "")
case t.Kind() == reflect.Uint64:
r.Uint64(keyPath.String(), v.Uint(), "")
case t.Kind() == reflect.Uint:
r.Uint(keyPath.String(), uint(v.Uint()), "")
case t.Kind() == reflect.String:
r.String(keyPath.String(), v.String(), "")
case t.Kind() == reflect.Slice:
switch t.Elem().Kind() {
case reflect.Bool:
r.BoolSlice(keyPath.String(), v.Interface().([]bool), "")
case reflect.Int32:
r.Int32Slice(keyPath.String(), v.Interface().([]int32), "")
case reflect.Int64:
r.Int64Slice(keyPath.String(), v.Interface().([]int64), "")
case reflect.Int:
r.IntSlice(keyPath.String(), v.Interface().([]int), "")
case reflect.Uint:
r.UintSlice(keyPath.String(), v.Interface().([]uint), "")
case reflect.String:
r.StringSlice(keyPath.String(), v.Interface().([]string), "")
}
default:
return fmt.Errorf("unsupported value: %s (%T)", keyPath, value)
}
}
return nil
}

type path []string

func (p path) Extend(key string) path {
Expand Down
33 changes: 30 additions & 3 deletions pkg/boa/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@ import (

"github.com/golang/mock/gomock"
"github.com/mitchellh/mapstructure"
"github.com/spf13/pflag"
"github.com/spf13/viper"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/oncilla/boa/pkg/boa"
"github.com/oncilla/boa/pkg/boa/flag"
"github.com/oncilla/boa/pkg/boa/mock_boa"
)

Expand All @@ -34,7 +36,7 @@ type Config struct {
User string `mapstructure:"user"`
Password string `mapstructure:"password"`
} `mapstructure:"db"`
Addr *net.TCPAddr `mapstructure:"addr"`
Addr *flag.TCPAddr `mapstructure:"addr"`
Token `mapstructure:",squash"`
}

Expand Down Expand Up @@ -64,13 +66,13 @@ func TestSetDefault(t *testing.T) {
var config Config
config.DB.User = "oncilla"
config.DB.Password = "password"
config.Addr = &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 8080}
config.Addr = &flag.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 8080}
config.Token.Token = "token"

r := mock_boa.NewMockConfigRegistry(ctrl)
r.EXPECT().SetDefault("db.user", "oncilla")
r.EXPECT().SetDefault("db.password", "password")
r.EXPECT().SetDefault("addr", &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 8080})
r.EXPECT().SetDefault("addr", &flag.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 8080})
r.EXPECT().SetDefault("token", "token")

err := boa.SetDefaults(r, &config)
Expand Down Expand Up @@ -103,3 +105,28 @@ func TestViperBindEnv(t *testing.T) {
assert.Equal(t, "127.0.0.1:8080", config.Addr.String())
assert.Equal(t, "token", config.Token.Token)
}

func TestAddFlags(t *testing.T) {
var config Config
config.DB.User = "oncilla"
config.DB.Password = "password"
config.Addr = &flag.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 8080}
config.Token.Token = "token"

s := pflag.NewFlagSet("", pflag.ContinueOnError)
err := boa.AddFlags(s, &config)
require.NoError(t, err)

user, err := s.GetString("db.user")
assert.NoError(t, err)
assert.Equal(t, "oncilla", user)
password, err := s.GetString("db.password")
assert.NoError(t, err)
assert.Equal(t, "password", password)
addr := s.Lookup("addr")
assert.NotNil(t, addr)
assert.Equal(t, "127.0.0.1:8080", addr.Value.String())
token, err := s.GetString("token")
assert.NoError(t, err)
assert.Equal(t, "token", token)
}
80 changes: 80 additions & 0 deletions pkg/boa/flag/flag.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
// Copyright 2020 oncilla
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package flag

import (
"encoding"
"net"

"github.com/spf13/pflag"
)

var _ pflag.Value = (*TCPAddr)(nil)
var _ encoding.TextMarshaler = (*TCPAddr)(nil)

// TCPAddr implements pflags.Value
type TCPAddr net.TCPAddr

func (addr *TCPAddr) Set(input string) error {
p, err := net.ResolveTCPAddr("tcp", input)
if err != nil {
return err
}
*addr = TCPAddr(*p)
return nil
}

func (addr *TCPAddr) UnmarshalText(b []byte) error {
return addr.Set(string(b))
}

func (addr *TCPAddr) Type() string {
return "tcp-addr"
}

func (addr *TCPAddr) MarshalText() ([]byte, error) {
return []byte(addr.String()), nil
}

func (addr *TCPAddr) String() string {
return (*net.TCPAddr)(addr).String()
}

// UDPAddr implements pflags.Value
type UDPAddr net.UDPAddr

func (addr *UDPAddr) Set(input string) error {
p, err := net.ResolveUDPAddr("tcp", input)
if err != nil {
return err
}
*addr = UDPAddr(*p)
return nil
}

func (addr *UDPAddr) UnmarshalText(b []byte) error {
return addr.Set(string(b))
}

func (addr *UDPAddr) Type() string {
return "tcp-addr"
}

func (addr *UDPAddr) MarshalText() ([]byte, error) {
return []byte(addr.String()), nil
}

func (addr *UDPAddr) String() string {
return (*net.UDPAddr)(addr).String()
}
7 changes: 3 additions & 4 deletions pkg/boa/hooks.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"reflect"

"github.com/mitchellh/mapstructure"
"github.com/oncilla/boa/pkg/boa/flag"
)

// DefaultDecodeHooks returns a list of useful decoding hooks.
Expand All @@ -41,10 +42,9 @@ func StringToTCPAddrHookFunc() mapstructure.DecodeHookFunc {
if f.Kind() != reflect.String {
return data, nil
}
if t != reflect.TypeOf(net.TCPAddr{}) {
if t != reflect.TypeOf(net.TCPAddr{}) && t != reflect.TypeOf(flag.TCPAddr{}) {
return data, nil
}

addr, err := net.ResolveTCPAddr("tcp", data.(string))
if err != nil {
return nil, err
Expand All @@ -63,10 +63,9 @@ func StringToUDPAddrHookFunc() mapstructure.DecodeHookFunc {
if f.Kind() != reflect.String {
return data, nil
}
if t != reflect.TypeOf(net.UDPAddr{}) {
if t != reflect.TypeOf(net.UDPAddr{}) && t != reflect.TypeOf(flag.UDPAddr{}) {
return data, nil
}

addr, err := net.ResolveUDPAddr("udp", data.(string))
if err != nil {
return nil, err
Expand Down
40 changes: 25 additions & 15 deletions sample/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,33 @@ import (
"gopkg.in/yaml.v3"

"github.com/oncilla/boa/pkg/boa"
"github.com/oncilla/boa/pkg/boa/flag"
)

func defaultConfig() *Config {
return &Config{
DB: DB{
User: "user",
Password: "password",
},
Addr: &flag.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 8080},
}
}

func main() {
v := viper.New()

cmd := &cobra.Command{
Use: "config <config-file>",
Short: "A sample application with config parsing",
Long: `This is a sample application that showcases config parsign with the help of boa.
This application loads the configuration based on the following precedence:
1. Environment variable
2. Configuration file
3. Default values
1. Command line flag
2. Environment variable
3. Configuration file
4. Default value
The default configuration is:
Expand All @@ -54,18 +68,10 @@ SAMPLE_DB_USER=secure
`,
SilenceErrors: true,
RunE: func(cmd *cobra.Command, args []string) error {
v := viper.New()
cfg := defaultConfig()
v.SetEnvPrefix("sample")
v.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))

// Define the configuration struct.
cfg := &Config{
DB: DB{
User: "user",
Password: "password",
},
Addr: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 8080},
}
v.BindPFlags(cmd.Flags())

// Set the default values for the configuration.
if err := boa.SetDefaults(v, cfg); err != nil {
Expand Down Expand Up @@ -102,15 +108,19 @@ SAMPLE_DB_USER=secure
return enc.Encode(out)
},
}
if err := boa.AddFlags(cmd.Flags(), defaultConfig()); err != nil {
fmt.Fprintf(os.Stderr, "Error: %s\n", err)
os.Exit(1)
}
if err := cmd.Execute(); err != nil {
fmt.Fprintf(os.Stderr, "Error: %s\n", err)
os.Exit(1)
}
}

type Config struct {
DB DB `mapstructure:"db"`
Addr *net.TCPAddr `mapstructure:"addr"`
DB DB `mapstructure:"db"`
Addr *flag.TCPAddr `mapstructure:"addr"`
}

type DB struct {
Expand Down

0 comments on commit 6107f2b

Please sign in to comment.