diff --git a/cli/providers/providers.go b/cli/providers/providers.go index 5e93066ac6..fb32d2000f 100644 --- a/cli/providers/providers.go +++ b/cli/providers/providers.go @@ -5,8 +5,6 @@ package providers import ( "encoding/json" - "go.mondoo.com/cnquery/v11/utils/piped" - "go.mondoo.com/ranger-rpc/status" "os" "strings" @@ -21,6 +19,8 @@ import ( "go.mondoo.com/cnquery/v11/providers-sdk/v1/plugin" "go.mondoo.com/cnquery/v11/providers-sdk/v1/recording" "go.mondoo.com/cnquery/v11/types" + "go.mondoo.com/cnquery/v11/utils/piped" + "go.mondoo.com/ranger-rpc/status" ) type Command struct { @@ -318,35 +318,22 @@ func attachFlags(flagset *pflag.FlagSet, flags []plugin.Flag) { } } -func getFlagValue(flag plugin.Flag, cmd *cobra.Command) *llx.Primitive { +func getFlagValue(flag plugin.Flag) *llx.Primitive { switch flag.Type { case plugin.FlagType_Bool: - v, err := cmd.Flags().GetBool(flag.Long) - if err == nil { - return llx.BoolPrimitive(v) - } - log.Warn().Err(err).Msg("failed to get flag " + flag.Long) + return llx.BoolPrimitive(viper.GetBool(flag.Long)) case plugin.FlagType_Int: - if v, err := cmd.Flags().GetInt(flag.Long); err == nil { - return llx.IntPrimitive(int64(v)) - } + return llx.IntPrimitive(viper.GetInt64(flag.Long)) case plugin.FlagType_String: - if v, err := cmd.Flags().GetString(flag.Long); err == nil { - return llx.StringPrimitive(v) - } + return llx.StringPrimitive(viper.GetString(flag.Long)) case plugin.FlagType_List: - if v, err := cmd.Flags().GetStringSlice(flag.Long); err == nil { - return llx.ArrayPrimitiveT(v, llx.StringPrimitive, types.String) - } + return llx.ArrayPrimitiveT(viper.GetStringSlice(flag.Long), llx.StringPrimitive, types.String) case plugin.FlagType_KeyValue: - if v, err := cmd.Flags().GetStringToString(flag.Long); err == nil { - return llx.MapPrimitiveT(v, llx.StringPrimitive, types.String) - } + return llx.MapPrimitiveT(viper.GetStringMapString(flag.Long), llx.StringPrimitive, types.String) default: log.Warn().Msg("unknown flag type for " + flag.Long) return nil } - return nil } func setConnector(provider *plugin.Provider, connector *plugin.Connector, run func(*cobra.Command, *providers.Runtime, *plugin.ParseCLIRes), cmd *cobra.Command) { @@ -421,7 +408,7 @@ func setConnector(provider *plugin.Provider, connector *plugin.Connector, run fu continue } - if v := getFlagValue(flag, cmd); v != nil { + if v := getFlagValue(flag); v != nil { flagVals[flag.Long] = v } } diff --git a/test/providers/os_test.go b/test/providers/os_test.go index cc029498c6..05d69b1722 100644 --- a/test/providers/os_test.go +++ b/test/providers/os_test.go @@ -4,15 +4,16 @@ package providers import ( - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "go.mondoo.com/cnquery/v11/test" "log" "os" "os/exec" "path/filepath" "sync" "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.mondoo.com/cnquery/v11/test" ) var once sync.Once @@ -185,3 +186,43 @@ func TestOsProviderSharedTests(t *testing.T) { } } } + +func TestProvidersEnvVarsLoading(t *testing.T) { + t.Run("command WITHOUT path should not find any package", func(t *testing.T) { + r := test.NewCliTestRunner("./cnquery", "run", "fs", "-c", mqlPackagesQuery, "-j") + err := r.Run() + require.NoError(t, err) + assert.Equal(t, 0, r.ExitCode()) + assert.NotNil(t, r.Stdout()) + assert.NotNil(t, r.Stderr()) + + var c mqlPackages + err = r.Json(&c) + assert.NoError(t, err) + + // No packages + assert.Empty(t, c) + }) + t.Run("command WITH path should find packages", func(t *testing.T) { + os.Setenv("MONDOO_PATH", "./testdata/fs") + defer os.Unsetenv("MONDOO_PATH") + // Note we are not passing the flag "--path ./testdata/fs" + r := test.NewCliTestRunner("./cnquery", "run", "fs", "-c", mqlPackagesQuery, "-j") + err := r.Run() + require.NoError(t, err) + assert.Equal(t, 0, r.ExitCode()) + assert.NotNil(t, r.Stdout()) + assert.NotNil(t, r.Stderr()) + + var c mqlPackages + err = r.Json(&c) + assert.NoError(t, err) + + // Should have packages + if assert.NotEmpty(t, c) { + x := c[0] + assert.NotNil(t, x.Packages) + assert.True(t, len(x.Packages) > 0) + } + }) +}