Skip to content

Commit

Permalink
🐛 fix provider flags with ConfigEntry="-"
Browse files Browse the repository at this point in the history
When a provider defines a flag with `ConfigEntry = "-"`, then we do not
Bind the Flag to the `viper` config. For those flags, we will continue
to fetch the value directly from the flag, that is, from `cobra`.

Signed-off-by: Salim Afiune Maya <[email protected]>
  • Loading branch information
afiune committed Nov 15, 2024
1 parent b5fd146 commit 2ad9b65
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 3 deletions.
47 changes: 44 additions & 3 deletions cli/providers/providers.go
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ func attachFlags(flagset *pflag.FlagSet, flags []plugin.Flag) {
}
}

func getFlagValue(flag plugin.Flag) *llx.Primitive {
func getFlagValueFromConfig(flag plugin.Flag) *llx.Primitive {
switch flag.Type {
case plugin.FlagType_Bool:
return llx.BoolPrimitive(viper.GetBool(flag.Long))
Expand All @@ -336,6 +336,38 @@ func getFlagValue(flag plugin.Flag) *llx.Primitive {
}
}

func getFlagValueFromCobra(flag plugin.Flag, cmd *cobra.Command) *llx.Primitive {
var err error
switch flag.Type {
case plugin.FlagType_Bool:
if v, err := cmd.Flags().GetBool(flag.Long); err == nil {
return llx.BoolPrimitive(v)
}
case plugin.FlagType_Int:
if v, err := cmd.Flags().GetInt(flag.Long); err == nil {
return llx.IntPrimitive(int64(v))
}
case plugin.FlagType_String:
if v, err := cmd.Flags().GetString(flag.Long); err == nil {
return llx.StringPrimitive(v)
}
case plugin.FlagType_List:
if v, err := cmd.Flags().GetStringSlice(flag.Long); err == nil {
return llx.ArrayPrimitiveT(v, llx.StringPrimitive, types.String)
}
case plugin.FlagType_KeyValue:
if v, err := cmd.Flags().GetStringToString(flag.Long); err == nil {
return llx.MapPrimitiveT(v, llx.StringPrimitive, types.String)
}
default:
log.Warn().Msg("unknown flag type for " + flag.Long)
return nil
}

log.Warn().Err(err).Msg("failed to get flag " + flag.Long)
return nil
}

func setConnector(provider *plugin.Provider, connector *plugin.Connector, run func(*cobra.Command, *providers.Runtime, *plugin.ParseCLIRes), cmd *cobra.Command) {
oldRun := cmd.Run
oldPreRun := cmd.PreRun
Expand All @@ -353,6 +385,7 @@ func setConnector(provider *plugin.Provider, connector *plugin.Connector, run fu
for i := range allFlags {
flag := allFlags[i]
if flag.ConfigEntry == "-" {
log.Debug().Msg("skipping config binding for " + flag.Long)
continue
}

Expand Down Expand Up @@ -408,8 +441,16 @@ func setConnector(provider *plugin.Provider, connector *plugin.Connector, run fu
continue
}

if v := getFlagValue(flag); v != nil {
flagVals[flag.Long] = v
// if the provider flag was configured to avoid using the config,
// we should instead fetch the flag value from `cobra` directly.
if flag.ConfigEntry == "-" {
if v := getFlagValueFromCobra(flag, cmd); v != nil {
flagVals[flag.Long] = v
}
} else {
if v := getFlagValueFromConfig(flag); v != nil {
flagVals[flag.Long] = v
}
}
}

Expand Down
27 changes: 27 additions & 0 deletions test/providers/os_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,4 +225,31 @@ func TestProvidersEnvVarsLoading(t *testing.T) {
assert.True(t, len(x.Packages) > 0)
}
})

t.Run("command with flags set to not bind to config (ConfigEntry=\"-\")", func(t *testing.T) {
t.Run("should work via direct flag", func(t *testing.T) {
r := test.NewCliTestRunner("./cnquery", "run", "ssh", "localhost", "-c", "ls", "-p", "test", "-v")
err := r.Run()
require.NoError(t, err)
assert.Equal(t, 0, r.ExitCode())
assert.NotNil(t, r.Stdout())
if assert.NotNil(t, r.Stderr()) {
assert.Contains(t, string(r.Stderr()), "skipping config binding for password")
assert.Contains(t, string(r.Stderr()), "enabled ssh password authentication")
}
})
t.Run("should NOT work via config/env-vars", func(t *testing.T) {
os.Setenv("MONDOO_PASSWORD", "test")
defer os.Unsetenv("MONDOO_PASSWORD")
r := test.NewCliTestRunner("./cnquery", "run", "ssh", "localhost", "-c", "ls", "-v")
err := r.Run()
require.NoError(t, err)
assert.Equal(t, 0, r.ExitCode())
assert.NotNil(t, r.Stdout())
if assert.NotNil(t, r.Stderr()) {
assert.Contains(t, string(r.Stderr()), "skipping config binding for password")
assert.NotContains(t, string(r.Stderr()), "enabled ssh password authentication")
}
})
})
}

0 comments on commit 2ad9b65

Please sign in to comment.