diff --git a/cmd/config/main.go b/cmd/config/main.go index 80358e383f..2c6a463a25 100644 --- a/cmd/config/main.go +++ b/cmd/config/main.go @@ -10,7 +10,7 @@ import ( "github.com/spf13/cobra" ) -var ConfigAvailableKeys = []string{"token", "offline"} +var ConfigAvailableKeys = []string{"token", "offline", "policy_config", "schema_locations"} type Messager interface { LoadVersionMessages(cliVersion string) chan *messager.VersionMessage diff --git a/cmd/test/main.go b/cmd/test/main.go index 3705804506..4fb36d3582 100644 --- a/cmd/test/main.go +++ b/cmd/test/main.go @@ -278,6 +278,11 @@ func (flags *TestCommandFlags) AddFlags(cmd *cobra.Command) { cmd.Flags().BoolVarP(&flags.Quiet, "quiet", "", false, "Don't print skipped rules messages") } +const ( + DatreePolicyConfig = "DATREE_POLICY_CONFIG" + DatreeSchemaLocations = "DATREE_SCHEMA_LOCATION" +) + func GenerateTestCommandData(testCommandFlags *TestCommandFlags, localConfigContent *localConfig.LocalConfig, evaluationPrerunDataResp *cliClient.EvaluationPrerunDataResponse) (*TestCommandData, error) { k8sVersion := testCommandFlags.K8sVersion if k8sVersion == "" { @@ -294,12 +299,21 @@ func GenerateTestCommandData(testCommandFlags *TestCommandFlags, localConfigCont var policies *defaultPolicies.EvaluationPrerunPolicies var err error + var policyConfig string if testCommandFlags.PolicyConfig != "" { + policyConfig = testCommandFlags.PolicyConfig + } else if policyConfigEnv, ok := os.LookupEnv(DatreePolicyConfig); ok { + policyConfig = policyConfigEnv + } else if localConfigContent.PolicyConfig != "" { + policyConfig = localConfigContent.PolicyConfig + } + + if policyConfig != "" { if localConfigContent.Offline != "local" && !evaluationPrerunDataResp.IsPolicyAsCodeMode { - return nil, fmt.Errorf("to use --policy-config flag you must first enable policy-as-code mode: https://hub.datree.io/policy-as-code") + return nil, fmt.Errorf("to use custom policy-config you must first enable policy-as-code mode: https://hub.datree.io/policy-as-code") } - policies, err = policy.GetPoliciesFileFromPath(testCommandFlags.PolicyConfig) + policies, err = policy.GetPoliciesFileFromPath(policyConfig) if err != nil { return nil, err } @@ -317,6 +331,15 @@ func GenerateTestCommandData(testCommandFlags *TestCommandFlags, localConfigCont return nil, err } + var schemaLocations []string + if len(testCommandFlags.SchemaLocations) != 0 { + schemaLocations = testCommandFlags.SchemaLocations + } else if schemaLocationsEnv, ok := os.LookupEnv(DatreeSchemaLocations); ok { + schemaLocations = strings.Split(schemaLocationsEnv, ",") + } else if len(localConfigContent.SchemaLocations) != 0 { + schemaLocations = localConfigContent.SchemaLocations + } + testCommandOptions := &TestCommandData{Output: testCommandFlags.Output, K8sVersion: k8sVersion, IgnoreMissingSchemas: testCommandFlags.IgnoreMissingSchemas, @@ -324,7 +347,7 @@ func GenerateTestCommandData(testCommandFlags *TestCommandFlags, localConfigCont Verbose: testCommandFlags.Verbose, NoRecord: testCommandFlags.NoRecord, Policy: policy, - SchemaLocations: testCommandFlags.SchemaLocations, + SchemaLocations: schemaLocations, Token: localConfigContent.Token, ClientId: localConfigContent.ClientId, RegistrationURL: evaluationPrerunDataResp.RegistrationURL, diff --git a/pkg/localConfig/localConfig.go b/pkg/localConfig/localConfig.go index de1ed6146a..bf8bf32bf4 100644 --- a/pkg/localConfig/localConfig.go +++ b/pkg/localConfig/localConfig.go @@ -4,6 +4,7 @@ import ( "fmt" "os" "path/filepath" + "strings" "github.com/datreeio/datree/pkg/networkValidator" @@ -13,10 +14,12 @@ import ( ) type LocalConfig struct { - Token string - ClientId string - SchemaVersion string - Offline string + Token string + ClientId string + SchemaVersion string + Offline string + PolicyConfig string + SchemaLocations []string } type TokenClient interface { @@ -36,10 +39,12 @@ func NewLocalConfigClient(t TokenClient, nv *networkValidator.NetworkValidator) } const ( - clientIdKey = "client_id" - tokenKey = "token" - schemaVersionKey = "schema_version" - offlineKey = "offline" + clientIdKey = "client_id" + tokenKey = "token" + schemaVersionKey = "schema_version" + offlineKey = "offline" + policyConfigKey = "policy_config" + schemaLocationsKey = "schema_locations" ) func (lc *LocalConfigClient) GetLocalConfiguration() (*LocalConfig, error) { @@ -55,6 +60,8 @@ func (lc *LocalConfigClient) GetLocalConfiguration() (*LocalConfig, error) { clientId := viper.GetString(clientIdKey) schemaVersion := viper.GetString(schemaVersionKey) offline := viper.GetString(offlineKey) + policyConfig := viper.GetString(policyConfigKey) + schemaLocations := viper.GetStringSlice(schemaLocationsKey) if offline == "" { offline = "fail" @@ -87,7 +94,7 @@ func (lc *LocalConfigClient) GetLocalConfiguration() (*LocalConfig, error) { } } - return &LocalConfig{Token: token, ClientId: clientId, SchemaVersion: schemaVersion, Offline: offline}, nil + return &LocalConfig{Token: token, ClientId: clientId, SchemaVersion: schemaVersion, Offline: offline, PolicyConfig: policyConfig, SchemaLocations: schemaLocations}, nil } func (lc *LocalConfigClient) Set(key string, value string) error { @@ -101,7 +108,15 @@ func (lc *LocalConfigClient) Set(key string, value string) error { return err } - viper.Set(key, value) + if key == policyConfigKey { + absPath, _ := filepath.Abs(value) + viper.Set(policyConfigKey, absPath) + } else if key == schemaLocationsKey { + viper.Set(schemaLocationsKey, strings.Split(value, ",")) + } else { + viper.Set(key, value) + } + writeClientIdErr := viper.WriteConfig() if writeClientIdErr != nil { return writeClientIdErr