From b5632573818e4cfe361c78925a2ea66f9a4a370f Mon Sep 17 00:00:00 2001 From: Joseph Kavanagh Date: Sat, 6 Apr 2024 03:53:34 +0100 Subject: [PATCH] test(db): fix races with logfrom --- cmd/argus/main.go | 3 +- config/help_test.go | 3 +- config/ordering_test.go | 2 +- config/settings.go | 6 ++-- config/settings_test.go | 12 +++---- db/handlers.go | 18 +++++----- db/handlers_test.go | 9 +++-- db/help_test.go | 8 +++-- db/init.go | 55 +++++++++++++++---------------- db/init_test.go | 14 ++++---- db/types.go | 6 ++++ web/api/types/argus.go | 4 +-- web/api/v1/http-api-flags.go | 2 +- web/api/v1/http-api-flags_test.go | 2 +- web/help_test.go | 14 +++----- 15 files changed, 82 insertions(+), 76 deletions(-) diff --git a/cmd/argus/main.go b/cmd/argus/main.go index 87dfd7fc..b4a8fa7a 100644 --- a/cmd/argus/main.go +++ b/cmd/argus/main.go @@ -77,7 +77,8 @@ func main() { } } - go db.Run(&config, &jLog) + db.LogInit(&jLog, config.Settings.DataDatabaseFile()) + go db.Run(&config) // Track all targets for changes in version and act on any found changes. go (&config).Service.Track(&config.Order, &config.OrderMutex) diff --git a/config/help_test.go b/config/help_test.go index 6774ac7b..5d1c4c18 100644 --- a/config/help_test.go +++ b/config/help_test.go @@ -106,7 +106,7 @@ func testLoad(file string, t *testing.T) (config *Config) { loadMutex.Lock() defer loadMutex.Unlock() config.Load(file, &flags, log) - t.Cleanup(func() { os.Remove(*config.Settings.DataDatabaseFile()) }) + t.Cleanup(func() { os.Remove(config.Settings.DataDatabaseFile()) }) return } @@ -143,6 +143,7 @@ func testLoadBasic(file string, t *testing.T) (config *Config) { service.ID = name } config.CheckValues() + t.Log("Loaded", file) return } diff --git a/config/ordering_test.go b/config/ordering_test.go index ab0af2d0..44cb4ac4 100644 --- a/config/ordering_test.go +++ b/config/ordering_test.go @@ -53,7 +53,7 @@ func TestConfig_LoadOrdering(t *testing.T) { lock.Lock() config.Load(file, &flags, log) lock.Unlock() - defer os.Remove(*config.Settings.DataDatabaseFile()) + defer os.Remove(config.Settings.DataDatabaseFile()) // THEN it gets the ordering correctly gotOrder := config.Order diff --git a/config/settings.go b/config/settings.go index 1bc5c185..8adb646a 100644 --- a/config/settings.go +++ b/config/settings.go @@ -274,11 +274,11 @@ func (s *Settings) LogLevel() string { } // DataDatabaseFile. -func (s *Settings) DataDatabaseFile() *string { - return util.FirstNonNilPtr( +func (s *Settings) DataDatabaseFile() string { + return util.DefaultIfNil(util.FirstNonNilPtr( s.FromFlags.Data.DatabaseFile, s.Data.DatabaseFile, - s.HardDefaults.Data.DatabaseFile) + s.HardDefaults.Data.DatabaseFile)) } // WebListenHost. diff --git a/config/settings_test.go b/config/settings_test.go index 4e50510a..c44a2030 100644 --- a/config/settings_test.go +++ b/config/settings_test.go @@ -197,17 +197,17 @@ func TestSettings_GetString(t *testing.T) { want: "ERROR", }, "data.database-file hard default": { - getFuncPtr: settings.DataDatabaseFile, - flag: &DataDatabaseFile, want: "data/argus.db", + getFunc: settings.DataDatabaseFile, + flag: &DataDatabaseFile, want: "data/argus.db", nilConfig: true, configPtr: &settings.Data.DatabaseFile, }, "data.database-file config": { - getFuncPtr: settings.DataDatabaseFile, - flag: &DataDatabaseFile, want: "somewhere.db", + getFunc: settings.DataDatabaseFile, + flag: &DataDatabaseFile, want: "somewhere.db", }, "data.database-file flag": { - getFuncPtr: settings.DataDatabaseFile, - flag: &DataDatabaseFile, flagVal: stringPtr("ERROR"), + getFunc: settings.DataDatabaseFile, + flag: &DataDatabaseFile, flagVal: stringPtr("ERROR"), want: "ERROR", }, "web.listen-host hard default": { diff --git a/db/handlers.go b/db/handlers.go index 9fe5a651..4f3b9004 100644 --- a/db/handlers.go +++ b/db/handlers.go @@ -66,14 +66,14 @@ func (api *api) updateRow(serviceID string, cells []dbtype.Cell) { if jLog.IsLevel("DEBUG") { jLog.Debug( fmt.Sprintf("%s, %v", sqlStmt, params), - *logFrom, true) + logFrom, true) } res, err := api.db.Exec(sqlStmt, params...) // Query failed if err != nil { jLog.Error( fmt.Sprintf("updateRow UPDATE: %q %v, %s", sqlStmt, params, util.ErrorToString(err)), - *logFrom, true) + logFrom, true) return } @@ -105,12 +105,13 @@ func (api *api) updateRow(serviceID string, cells []dbtype.Cell) { if jLog.IsLevel("DEBUG") { jLog.Debug( fmt.Sprintf("%s, %v", sqlStmt, params), - *logFrom, true) + logFrom, true) } _, err = api.db.Exec(sqlStmt, params...) jLog.Error( - fmt.Sprintf("updateRow INSERT: %q %v, %s", sqlStmt, params, util.ErrorToString(err)), - *logFrom, + fmt.Sprintf("updateRow INSERT: %q %v, %s", + sqlStmt, params, util.ErrorToString(err)), + logFrom, err != nil) } } @@ -123,11 +124,12 @@ func (api *api) deleteRow(serviceID string) { if jLog.IsLevel("DEBUG") { jLog.Debug( fmt.Sprintf("%s, %v", sqlStmt, serviceID), - *logFrom, true) + logFrom, true) } _, err := api.db.Exec(sqlStmt, serviceID) jLog.Error( - fmt.Sprintf("deleteRow: %q with %q, %s", sqlStmt, serviceID, util.ErrorToString(err)), - *logFrom, + fmt.Sprintf("deleteRow: %q with %q, %s", + sqlStmt, serviceID, util.ErrorToString(err)), + logFrom, err != nil) } diff --git a/db/handlers_test.go b/db/handlers_test.go index 80aa18ce..13576346 100644 --- a/db/handlers_test.go +++ b/db/handlers_test.go @@ -75,7 +75,7 @@ func TestAPI_UpdateRow(t *testing.T) { cfg := testConfig() testAPI := api{config: cfg} *testAPI.config.Settings.Data.DatabaseFile = fmt.Sprintf("%s.db", strings.ReplaceAll(name, " ", "_")) - defer os.Remove(*testAPI.config.Settings.Data.DatabaseFile) + defer os.Remove(*cfg.Settings.Data.DatabaseFile) testAPI.initialise() // WHEN updateRow is called targeting single/multiple cells @@ -134,8 +134,11 @@ func TestAPI_DeleteRow(t *testing.T) { // Ensure the row exists if tc.exists if tc.exists { - testAPI.updateRow(tc.serviceID, []dbtype.Cell{ - {Column: "latest_version", Value: "9.9.9"}, {Column: "deployed_version", Value: "8.8.8"}}) + testAPI.updateRow( + tc.serviceID, + []dbtype.Cell{ + {Column: "latest_version", Value: "9.9.9"}, {Column: "deployed_version", Value: "8.8.8"}}, + ) time.Sleep(100 * time.Millisecond) } // Check the row existance before the test diff --git a/db/help_test.go b/db/help_test.go index 7da243f8..64f89da5 100644 --- a/db/help_test.go +++ b/db/help_test.go @@ -33,13 +33,15 @@ var cfg *config.Config func TestMain(m *testing.M) { log := util.NewJLog("DEBUG", false) - logFrom = &util.LogFrom{} log.Testing = true + databaseFile := "TestRun.db" + LogInit(log, databaseFile) cfg = testConfig() - *cfg.Settings.Data.DatabaseFile = "TestRun.db" + *cfg.Settings.Data.DatabaseFile = databaseFile defer os.Remove(*cfg.Settings.Data.DatabaseFile) - go Run(cfg, log) + go Run(cfg) + time.Sleep(250 * time.Millisecond) // Time for db to start os.Exit(m.Run()) } diff --git a/db/init.go b/db/init.go index e5a3435b..b7ecda8c 100644 --- a/db/init.go +++ b/db/init.go @@ -25,10 +25,11 @@ import ( "github.com/release-argus/Argus/util" ) -var ( - jLog *util.JLog - logFrom *util.LogFrom -) +// LogInit for this package. +func LogInit(log *util.JLog, databaseFile string) { + jLog = log + logFrom = util.LogFrom{Primary: "db", Secondary: databaseFile} +} func checkFile(path string) { file := filepath.Base(path) @@ -40,34 +41,30 @@ func checkFile(path string) { // create the dir if os.IsNotExist(err) { err = os.MkdirAll(dir, 0755) - jLog.Fatal(util.ErrorToString(err), *logFrom, err != nil) + jLog.Fatal(util.ErrorToString(err), logFrom, err != nil) } else { // other error - jLog.Fatal(util.ErrorToString(err), *logFrom, true) + jLog.Fatal(util.ErrorToString(err), logFrom, true) } // directory exists but is not a directory } else if fileInfo == nil || !fileInfo.IsDir() { - jLog.Fatal(fmt.Sprintf("path %q (for %q) is not a directory", dir, file), *logFrom, true) + jLog.Fatal(fmt.Sprintf("path %q (for %q) is not a directory", dir, file), logFrom, true) } // Check that the file exists fileInfo, err = os.Stat(path) if err != nil { // file doesn't exist - jLog.Fatal(util.ErrorToString(err), *logFrom, os.IsExist(err)) + jLog.Fatal(util.ErrorToString(err), logFrom, os.IsExist(err)) // item exists but is a directory } else if fileInfo != nil && fileInfo.IsDir() { - jLog.Fatal(fmt.Sprintf("path %q (for %q) is a directory, not a file", path, file), *logFrom, true) + jLog.Fatal(fmt.Sprintf("path %q (for %q) is a directory, not a file", path, file), logFrom, true) } } -func Run(cfg *config.Config, log *util.JLog) { - jLog = log - databaseFile := cfg.Settings.DataDatabaseFile() - logFrom = &util.LogFrom{Primary: "db", Secondary: *databaseFile} - +func Run(cfg *config.Config) { api := api{config: cfg} api.initialise() defer api.db.Close() @@ -81,9 +78,9 @@ func Run(cfg *config.Config, log *util.JLog) { func (api *api) initialise() { databaseFile := api.config.Settings.DataDatabaseFile() - checkFile(*databaseFile) - db, err := sql.Open("sqlite", *databaseFile) - jLog.Fatal(err, *logFrom, err != nil) + checkFile(databaseFile) + db, err := sql.Open("sqlite", databaseFile) + jLog.Fatal(err, logFrom, err != nil) // Create the table sqlStmt := ` @@ -96,7 +93,7 @@ func (api *api) initialise() { approved_version TEXT DEFAULT '' );` _, err = db.Exec(sqlStmt) - jLog.Fatal(util.ErrorToString(err), *logFrom, err != nil) + jLog.Fatal(util.ErrorToString(err), logFrom, err != nil) updateTable(db) @@ -123,7 +120,7 @@ func (api *api) removeUnknownServices() { _, err := api.db.Exec(sqlStmt, params...) jLog.Fatal( fmt.Sprintf("removeUnknownServices: %s", util.ErrorToString(err)), - *logFrom, + logFrom, err != nil) } @@ -139,7 +136,7 @@ func (api *api) extractServiceStatus() { deployed_version_timestamp, approved_version FROM status;`) - jLog.Fatal(err, *logFrom, err != nil) + jLog.Fatal(err, logFrom, err != nil) defer rows.Close() api.config.OrderMutex.RLock() @@ -156,7 +153,7 @@ func (api *api) extractServiceStatus() { err = rows.Scan(&id, &lv, &lvt, &dv, &dvt, &av) jLog.Fatal( fmt.Sprintf("extractServiceStatus row: %s", util.ErrorToString(err)), - *logFrom, + logFrom, err != nil) api.config.Service[id].Status.SetLatestVersion(lv, false) api.config.Service[id].Status.SetLatestVersionTimestamp(lvt) @@ -167,7 +164,7 @@ func (api *api) extractServiceStatus() { err = rows.Err() jLog.Fatal( fmt.Sprintf("extractServiceStatus: %s", util.ErrorToString(err)), - *logFrom, + logFrom, err != nil) } @@ -176,12 +173,12 @@ func updateTable(db *sql.DB) { // Get the type of the *_version columns var columnType string err := db.QueryRow("SELECT type FROM pragma_table_info('status') WHERE name = 'latest_version'").Scan(&columnType) - jLog.Fatal(fmt.Sprintf("updateTable: %s", util.ErrorToString(err)), *logFrom, err != nil) + jLog.Fatal(fmt.Sprintf("updateTable: %s", util.ErrorToString(err)), logFrom, err != nil) // Update if the column type is not TEXT if columnType != "TEXT" { - jLog.Verbose("Updating column types", *logFrom, true) + jLog.Verbose("Updating column types", logFrom, true) updateColumnTypes(db) - jLog.Verbose("Finished updating column types", *logFrom, true) + jLog.Verbose("Finished updating column types", logFrom, true) } } @@ -198,17 +195,17 @@ func updateColumnTypes(db *sql.DB) { approved_version TEXT DEFAULT '' );` _, err := db.Exec(sqlStmt) - jLog.Fatal(fmt.Sprintf("updateColumnTypes - create: %s", util.ErrorToString(err)), *logFrom, err != nil) + jLog.Fatal(fmt.Sprintf("updateColumnTypes - create: %s", util.ErrorToString(err)), logFrom, err != nil) // Copy the data from the old table to the new table _, err = db.Exec(`INSERT INTO status_backup SELECT * FROM status;`) - jLog.Fatal(fmt.Sprintf("updateColumnTypes - copy: %s", util.ErrorToString(err)), *logFrom, err != nil) + jLog.Fatal(fmt.Sprintf("updateColumnTypes - copy: %s", util.ErrorToString(err)), logFrom, err != nil) // Drop the table _, err = db.Exec("DROP TABLE status;") - jLog.Fatal(fmt.Sprintf("updateColumnTypes - drop: %s", util.ErrorToString(err)), *logFrom, err != nil) + jLog.Fatal(fmt.Sprintf("updateColumnTypes - drop: %s", util.ErrorToString(err)), logFrom, err != nil) // Rename the new table to the old table _, err = db.Exec("ALTER TABLE status_backup RENAME TO status;") - jLog.Fatal(fmt.Sprintf("updateColumnTypes - rename: %s", util.ErrorToString(err)), *logFrom, err != nil) + jLog.Fatal(fmt.Sprintf("updateColumnTypes - rename: %s", util.ErrorToString(err)), logFrom, err != nil) } diff --git a/db/init_test.go b/db/init_test.go index db8b610c..cb7b66da 100644 --- a/db/init_test.go +++ b/db/init_test.go @@ -169,7 +169,8 @@ func TestDBQueryService(t *testing.T) { {Column: "latest_version_timestamp", Value: (*svc).Status.LatestVersionTimestamp()}, {Column: "deployed_version", Value: (*svc).Status.DeployedVersion()}, {Column: "deployed_version_timestamp", Value: (*svc).Status.DeployedVersionTimestamp()}, - {Column: "approved_version", Value: (*svc).Status.ApprovedVersion()}}) + {Column: "approved_version", Value: (*svc).Status.ApprovedVersion()}}, + ) // THEN that data can be queried got := queryRow(t, testAPI.db, serviceName) @@ -322,10 +323,8 @@ func TestAPI_extractServiceStatus(t *testing.T) { *cfg.Settings.Data.DatabaseFile = "TestAPI_extractServiceStatus.db" defer os.Remove(*cfg.Settings.Data.DatabaseFile) testAPI := api{config: cfg} - go func() { - testAPI.initialise() - testAPI.handler() - }() + testAPI.initialise() + go testAPI.handler() wantStatus := make([]svcstatus.Status, len(cfg.Service)) // push a random Status for each Service to the DB index := 0 @@ -401,8 +400,9 @@ func Test_UpdateColumnTypes(t *testing.T) { r, w, _ := os.Pipe() os.Stdout = w - db, err := sql.Open("sqlite", "Test_UpdateColumnTypes.db") - defer os.Remove("Test_UpdateColumnTypes.db") + databaseFile := "Test_UpdateColumnTypes.db" + db, err := sql.Open("sqlite", databaseFile) + defer os.Remove(databaseFile) if err != nil { t.Fatal(err) } diff --git a/db/types.go b/db/types.go index 9f2a5061..4b166ab7 100644 --- a/db/types.go +++ b/db/types.go @@ -18,9 +18,15 @@ import ( "database/sql" "github.com/release-argus/Argus/config" + "github.com/release-argus/Argus/util" ) type api struct { config *config.Config db *sql.DB } + +var ( + jLog *util.JLog + logFrom util.LogFrom +) diff --git a/web/api/types/argus.go b/web/api/types/argus.go index c027ec19..c820c84f 100644 --- a/web/api/types/argus.go +++ b/web/api/types/argus.go @@ -173,10 +173,10 @@ type RuntimeInfo struct { // Flags is the runtime flags. type Flags struct { - ConfigFile *string `json:"config.file,omitempty" yaml:"config.file,omitempty"` + ConfigFile string `json:"config.file,omitempty" yaml:"config.file,omitempty"` LogLevel string `json:"log.level,omitempty" yaml:"log.level,omitempty"` LogTimestamps *bool `json:"log.timestamps,omitempty" yaml:"log.timestamps,omitempty"` - DataDatabaseFile *string `json:"data.database-file,omitempty" yaml:"data.database-file,omitempty"` + DataDatabaseFile string `json:"data.database-file,omitempty" yaml:"data.database-file,omitempty"` WebListenHost string `json:"web.listen-host,omitempty" yaml:"web.listen-host,omitempty"` WebListenPort string `json:"web.listen-port,omitempty" yaml:"web.listen-port,omitempty"` WebCertFile *string `json:"web.cert-file" yaml:"web.cert-file"` diff --git a/web/api/v1/http-api-flags.go b/web/api/v1/http-api-flags.go index 7fb4a99f..7488b39e 100644 --- a/web/api/v1/http-api-flags.go +++ b/web/api/v1/http-api-flags.go @@ -29,7 +29,7 @@ func (api *API) httpFlags(w http.ResponseWriter, r *http.Request) { // Create and send status page data msg := api_type.Flags{ - ConfigFile: &api.Config.File, + ConfigFile: api.Config.File, LogLevel: api.Config.Settings.LogLevel(), LogTimestamps: api.Config.Settings.LogTimestamps(), DataDatabaseFile: api.Config.Settings.DataDatabaseFile(), diff --git a/web/api/v1/http-api-flags_test.go b/web/api/v1/http-api-flags_test.go index 48905a95..0c3193f0 100644 --- a/web/api/v1/http-api-flags_test.go +++ b/web/api/v1/http-api-flags_test.go @@ -44,7 +44,7 @@ func TestHTTP_httpFlags(t *testing.T) { "config.file":"` + file + `", "log.level":"` + fmt.Sprintf(api.Config.Settings.LogLevel()) + `", "log.timestamps":` + fmt.Sprint(*api.Config.Settings.LogTimestamps()) + `, - "data.database-file":"` + *api.Config.Settings.DataDatabaseFile() + `", + "data.database-file":"` + api.Config.Settings.DataDatabaseFile() + `", "web.listen-host":"` + api.Config.Settings.WebListenHost() + `", "web.listen-port":"[0-9]{1,5}", "web.cert-file":null, diff --git a/web/help_test.go b/web/help_test.go index fb871647..20d1596b 100644 --- a/web/help_test.go +++ b/web/help_test.go @@ -23,7 +23,6 @@ import ( "crypto/x509/pkix" "encoding/pem" "fmt" - "io/ioutil" "math/big" "net" "net/http" @@ -100,7 +99,7 @@ func testConfig(path string, t *testing.T) (cfg *config.Config) { &map[string]bool{}, jLog) if t != nil { - t.Cleanup(func() { os.Remove(*cfg.Settings.DataDatabaseFile()) }) + t.Cleanup(func() { os.Remove(cfg.Settings.DataDatabaseFile()) }) } cfg.Settings.NilUndefinedFlags(&map[string]bool{}) @@ -172,9 +171,6 @@ func getFreePort() (int, error) { return 0, err } ln.Close() - if err != nil { - return 0, err - } return ln.Addr().(*net.TCPAddr).Port, nil } @@ -286,6 +282,7 @@ func testWebHook(failing bool, id string) *webhook.WebHook { &webhook.WebHookDefaults{}, &webhook.WebHookDefaults{}, &webhook.WebHookDefaults{}) + wh.ID = id if failing { wh.Secret = "notArgus" } @@ -373,15 +370,12 @@ func generateCertFiles(certFile, keyFile string) error { // Convert the certificate and private key to PEM format certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)}) - if err != nil { - return err - } // Write the certificate and private key to files - if err := ioutil.WriteFile(certFile, certPEM, 0644); err != nil { + if err := os.WriteFile(certFile, certPEM, 0644); err != nil { return err } - if err := ioutil.WriteFile(keyFile, keyPEM, 0600); err != nil { + if err := os.WriteFile(keyFile, keyPEM, 0600); err != nil { return err }