Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Make config fields private #91

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions backend/repository/result_repository_json_impl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func TestGetResults(t *testing.T) {

t.Run("No filter", func(t *testing.T) {
mu := internal_test.NewMockUtils()
mu.On("ReadFile", config.GetInstance().ResultFilePath).Return([]byte(`{"path/to/file": [{"ID": "file", "Purl": ["pkg:example/package"]}]}`), nil)
mu.On("ReadFile", config.GetInstance().GetResultFilePath()).Return([]byte(`{"path/to/file": [{"ID": "file", "Purl": ["pkg:example/package"]}]}`), nil)

repo := repository.NewResultRepositoryJsonImpl(mu)
results, err := repo.GetResults(nil)
Expand All @@ -57,7 +57,7 @@ func TestGetResults(t *testing.T) {

t.Run("With filter", func(t *testing.T) {
mu := internal_test.NewMockUtils()
mu.On("ReadFile", config.GetInstance().ResultFilePath).Return([]byte(`{"path/to/file": [{"ID": "file", "Purl": ["pkg:example/package"]}]}`), nil)
mu.On("ReadFile", config.GetInstance().GetResultFilePath()).Return([]byte(`{"path/to/file": [{"ID": "file", "Purl": ["pkg:example/package"]}]}`), nil)

filter := mocks.MockResultFilter{}
filter.EXPECT().IsValid(mock.Anything).Return(true)
Expand All @@ -74,7 +74,7 @@ func TestGetResults(t *testing.T) {

t.Run("Read file error", func(t *testing.T) {
mu := internal_test.NewMockUtils()
mu.On("ReadFile", config.GetInstance().ResultFilePath).Return([]byte{}, entities.ErrReadingResultFile)
mu.On("ReadFile", config.GetInstance().GetResultFilePath()).Return([]byte{}, entities.ErrReadingResultFile)

repo := repository.NewResultRepositoryJsonImpl(mu)
results, err := repo.GetResults(nil)
Expand All @@ -86,7 +86,7 @@ func TestGetResults(t *testing.T) {

t.Run("Invalid json", func(t *testing.T) {
mu := internal_test.NewMockUtils()
mu.On("ReadFile", config.GetInstance().ResultFilePath).Return([]byte(`invalid json`), nil)
mu.On("ReadFile", config.GetInstance().GetResultFilePath()).Return([]byte(`invalid json`), nil)

repo := repository.NewResultRepositoryJsonImpl(mu)
results, err := repo.GetResults(nil)
Expand All @@ -97,7 +97,7 @@ func TestGetResults(t *testing.T) {

t.Run("Filter no match", func(t *testing.T) {
mu := internal_test.NewMockUtils()
mu.On("ReadFile", config.GetInstance().ResultFilePath).Return([]byte(`{"path/to/file": [{"ID": "file", "Purl": ["pkg:example/package"]}]}`), nil)
mu.On("ReadFile", config.GetInstance().GetResultFilePath()).Return([]byte(`{"path/to/file": [{"ID": "file", "Purl": ["pkg:example/package"]}]}`), nil)

filter := mocks.MockResultFilter{}
filter.EXPECT().IsValid(mock.Anything).Return(false)
Expand Down
2 changes: 1 addition & 1 deletion backend/service/result_service_impl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func TestGetResults(t *testing.T) {
defer cleanup()

mu := internal_test.NewMockUtils()
mu.On("ReadFile", config.GetInstance().ResultFilePath).Return([]byte(`{"path/to/file": [{"ID": "file", "Purl": ["pkg:example/package"]}]}`), nil)
mu.On("ReadFile", config.GetInstance().GetResultFilePath()).Return([]byte(`{"path/to/file": [{"ID": "file", "Purl": ["pkg:example/package"]}]}`), nil)

mockRepo := repoMocks.NewMockResultRepository(t)
resultMapper := mapperMocks.NewMockResultMapper(t)
Expand Down
4 changes: 2 additions & 2 deletions cmd/configure.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ var configureCmd = &cobra.Command{
}
}

log.Info().Msgf("API URL: %s", cfg.ApiUrl)
log.Info().Msgf("KEY: %s", strings.Repeat("*", len(cfg.ApiToken)))
log.Info().Msgf("API URL: %s", cfg.GetApiUrl())
log.Info().Msgf("KEY: %s", strings.Repeat("*", len(cfg.GetApiToken())))
log.Info().Msg("Configuration saved successfully!")
},
}
Expand Down
163 changes: 119 additions & 44 deletions internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
package config

import (
"encoding/json"
"errors"
"fmt"
"os"
Expand Down Expand Up @@ -53,14 +54,48 @@ const (
)

type Config struct {
apiToken string
apiUrl string
resultFilePath string
scanRoot string
scanSettingsFilePath string
debug bool
mu sync.RWMutex
listeners []func(*Config)
}

type ConfigDTO struct {
ApiToken string `json:"apitoken"`
ApiUrl string `json:"apiurl"`
ResultFilePath string `json:"resultfilepath,omitempty"`
ScanRoot string `json:"scanroot,omitempty"`
ScanSettingsFilePath string `json:"scansettingsfilepath,omitempty"`
Debug bool `json:"debug,omitempty"`
mu sync.RWMutex
listeners []func(*Config)
}

func (c *Config) MarshalJSON() ([]byte, error) {
return json.Marshal(ConfigDTO{
ApiToken: c.apiToken,
ApiUrl: c.apiUrl,
ResultFilePath: c.resultFilePath,
ScanRoot: c.scanRoot,
ScanSettingsFilePath: c.scanSettingsFilePath,
Debug: c.debug,
})
}

func (c *Config) UnmarshalJSON(data []byte) error {
var j ConfigDTO
if err := json.Unmarshal(data, &j); err != nil {
return err
}
c.apiToken = j.ApiToken
c.apiUrl = j.ApiUrl
c.resultFilePath = j.ResultFilePath
c.scanRoot = j.ScanRoot
c.scanSettingsFilePath = j.ScanSettingsFilePath
c.debug = j.Debug
return nil
}

var instance *Config
Expand Down Expand Up @@ -102,36 +137,42 @@ func (c *Config) getDefaultScanSettingsFilePath(scanRoot string) string {
func (c *Config) GetApiToken() string {
c.mu.RLock()
defer c.mu.RUnlock()
return c.ApiToken
return c.apiToken
}

func (c *Config) GetApiUrl() string {
c.mu.RLock()
defer c.mu.RUnlock()
return c.ApiUrl
return c.apiUrl
}

func (c *Config) GetResultFilePath() string {
c.mu.RLock()
defer c.mu.RUnlock()
return c.ResultFilePath
return c.resultFilePath
}

func (c *Config) GetScanRoot() string {
c.mu.RLock()
defer c.mu.RUnlock()
return c.ScanRoot
return c.scanRoot
}

func (c *Config) GetScanSettingsFilePath() string {
c.mu.RLock()
defer c.mu.RUnlock()
return c.ScanSettingsFilePath
return c.scanSettingsFilePath
}

func (c *Config) GetDebug() bool {
c.mu.RLock()
defer c.mu.RUnlock()
return c.debug
}

func (c *Config) SetApiToken(token string) error {
c.mu.Lock()
c.ApiToken = token
c.apiToken = token
viper.Set("apitoken", token)
c.mu.Unlock()
c.notifyListeners()
Expand All @@ -140,7 +181,7 @@ func (c *Config) SetApiToken(token string) error {

func (c *Config) SetApiUrl(url string) error {
c.mu.Lock()
c.ApiUrl = url
c.apiUrl = url
viper.Set("apiurl", url)
c.mu.Unlock()
c.notifyListeners()
Expand All @@ -149,23 +190,30 @@ func (c *Config) SetApiUrl(url string) error {

func (c *Config) SetResultFilePath(path string) {
c.mu.Lock()
c.ResultFilePath = path
c.resultFilePath = path
c.mu.Unlock()
c.notifyListeners()
}

func (c *Config) SetScanRoot(path string) {
c.mu.Lock()
c.ScanRoot = path
c.ResultFilePath = c.getDefaultResultFilePath(path)
c.ScanSettingsFilePath = c.getDefaultScanSettingsFilePath(path)
c.scanRoot = path
c.resultFilePath = c.getDefaultResultFilePath(path)
c.scanSettingsFilePath = c.getDefaultScanSettingsFilePath(path)
c.mu.Unlock()
c.notifyListeners()
}

func (c *Config) SetScanSettingsFilePath(path string) {
c.mu.Lock()
c.ScanSettingsFilePath = path
c.scanSettingsFilePath = path
c.mu.Unlock()
c.notifyListeners()
}

func (c *Config) SetDebug(debug bool) {
c.mu.Lock()
c.debug = debug
c.mu.Unlock()
c.notifyListeners()
}
Expand Down Expand Up @@ -206,11 +254,7 @@ func (c *Config) setupLogger(debug bool) error {
return nil
}

func (c *Config) InitializeConfig(cfgFile, scanRoot, apiKey, apiUrl, inputFile, scanossSettingsFilePath string, originalWorkDir string, debug bool) error {
if err := c.setupLogger(debug); err != nil {
return fmt.Errorf("error setting up logger: %w", err)
}

func (c *Config) initializeConfigFile(cfgFile string) error {
if cfgFile != "" {
absCfgFile, _ := filepath.Abs(cfgFile)
log.Debug().Msgf("Using config file: %s", absCfgFile)
Expand All @@ -219,36 +263,38 @@ func (c *Config) InitializeConfig(cfgFile, scanRoot, apiKey, apiUrl, inputFile,
if err := viper.ReadInConfig(); err != nil {
log.Fatal().Err(err).Msgf("Error reading config file %v", err.Error())
}
} else {
viper.SetConfigName(DEFAULT_CONFIG_FILE_NAME)
viper.SetConfigType(DEFAULT_CONFIG_FILE_TYPE)
viper.AddConfigPath(c.GetDefaultConfigFolder())
return nil
}

// Default values
viper.SetDefault("apiurl", DEFAULT_API_URL)
viper.SetDefault("apitoken", "")
viper.SetConfigName(DEFAULT_CONFIG_FILE_NAME)
viper.SetConfigType(DEFAULT_CONFIG_FILE_TYPE)
viper.AddConfigPath(c.GetDefaultConfigFolder())

if err := viper.ReadInConfig(); err != nil {
if _, ok := err.(viper.ConfigFileNotFoundError); ok {
defaultConfigDir := c.GetDefaultConfigFolder()
if err := os.MkdirAll(defaultConfigDir, os.ModePerm); err != nil {
return fmt.Errorf("error creating config directory: %w", err)
}
if err := viper.SafeWriteConfig(); err != nil {
return fmt.Errorf("error creating config file: %w", err)
}
log.Debug().Msgf("Created default config file: %s", viper.ConfigFileUsed())
} else {
return fmt.Errorf("error reading config file: %w", err)
// Default values
viper.SetDefault("apiurl", DEFAULT_API_URL)
viper.SetDefault("apitoken", "")

if err := viper.ReadInConfig(); err != nil {
if _, ok := err.(viper.ConfigFileNotFoundError); ok {
defaultConfigDir := c.GetDefaultConfigFolder()
if err := os.MkdirAll(defaultConfigDir, os.ModePerm); err != nil {
return fmt.Errorf("error creating config directory: %w", err)
}
if err := viper.SafeWriteConfig(); err != nil {
return fmt.Errorf("error creating config file: %w", err)
}
log.Debug().Msgf("Created default config file: %s", viper.ConfigFileUsed())
} else {
return fmt.Errorf("error reading config file: %w", err)
}
}
return nil
}

func (c *Config) initializeApiConfig(apiKey, apiUrl string) error {
c.SetApiToken(viper.GetString("apitoken"))
c.SetApiUrl(viper.GetString("apiurl"))

c.Debug = debug

// Override with command line flags
if apiKey != "" {
if err := c.SetApiToken(apiKey); err != nil {
Expand All @@ -260,6 +306,10 @@ func (c *Config) InitializeConfig(cfgFile, scanRoot, apiKey, apiUrl, inputFile,
return fmt.Errorf("error saving API URL: %w", err)
}
}
return nil
}

func (c *Config) initializePathConfig(scanRoot, inputFile, scanossSettingsFilePath, originalWorkDir string) error {
if scanRoot != "" {
c.SetScanRoot(scanRoot)
}
Expand All @@ -271,7 +321,7 @@ func (c *Config) InitializeConfig(cfgFile, scanRoot, apiKey, apiUrl, inputFile,
}

// Set default values if not set via config file or command line args
if c.ScanRoot == "" {
if c.GetScanRoot() == "" {
if originalWorkDir != "" {
c.SetScanRoot(originalWorkDir)
} else {
Expand All @@ -282,12 +332,37 @@ func (c *Config) InitializeConfig(cfgFile, scanRoot, apiKey, apiUrl, inputFile,
c.SetScanRoot(wd)
}
}
if c.ResultFilePath == "" {
c.SetResultFilePath(c.getDefaultResultFilePath(originalWorkDir))
if c.GetResultFilePath() == "" {
defaultPath := c.getDefaultResultFilePath(originalWorkDir)
if !filepath.IsAbs(defaultPath) {
defaultPath = filepath.Join(c.GetScanRoot(), defaultPath)
}
c.SetResultFilePath(defaultPath)
}
if c.ScanSettingsFilePath == "" {
if c.GetScanSettingsFilePath() == "" {
c.SetScanSettingsFilePath(c.getDefaultScanSettingsFilePath(originalWorkDir))
}
return nil
}

func (c *Config) InitializeConfig(cfgFile, scanRoot, apiKey, apiUrl, inputFile, scanossSettingsFilePath string, originalWorkDir string, debug bool) error {
if err := c.setupLogger(debug); err != nil {
return fmt.Errorf("error setting up logger: %w", err)
}

if err := c.initializeConfigFile(cfgFile); err != nil {
return err
}

if err := c.initializeApiConfig(apiKey, apiUrl); err != nil {
return err
}

c.SetDebug(debug)

if err := c.initializePathConfig(scanRoot, inputFile, scanossSettingsFilePath, originalWorkDir); err != nil {
return err
}

return nil
}
2 changes: 1 addition & 1 deletion internal/testutils.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func InitializeTestEnvironment(t *testing.T) func() {
InitValidatorForTests()

cfg := config.GetInstance()
cfg.ScanRoot = t.TempDir()
cfg.SetScanRoot(t.TempDir())

return func() {
cfg = nil
Expand Down
Loading