From 9c0d61d888d13b9da82a08123d78d477d9ffa10d Mon Sep 17 00:00:00 2001 From: Matias Daloia Date: Tue, 11 Feb 2025 11:07:06 +0100 Subject: [PATCH] fix: Make config fields private fix: Update tests fix: Update tests + couple code enhancements --- .../result_repository_json_impl_test.go | 10 +- backend/service/result_service_impl_test.go | 2 +- cmd/configure.go | 4 +- internal/config/config.go | 163 +++++++++++++----- internal/testutils.go | 2 +- 5 files changed, 128 insertions(+), 53 deletions(-) diff --git a/backend/repository/result_repository_json_impl_test.go b/backend/repository/result_repository_json_impl_test.go index 6d8e3aa..5a01641 100644 --- a/backend/repository/result_repository_json_impl_test.go +++ b/backend/repository/result_repository_json_impl_test.go @@ -43,7 +43,7 @@ func TestGetResults(t *testing.T) { t.Run("No filter", func(t *testing.T) { mu := internal_test.NewMockUtils() - mu.On("ReadFile", config.GetInstance().ResultFilePath).Return([]byte(`{"path/to/file": [{"ID": "file", "Purl": ["pkg:example/package"]}]}`), nil) + mu.On("ReadFile", config.GetInstance().GetResultFilePath()).Return([]byte(`{"path/to/file": [{"ID": "file", "Purl": ["pkg:example/package"]}]}`), nil) repo := repository.NewResultRepositoryJsonImpl(mu) results, err := repo.GetResults(nil) @@ -57,7 +57,7 @@ func TestGetResults(t *testing.T) { t.Run("With filter", func(t *testing.T) { mu := internal_test.NewMockUtils() - mu.On("ReadFile", config.GetInstance().ResultFilePath).Return([]byte(`{"path/to/file": [{"ID": "file", "Purl": ["pkg:example/package"]}]}`), nil) + mu.On("ReadFile", config.GetInstance().GetResultFilePath()).Return([]byte(`{"path/to/file": [{"ID": "file", "Purl": ["pkg:example/package"]}]}`), nil) filter := mocks.MockResultFilter{} filter.EXPECT().IsValid(mock.Anything).Return(true) @@ -74,7 +74,7 @@ func TestGetResults(t *testing.T) { t.Run("Read file error", func(t *testing.T) { mu := internal_test.NewMockUtils() - mu.On("ReadFile", config.GetInstance().ResultFilePath).Return([]byte{}, entities.ErrReadingResultFile) + mu.On("ReadFile", config.GetInstance().GetResultFilePath()).Return([]byte{}, entities.ErrReadingResultFile) repo := repository.NewResultRepositoryJsonImpl(mu) results, err := repo.GetResults(nil) @@ -86,7 +86,7 @@ func TestGetResults(t *testing.T) { t.Run("Invalid json", func(t *testing.T) { mu := internal_test.NewMockUtils() - mu.On("ReadFile", config.GetInstance().ResultFilePath).Return([]byte(`invalid json`), nil) + mu.On("ReadFile", config.GetInstance().GetResultFilePath()).Return([]byte(`invalid json`), nil) repo := repository.NewResultRepositoryJsonImpl(mu) results, err := repo.GetResults(nil) @@ -97,7 +97,7 @@ func TestGetResults(t *testing.T) { t.Run("Filter no match", func(t *testing.T) { mu := internal_test.NewMockUtils() - mu.On("ReadFile", config.GetInstance().ResultFilePath).Return([]byte(`{"path/to/file": [{"ID": "file", "Purl": ["pkg:example/package"]}]}`), nil) + mu.On("ReadFile", config.GetInstance().GetResultFilePath()).Return([]byte(`{"path/to/file": [{"ID": "file", "Purl": ["pkg:example/package"]}]}`), nil) filter := mocks.MockResultFilter{} filter.EXPECT().IsValid(mock.Anything).Return(false) diff --git a/backend/service/result_service_impl_test.go b/backend/service/result_service_impl_test.go index 5dba86d..a2f98c0 100644 --- a/backend/service/result_service_impl_test.go +++ b/backend/service/result_service_impl_test.go @@ -43,7 +43,7 @@ func TestGetResults(t *testing.T) { defer cleanup() mu := internal_test.NewMockUtils() - mu.On("ReadFile", config.GetInstance().ResultFilePath).Return([]byte(`{"path/to/file": [{"ID": "file", "Purl": ["pkg:example/package"]}]}`), nil) + mu.On("ReadFile", config.GetInstance().GetResultFilePath()).Return([]byte(`{"path/to/file": [{"ID": "file", "Purl": ["pkg:example/package"]}]}`), nil) mockRepo := repoMocks.NewMockResultRepository(t) resultMapper := mapperMocks.NewMockResultMapper(t) diff --git a/cmd/configure.go b/cmd/configure.go index 3d96f6e..50ea619 100644 --- a/cmd/configure.go +++ b/cmd/configure.go @@ -59,8 +59,8 @@ var configureCmd = &cobra.Command{ } } - log.Info().Msgf("API URL: %s", cfg.ApiUrl) - log.Info().Msgf("KEY: %s", strings.Repeat("*", len(cfg.ApiToken))) + log.Info().Msgf("API URL: %s", cfg.GetApiUrl()) + log.Info().Msgf("KEY: %s", strings.Repeat("*", len(cfg.GetApiToken()))) log.Info().Msg("Configuration saved successfully!") }, } diff --git a/internal/config/config.go b/internal/config/config.go index 3980830..5c9ca06 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -24,6 +24,7 @@ package config import ( + "encoding/json" "errors" "fmt" "os" @@ -53,14 +54,48 @@ const ( ) type Config struct { + apiToken string + apiUrl string + resultFilePath string + scanRoot string + scanSettingsFilePath string + debug bool + mu sync.RWMutex + listeners []func(*Config) +} + +type ConfigDTO struct { ApiToken string `json:"apitoken"` ApiUrl string `json:"apiurl"` ResultFilePath string `json:"resultfilepath,omitempty"` ScanRoot string `json:"scanroot,omitempty"` ScanSettingsFilePath string `json:"scansettingsfilepath,omitempty"` Debug bool `json:"debug,omitempty"` - mu sync.RWMutex - listeners []func(*Config) +} + +func (c *Config) MarshalJSON() ([]byte, error) { + return json.Marshal(ConfigDTO{ + ApiToken: c.apiToken, + ApiUrl: c.apiUrl, + ResultFilePath: c.resultFilePath, + ScanRoot: c.scanRoot, + ScanSettingsFilePath: c.scanSettingsFilePath, + Debug: c.debug, + }) +} + +func (c *Config) UnmarshalJSON(data []byte) error { + var j ConfigDTO + if err := json.Unmarshal(data, &j); err != nil { + return err + } + c.apiToken = j.ApiToken + c.apiUrl = j.ApiUrl + c.resultFilePath = j.ResultFilePath + c.scanRoot = j.ScanRoot + c.scanSettingsFilePath = j.ScanSettingsFilePath + c.debug = j.Debug + return nil } var instance *Config @@ -102,36 +137,42 @@ func (c *Config) getDefaultScanSettingsFilePath(scanRoot string) string { func (c *Config) GetApiToken() string { c.mu.RLock() defer c.mu.RUnlock() - return c.ApiToken + return c.apiToken } func (c *Config) GetApiUrl() string { c.mu.RLock() defer c.mu.RUnlock() - return c.ApiUrl + return c.apiUrl } func (c *Config) GetResultFilePath() string { c.mu.RLock() defer c.mu.RUnlock() - return c.ResultFilePath + return c.resultFilePath } func (c *Config) GetScanRoot() string { c.mu.RLock() defer c.mu.RUnlock() - return c.ScanRoot + return c.scanRoot } func (c *Config) GetScanSettingsFilePath() string { c.mu.RLock() defer c.mu.RUnlock() - return c.ScanSettingsFilePath + return c.scanSettingsFilePath +} + +func (c *Config) GetDebug() bool { + c.mu.RLock() + defer c.mu.RUnlock() + return c.debug } func (c *Config) SetApiToken(token string) error { c.mu.Lock() - c.ApiToken = token + c.apiToken = token viper.Set("apitoken", token) c.mu.Unlock() c.notifyListeners() @@ -140,7 +181,7 @@ func (c *Config) SetApiToken(token string) error { func (c *Config) SetApiUrl(url string) error { c.mu.Lock() - c.ApiUrl = url + c.apiUrl = url viper.Set("apiurl", url) c.mu.Unlock() c.notifyListeners() @@ -149,23 +190,30 @@ func (c *Config) SetApiUrl(url string) error { func (c *Config) SetResultFilePath(path string) { c.mu.Lock() - c.ResultFilePath = path + c.resultFilePath = path c.mu.Unlock() c.notifyListeners() } func (c *Config) SetScanRoot(path string) { c.mu.Lock() - c.ScanRoot = path - c.ResultFilePath = c.getDefaultResultFilePath(path) - c.ScanSettingsFilePath = c.getDefaultScanSettingsFilePath(path) + c.scanRoot = path + c.resultFilePath = c.getDefaultResultFilePath(path) + c.scanSettingsFilePath = c.getDefaultScanSettingsFilePath(path) c.mu.Unlock() c.notifyListeners() } func (c *Config) SetScanSettingsFilePath(path string) { c.mu.Lock() - c.ScanSettingsFilePath = path + c.scanSettingsFilePath = path + c.mu.Unlock() + c.notifyListeners() +} + +func (c *Config) SetDebug(debug bool) { + c.mu.Lock() + c.debug = debug c.mu.Unlock() c.notifyListeners() } @@ -206,11 +254,7 @@ func (c *Config) setupLogger(debug bool) error { return nil } -func (c *Config) InitializeConfig(cfgFile, scanRoot, apiKey, apiUrl, inputFile, scanossSettingsFilePath string, originalWorkDir string, debug bool) error { - if err := c.setupLogger(debug); err != nil { - return fmt.Errorf("error setting up logger: %w", err) - } - +func (c *Config) initializeConfigFile(cfgFile string) error { if cfgFile != "" { absCfgFile, _ := filepath.Abs(cfgFile) log.Debug().Msgf("Using config file: %s", absCfgFile) @@ -219,36 +263,38 @@ func (c *Config) InitializeConfig(cfgFile, scanRoot, apiKey, apiUrl, inputFile, if err := viper.ReadInConfig(); err != nil { log.Fatal().Err(err).Msgf("Error reading config file %v", err.Error()) } - } else { - viper.SetConfigName(DEFAULT_CONFIG_FILE_NAME) - viper.SetConfigType(DEFAULT_CONFIG_FILE_TYPE) - viper.AddConfigPath(c.GetDefaultConfigFolder()) + return nil + } - // Default values - viper.SetDefault("apiurl", DEFAULT_API_URL) - viper.SetDefault("apitoken", "") + viper.SetConfigName(DEFAULT_CONFIG_FILE_NAME) + viper.SetConfigType(DEFAULT_CONFIG_FILE_TYPE) + viper.AddConfigPath(c.GetDefaultConfigFolder()) - if err := viper.ReadInConfig(); err != nil { - if _, ok := err.(viper.ConfigFileNotFoundError); ok { - defaultConfigDir := c.GetDefaultConfigFolder() - if err := os.MkdirAll(defaultConfigDir, os.ModePerm); err != nil { - return fmt.Errorf("error creating config directory: %w", err) - } - if err := viper.SafeWriteConfig(); err != nil { - return fmt.Errorf("error creating config file: %w", err) - } - log.Debug().Msgf("Created default config file: %s", viper.ConfigFileUsed()) - } else { - return fmt.Errorf("error reading config file: %w", err) + // Default values + viper.SetDefault("apiurl", DEFAULT_API_URL) + viper.SetDefault("apitoken", "") + + if err := viper.ReadInConfig(); err != nil { + if _, ok := err.(viper.ConfigFileNotFoundError); ok { + defaultConfigDir := c.GetDefaultConfigFolder() + if err := os.MkdirAll(defaultConfigDir, os.ModePerm); err != nil { + return fmt.Errorf("error creating config directory: %w", err) } + if err := viper.SafeWriteConfig(); err != nil { + return fmt.Errorf("error creating config file: %w", err) + } + log.Debug().Msgf("Created default config file: %s", viper.ConfigFileUsed()) + } else { + return fmt.Errorf("error reading config file: %w", err) } } + return nil +} +func (c *Config) initializeApiConfig(apiKey, apiUrl string) error { c.SetApiToken(viper.GetString("apitoken")) c.SetApiUrl(viper.GetString("apiurl")) - c.Debug = debug - // Override with command line flags if apiKey != "" { if err := c.SetApiToken(apiKey); err != nil { @@ -260,6 +306,10 @@ func (c *Config) InitializeConfig(cfgFile, scanRoot, apiKey, apiUrl, inputFile, return fmt.Errorf("error saving API URL: %w", err) } } + return nil +} + +func (c *Config) initializePathConfig(scanRoot, inputFile, scanossSettingsFilePath, originalWorkDir string) error { if scanRoot != "" { c.SetScanRoot(scanRoot) } @@ -271,7 +321,7 @@ func (c *Config) InitializeConfig(cfgFile, scanRoot, apiKey, apiUrl, inputFile, } // Set default values if not set via config file or command line args - if c.ScanRoot == "" { + if c.GetScanRoot() == "" { if originalWorkDir != "" { c.SetScanRoot(originalWorkDir) } else { @@ -282,12 +332,37 @@ func (c *Config) InitializeConfig(cfgFile, scanRoot, apiKey, apiUrl, inputFile, c.SetScanRoot(wd) } } - if c.ResultFilePath == "" { - c.SetResultFilePath(c.getDefaultResultFilePath(originalWorkDir)) + if c.GetResultFilePath() == "" { + defaultPath := c.getDefaultResultFilePath(originalWorkDir) + if !filepath.IsAbs(defaultPath) { + defaultPath = filepath.Join(c.GetScanRoot(), defaultPath) + } + c.SetResultFilePath(defaultPath) } - if c.ScanSettingsFilePath == "" { + if c.GetScanSettingsFilePath() == "" { c.SetScanSettingsFilePath(c.getDefaultScanSettingsFilePath(originalWorkDir)) } + return nil +} + +func (c *Config) InitializeConfig(cfgFile, scanRoot, apiKey, apiUrl, inputFile, scanossSettingsFilePath string, originalWorkDir string, debug bool) error { + if err := c.setupLogger(debug); err != nil { + return fmt.Errorf("error setting up logger: %w", err) + } + + if err := c.initializeConfigFile(cfgFile); err != nil { + return err + } + + if err := c.initializeApiConfig(apiKey, apiUrl); err != nil { + return err + } + + c.SetDebug(debug) + + if err := c.initializePathConfig(scanRoot, inputFile, scanossSettingsFilePath, originalWorkDir); err != nil { + return err + } return nil } diff --git a/internal/testutils.go b/internal/testutils.go index e52df32..ce82586 100644 --- a/internal/testutils.go +++ b/internal/testutils.go @@ -44,7 +44,7 @@ func InitializeTestEnvironment(t *testing.T) func() { InitValidatorForTests() cfg := config.GetInstance() - cfg.ScanRoot = t.TempDir() + cfg.SetScanRoot(t.TempDir()) return func() { cfg = nil