Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
stdlibConfig strictMode should not fail if there are existing flags (#89
Browse files Browse the repository at this point in the history
)

* Do not error on existing flags

Signed-off-by: Haytham Abuelfutuh <[email protected]>

* Add unit test

Signed-off-by: Haytham Abuelfutuh <[email protected]>

* Update unit test

Signed-off-by: Haytham Abuelfutuh <[email protected]>
  • Loading branch information
EngHabu authored Jul 9, 2021
1 parent bb36111 commit f64c747
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 11 deletions.
19 changes: 19 additions & 0 deletions config/tests/accessor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,25 @@ func TestStrictAccessor(t *testing.T) {
assert.Error(t, v.UpdateConfig(context.TODO()))
})

t.Run(fmt.Sprintf("[%v] flags defined outside", provider(config.Options{}).ID()), func(t *testing.T) {
reg := config.NewRootSection()
_, err := reg.RegisterSection(MyComponentSectionKey, &MyComponentConfig{})
assert.NoError(t, err)

_, err = reg.RegisterSection(OtherComponentSectionKey, &OtherComponentConfig{})
assert.NoError(t, err)
v := provider(config.Options{
StrictMode: true,
SearchPaths: []string{filepath.Join("testdata", "bad_config.yaml")},
RootSection: reg,
})

set := pflag.NewFlagSet("test", pflag.ExitOnError)
set.StringP("unknown-key", "u", "", "")
v.InitializePflags(set)
assert.NoError(t, v.UpdateConfig(context.TODO()))
})

t.Run(fmt.Sprintf("[%v] Set through env", provider(config.Options{}).ID()), func(t *testing.T) {
reg := config.NewRootSection()
_, err := reg.RegisterSection(MyComponentSectionKey, &MyComponentConfig{})
Expand Down
8 changes: 4 additions & 4 deletions config/tests/testdata/bad_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ other-component:
int-val: 4
string-value: Hey there!
strings:
- hello
- world
- '!'
- hello
- world
- '!'
url-value: http://something.com
unknown-key: "something"
unknown-key: "something"

32 changes: 25 additions & 7 deletions config/viper/viper.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import (
"strings"
"sync"

"k8s.io/apimachinery/pkg/util/sets"

"github.com/pkg/errors"

stdLibErrs "github.com/flyteorg/flytestdlib/errors"
Expand Down Expand Up @@ -44,6 +46,7 @@ type viperAccessor struct {
rootConfig config.Section
// Ensures we initialize the file Watcher once.
watcherInitializer *sync.Once
existingFlagKeys sets.String
}

func (viperAccessor) ID() string {
Expand All @@ -54,7 +57,17 @@ func (viperAccessor) InitializeFlags(cmdFlags *flag.FlagSet) {
// TODO: Implement?
}

func (v viperAccessor) InitializePflags(cmdFlags *pflag.FlagSet) {
func (v *viperAccessor) InitializePflags(cmdFlags *pflag.FlagSet) {
existingFlagKeys := sets.NewString()
cmdFlags.VisitAll(func(f *pflag.Flag) {
existingFlagKeys.Insert(f.Name)
if len(f.Shorthand) > 0 {
existingFlagKeys.Insert(f.Shorthand)
}
})

v.existingFlagKeys = existingFlagKeys

err := v.addSectionsPFlags(cmdFlags)
if err != nil {
panic(errors.Wrap(err, "error adding config PFlags to flag set"))
Expand Down Expand Up @@ -260,12 +273,14 @@ func (v viperAccessor) parseViperConfigRecursive(root config.Section, settings i
errs := stdLibErrs.ErrorCollection{}
var mine interface{}
myKeysCount := 0
discoveredKeys := sets.NewString()
if asMap, casted := settings.(map[string]interface{}); casted {
myMap := map[string]interface{}{}
for childKey, childValue := range asMap {
if childSection, found := root.GetSections()[childKey]; found {
errs.Append(v.parseViperConfigRecursive(childSection, childValue))
} else {
discoveredKeys.Insert(childKey)
myMap[childKey] = childValue
}
}
Expand All @@ -276,6 +291,7 @@ func (v viperAccessor) parseViperConfigRecursive(root config.Section, settings i
mine = settings
myKeysCount = len(asSlice)
} else {
discoveredKeys.Insert(fmt.Sprintf("%v", mine))
mine = settings
if settings != nil {
myKeysCount = 1
Expand All @@ -296,10 +312,12 @@ func (v viperAccessor) parseViperConfigRecursive(root config.Section, settings i
} else if myKeysCount > 0 {
// There are keys set that are meant to be decoded but no config to receive them. Fail if strict mode is on.
if v.strictMode {
errs.Append(errors.Wrap(
config.ErrStrictModeValidation,
fmt.Sprintf("strict mode is on but received keys [%+v] to decode with no config assigned to"+
" receive them", mine)))
if newKeys := discoveredKeys.Difference(v.existingFlagKeys); newKeys.Len() > 0 {
errs.Append(errors.Wrap(
config.ErrStrictModeValidation,
fmt.Sprintf("strict mode is on but received keys [%+v] to decode with no config assigned to"+
" receive them", newKeys)))
}
}
}

Expand Down Expand Up @@ -396,7 +414,7 @@ func NewAccessor(opts config.Options) config.Accessor {
return newAccessor(opts)
}

func newAccessor(opts config.Options) viperAccessor {
func newAccessor(opts config.Options) *viperAccessor {
vipers := make([]Viper, 0, 1)
configFiles := files.FindConfigFiles(opts.SearchPaths)
for _, configFile := range configFiles {
Expand All @@ -417,7 +435,7 @@ func newAccessor(opts config.Options) viperAccessor {
r = config.GetRootSection()
}

return viperAccessor{
return &viperAccessor{
strictMode: opts.StrictMode,
rootConfig: r,
viper: &CollectionProxy{underlying: vipers},
Expand Down

0 comments on commit f64c747

Please sign in to comment.