From 93f216b3578f52b8cad41e2c8c2e309f2d24ce41 Mon Sep 17 00:00:00 2001 From: Joseph Kavanagh Date: Sat, 20 Apr 2024 03:28:28 +0100 Subject: [PATCH 1/4] test: fix race conditions * stop setting the JLog for packages multiple times, somtimes causing a RACE * consistently pass logFrom as a pointer --- cmd/argus/main.go | 5 +- cmd/argus/main_test.go | 4 +- commands/commands.go | 6 +- config/defaults.go | 2 +- config/edit.go | 10 +- config/help_test.go | 6 +- config/init.go | 26 ++--- config/save.go | 14 +-- config/settings.go | 11 +-- config/verify.go | 2 +- db/handlers_test.go | 62 +++++------- db/help_test.go | 32 +++++-- db/init.go | 12 ++- db/init_test.go | 95 ++++++++----------- db/types.go | 2 +- notifiers/shoutrrr/shoutrrr.go | 6 +- service/deployed_version/query.go | 18 ++-- service/deployed_version/refresh.go | 12 +-- service/handlers.go | 6 +- service/latest_version/filter/docker.go | 6 +- service/latest_version/filter/docker_test.go | 2 +- service/latest_version/filter/regex.go | 6 +- service/latest_version/filter/require.go | 2 +- service/latest_version/filter/urlcommand.go | 20 ++-- .../latest_version/filter/urlcommand_test.go | 2 +- service/latest_version/github.go | 12 +-- service/latest_version/query.go | 38 ++++---- service/latest_version/refresh.go | 8 +- service/new.go | 6 +- service/track.go | 2 +- test/config.go | 20 +++- test/config_test.go | 2 +- testing/commands.go | 4 +- testing/service.go | 6 +- testing/shoutrrr.go | 8 +- util/log.go | 16 ++-- util/log_test.go | 14 +-- util/util_test.go | 12 +-- web/api/v1/api.go | 11 ++- web/api/v1/help_test.go | 42 +++++--- web/api/v1/http-api-actions.go | 35 ++++--- web/api/v1/http-api-actions_test.go | 2 +- web/api/v1/http-api-config.go | 6 +- web/api/v1/http-api-edit.go | 56 +++++------ web/api/v1/http-api-flags.go | 6 +- web/api/v1/http-api-service.go | 14 +-- web/api/v1/http-api-status.go | 6 +- web/api/v1/http.go | 6 +- web/api/v1/http_test.go | 9 +- web/api/v1/util.go | 12 +++ web/api/v1/websocket-client.go | 42 ++++---- web/api/v1/websocket-hub.go | 11 ++- web/api/v1/websocket-hub_test.go | 41 ++++---- web/help_test.go | 10 +- web/web.go | 18 ++-- web/web_test.go | 6 +- webhook/send.go | 11 ++- 57 files changed, 459 insertions(+), 402 deletions(-) diff --git a/cmd/argus/main.go b/cmd/argus/main.go index b4a8fa7a..d68f452e 100644 --- a/cmd/argus/main.go +++ b/cmd/argus/main.go @@ -68,7 +68,7 @@ func main() { } } msg := fmt.Sprintf("Found %d services to monitor:", serviceCount) - jLog.Info(msg, util.LogFrom{}, true) + jLog.Info(msg, &util.LogFrom{}, true) for _, key := range config.Order { if config.Service[key].Options.GetActive() { @@ -77,8 +77,7 @@ func main() { } } - db.LogInit(&jLog, config.Settings.DataDatabaseFile()) - go db.Run(&config) + go db.Run(&config, &jLog) // Track all targets for changes in version and act on any found changes. go (&config).Service.Track(&config.Order, &config.OrderMutex) diff --git a/cmd/argus/main_test.go b/cmd/argus/main_test.go index 09518304..1ede9455 100644 --- a/cmd/argus/main_test.go +++ b/cmd/argus/main_test.go @@ -90,9 +90,7 @@ func TestTheMain(t *testing.T) { os.Setenv("ARGUS_SERVICE_LATEST_VERSION_ACCESS_TOKEN", accessToken) // WHEN Main is called - go func() { - main() - }() + go main() time.Sleep(3 * time.Second) // THEN the program will have printed everything expected diff --git a/commands/commands.go b/commands/commands.go index 037ac2b3..0cb455ee 100644 --- a/commands/commands.go +++ b/commands/commands.go @@ -94,11 +94,11 @@ func (c *Controller) ExecIndex(logFrom *util.LogFrom, index int) (err error) { // Exec this Command and return any errors encountered. func (c *Command) Exec(logFrom *util.LogFrom) error { - jLog.Info(fmt.Sprintf("Executing '%s'", c), *logFrom, true) + jLog.Info(fmt.Sprintf("Executing '%s'", c), logFrom, true) out, err := exec.Command((*c)[0], (*c)[1:]...).Output() - jLog.Error(util.ErrorToString(err), *logFrom, err != nil) - jLog.Info(string(out), *logFrom, err == nil && string(out) != "") + jLog.Error(util.ErrorToString(err), logFrom, err != nil) + jLog.Info(string(out), logFrom, err == nil && string(out) != "") //nolint:wrapcheck return err diff --git a/config/defaults.go b/config/defaults.go index 2383bd0e..b3a7b8f7 100644 --- a/config/defaults.go +++ b/config/defaults.go @@ -71,7 +71,7 @@ func (d *Defaults) MapEnvToStruct() { jLog.Fatal( "One or more 'ARGUS_' environment variables are incorrect:\n"+ strings.ReplaceAll(util.ErrorToString(err), "\\", "\n"), - util.LogFrom{}, true) + &util.LogFrom{}, true) } } diff --git a/config/edit.go b/config/edit.go index 5b10f927..678e7357 100644 --- a/config/edit.go +++ b/config/edit.go @@ -31,7 +31,7 @@ func (c *Config) AddService(oldServiceID string, newService *service.Service) (e // Check the service doesn't already exist if the name is changing if oldServiceID != newService.ID && c.Service[newService.ID] != nil { err = fmt.Errorf("service %q already exists", newService.ID) - jLog.Error(err, logFrom, true) + jLog.Error(err, &logFrom, true) return } @@ -46,7 +46,7 @@ func (c *Config) AddService(oldServiceID string, newService *service.Service) (e c.Service[oldServiceID].Status.DeployedVersion() != newService.Status.DeployedVersion() // New service if oldServiceID == "" { - jLog.Info("Adding service", logFrom, true) + jLog.Info("Adding service", &logFrom, true) c.Order = append(c.Order, newService.ID) // Create the service map if it doesn't exist //nolint:typecheck @@ -58,7 +58,7 @@ func (c *Config) AddService(oldServiceID string, newService *service.Service) (e } else { // Keeping the same ID if oldServiceID == newService.ID { - jLog.Info("Replacing service", logFrom, true) + jLog.Info("Replacing service", &logFrom, true) // Delete the old service c.Service[oldServiceID].PrepDelete(false) @@ -105,7 +105,7 @@ func (c *Config) RenameService(oldService string, newService *service.Service) { jLog.Info( fmt.Sprintf("%q", newService.ID), - util.LogFrom{Primary: "RenameService", Secondary: oldService}, + &util.LogFrom{Primary: "RenameService", Secondary: oldService}, true) // Replace the service in the order/config c.Order = util.ReplaceElement(c.Order, oldService, newService.ID) @@ -132,7 +132,7 @@ func (c *Config) DeleteService(serviceID string) { jLog.Info( "Deleting service", - util.LogFrom{Primary: "DeleteService", Secondary: serviceID}, + &util.LogFrom{Primary: "DeleteService", Secondary: serviceID}, true) // Remove the service from the Order c.Order = util.RemoveElement(c.Order, serviceID) diff --git a/config/help_test.go b/config/help_test.go index 5d1c4c18..a5818412 100644 --- a/config/help_test.go +++ b/config/help_test.go @@ -121,11 +121,11 @@ func testLoadBasic(file string, t *testing.T) (config *Config) { //#nosec G304 -- Loading the test config file data, err := os.ReadFile(file) jLog.Fatal(fmt.Sprintf("Error reading %q\n%s", file, err), - util.LogFrom{}, err != nil) + &util.LogFrom{}, err != nil) err = yaml.Unmarshal(data, config) jLog.Fatal(fmt.Sprintf("Unmarshal of %q failed\n%s", file, err), - util.LogFrom{}, err != nil) + &util.LogFrom{}, err != nil) saveChannel := make(chan bool, 32) config.SaveChannel = &saveChannel @@ -138,7 +138,7 @@ func testLoadBasic(file string, t *testing.T) (config *Config) { config.GetOrder(data) mutex.Lock() defer mutex.Unlock() - config.Init() + config.Init(true) for name, service := range config.Service { service.ID = name } diff --git a/config/init.go b/config/init.go index d04ee5a0..6ce94493 100644 --- a/config/init.go +++ b/config/init.go @@ -36,7 +36,7 @@ func LogInit(log *util.JLog) { } // Init will hand out the appropriate Defaults.X and HardDefaults.X pointer(s) -func (c *Config) Init() { +func (c *Config) Init(setLog bool) { c.OrderMutex.RLock() defer c.OrderMutex.RUnlock() @@ -49,14 +49,14 @@ func (c *Config) Init() { c.HardDefaults.Service.Status.SaveChannel = c.SaveChannel - jLog.SetTimestamps(*c.Settings.LogTimestamps()) - jLog.SetLevel(c.Settings.LogLevel()) + if setLog { + jLog.SetTimestamps(*c.Settings.LogTimestamps()) + jLog.SetLevel(c.Settings.LogLevel()) + } - i := 0 - for _, name := range c.Order { - i++ - jLog.Debug(fmt.Sprintf("%d/%d %s Init", i, len(c.Service), name), - util.LogFrom{}, true) + for i, name := range c.Order { + jLog.Debug(fmt.Sprintf("%d/%d %s Init", i+1, len(c.Service), name), + &util.LogFrom{}, true) c.Service[name].Init( &c.Defaults.Service, &c.HardDefaults.Service, &c.Notify, &c.Defaults.Notify, &c.HardDefaults.Notify, @@ -68,17 +68,19 @@ func (c *Config) Init() { func (c *Config) Load(file string, flagset *map[string]bool, log *util.JLog) { c.File = file // Give the log to the other packages - LogInit(log) + if log != nil { + LogInit(log) + } c.Settings.NilUndefinedFlags(flagset) //#nosec G304 -- Loading the file asked for by the user data, err := os.ReadFile(file) jLog.Fatal(fmt.Sprintf("Error reading %q\n%s", file, err), - util.LogFrom{}, err != nil) + &util.LogFrom{}, err != nil) err = yaml.Unmarshal(data, c) jLog.Fatal(fmt.Sprintf("Unmarshal of %q failed\n%s", file, err), - util.LogFrom{}, err != nil) + &util.LogFrom{}, err != nil) c.GetOrder(data) @@ -100,6 +102,6 @@ func (c *Config) Load(file string, flagset *map[string]bool, log *util.JLog) { // SaveHandler that listens for calls to save config changes. go c.SaveHandler() - c.Init() + c.Init(log != nil) c.CheckValues() } diff --git a/config/save.go b/config/save.go index 8c6d304d..ca392e42 100644 --- a/config/save.go +++ b/config/save.go @@ -65,7 +65,7 @@ func (c *Config) Save() { // Write the config to file (unordered slices, but with an order list) file, err := os.OpenFile(c.File, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600) errMsg := fmt.Sprintf("error opening/creating file: %v", err) - jLog.Fatal(errMsg, util.LogFrom{}, err != nil) + jLog.Fatal(errMsg, &util.LogFrom{}, err != nil) defer file.Close() // Create the yaml encoder and set indentation @@ -77,19 +77,19 @@ func (c *Config) Save() { jLog.Fatal( fmt.Sprintf("error encoding %s:\n%v\n", c.File, err), - util.LogFrom{}, + &util.LogFrom{}, err != nil) err = file.Close() jLog.Fatal( fmt.Sprintf("error opening %s:\n%v\n", c.File, err), - util.LogFrom{}, + &util.LogFrom{}, err != nil) // Read the file to find what needs to be re-arranged data, err := os.ReadFile(c.File) msg := fmt.Sprintf("Error reading %q\n%s", c.File, err) - jLog.Fatal(msg, util.LogFrom{}, err != nil) + jLog.Fatal(msg, &util.LogFrom{}, err != nil) lines := strings.Split(string(util.NormaliseNewlines(data)), "\n") // Fix the ordering of the read data @@ -257,7 +257,7 @@ func (c *Config) Save() { // Open the file. file, err = os.OpenFile(c.File, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600) errMsg = fmt.Sprintf("error opening/creating file: %v", err) - jLog.Fatal(errMsg, util.LogFrom{}, err != nil) + jLog.Fatal(errMsg, &util.LogFrom{}, err != nil) // Buffered writes to the file. writer := bufio.NewWriter(file) @@ -269,10 +269,10 @@ func (c *Config) Save() { // Flush the writes. err = writer.Flush() errMsg = fmt.Sprintf("error writing file: %v", err) - jLog.Fatal(errMsg, util.LogFrom{}, err != nil) + jLog.Fatal(errMsg, &util.LogFrom{}, err != nil) jLog.Info( fmt.Sprintf("Saved service updates to %s", c.File), - util.LogFrom{}, true) + &util.LogFrom{}, true) } // removeAllServiceDefaults removes the written default values from all Services diff --git a/config/settings.go b/config/settings.go index 44beb751..cc6b9abf 100644 --- a/config/settings.go +++ b/config/settings.go @@ -77,7 +77,7 @@ func (s *SettingsBase) MapEnvToStruct() { jLog.Fatal( "One or more 'ARGUS_' environment variables are incorrect:\n"+ strings.ReplaceAll(util.ErrorToString(err), "\\", "\n"), - util.LogFrom{}, true) + &util.LogFrom{}, true) } s.CheckValues() // Set hash values and remove empty structs. } @@ -277,7 +277,6 @@ func (s *Settings) LogTimestamps() *bool { // LogLevel. func (s *Settings) LogLevel() string { return strings.ToUpper(*util.FirstNonNilPtr( - s.FromFlags.Log.Level, s.FromFlags.Log.Level, s.Log.Level, s.HardDefaults.Log.Level)) @@ -327,7 +326,7 @@ func (s *Settings) WebCertFile() *string { if _, err := os.Stat(*certFile); err != nil { if !filepath.IsAbs(*certFile) { path, execErr := os.Executable() - jLog.Error(execErr, util.LogFrom{}, execErr != nil) + jLog.Error(execErr, &util.LogFrom{}, execErr != nil) err = fmt.Errorf(strings.Replace( err.Error(), " "+*certFile+":", @@ -335,7 +334,7 @@ func (s *Settings) WebCertFile() *string { 1, )) } - jLog.Fatal("settings.web.cert_file "+err.Error(), util.LogFrom{}, true) + jLog.Fatal("settings.web.cert_file "+err.Error(), &util.LogFrom{}, true) } return certFile } @@ -352,7 +351,7 @@ func (s *Settings) WebKeyFile() *string { if _, err := os.Stat(*keyFile); err != nil { if !filepath.IsAbs(*keyFile) { path, execErr := os.Executable() - jLog.Error(execErr, util.LogFrom{}, execErr != nil) + jLog.Error(execErr, &util.LogFrom{}, execErr != nil) err = fmt.Errorf(strings.Replace( err.Error(), " "+*keyFile+":", @@ -360,7 +359,7 @@ func (s *Settings) WebKeyFile() *string { 1, )) } - jLog.Fatal("settings.web.key_file "+err.Error(), util.LogFrom{}, true) + jLog.Fatal("settings.web.key_file "+err.Error(), &util.LogFrom{}, true) } return keyFile } diff --git a/config/verify.go b/config/verify.go index a5426c85..1f991793 100644 --- a/config/verify.go +++ b/config/verify.go @@ -49,7 +49,7 @@ func (c *Config) CheckValues() { if errs != nil { fmt.Println(strings.ReplaceAll(errs.Error(), "\\", "\n")) - jLog.Fatal("Config could not be parsed successfully.", util.LogFrom{}, true) + jLog.Fatal("Config could not be parsed successfully.", &util.LogFrom{}, true) } } diff --git a/db/handlers_test.go b/db/handlers_test.go index 13576346..34c91a8a 100644 --- a/db/handlers_test.go +++ b/db/handlers_test.go @@ -17,14 +17,10 @@ package db import ( - "fmt" - "os" - "strings" "testing" "time" dbtype "github.com/release-argus/Argus/db/types" - _ "modernc.org/sqlite" ) func TestAPI_UpdateRow(t *testing.T) { @@ -72,18 +68,16 @@ func TestAPI_UpdateRow(t *testing.T) { t.Run(name, func(t *testing.T) { t.Parallel() - cfg := testConfig() - testAPI := api{config: cfg} - *testAPI.config.Settings.Data.DatabaseFile = fmt.Sprintf("%s.db", strings.ReplaceAll(name, " ", "_")) - defer os.Remove(*cfg.Settings.Data.DatabaseFile) - testAPI.initialise() + tAPI := testAPI(name, "TestAPI_UpdateRow") + defer dbCleanup(tAPI) + tAPI.initialise() // WHEN updateRow is called targeting single/multiple cells - testAPI.updateRow(tc.target, tc.cells) + tAPI.updateRow(tc.target, tc.cells) time.Sleep(100 * time.Millisecond) // THEN those cell(s) are changed in the DB - row := queryRow(t, testAPI.db, tc.target) + row := queryRow(t, tAPI.db, tc.target) for _, cell := range tc.cells { var got string switch cell.Column { @@ -103,7 +97,6 @@ func TestAPI_UpdateRow(t *testing.T) { cell.Column, cell.Value, got) } } - testAPI.db.Close() time.Sleep(100 * time.Millisecond) }) } @@ -127,14 +120,13 @@ func TestAPI_DeleteRow(t *testing.T) { t.Run(name, func(t *testing.T) { t.Parallel() - cfg := testConfig() - testAPI := api{config: cfg} - *testAPI.config.Settings.Data.DatabaseFile = fmt.Sprintf("%s.db", strings.ReplaceAll(name, " ", "_")) - testAPI.initialise() + tAPI := testAPI(name, "TestAPI_DeleteRow") + defer dbCleanup(tAPI) + tAPI.initialise() // Ensure the row exists if tc.exists if tc.exists { - testAPI.updateRow( + tAPI.updateRow( tc.serviceID, []dbtype.Cell{ {Column: "latest_version", Value: "9.9.9"}, {Column: "deployed_version", Value: "8.8.8"}}, @@ -142,22 +134,20 @@ func TestAPI_DeleteRow(t *testing.T) { time.Sleep(100 * time.Millisecond) } // Check the row existance before the test - row := queryRow(t, testAPI.db, tc.serviceID) + row := queryRow(t, tAPI.db, tc.serviceID) if tc.exists && (row.LatestVersion() == "" || row.DeployedVersion() == "") { t.Errorf("expecting row to exist. got %#v", row) } // WHEN deleteRow is called targeting a row - testAPI.deleteRow(tc.serviceID) + tAPI.deleteRow(tc.serviceID) time.Sleep(100 * time.Millisecond) // THEN the row is deleted from the DB - row = queryRow(t, testAPI.db, tc.serviceID) + row = queryRow(t, tAPI.db, tc.serviceID) if row.LatestVersion() != "" || row.DeployedVersion() != "" { t.Errorf("expecting row to be deleted. got %#v", row) } - testAPI.db.Close() - os.Remove(*testAPI.config.Settings.Data.DatabaseFile) time.Sleep(100 * time.Millisecond) }) } @@ -165,14 +155,10 @@ func TestAPI_DeleteRow(t *testing.T) { func TestAPI_Handler(t *testing.T) { // GIVEN a DB with a few service status' - cfg := testConfig() - testAPI := api{config: cfg} - *testAPI.config.Settings.Data.DatabaseFile = "TestHandler.db" - defer os.Remove(*testAPI.config.Settings.Data.DatabaseFile) - defer os.Remove(*testAPI.config.Settings.Data.DatabaseFile + "-journal") - testAPI.initialise() - go testAPI.handler() - defer testAPI.db.Close() + tAPI := testAPI("TestAPI_Handler", "db") + defer dbCleanup(tAPI) + tAPI.initialise() + go tAPI.handler() // WHEN a message is sent to the DatabaseChannel targeting latest_version target := "keep0" @@ -180,7 +166,7 @@ func TestAPI_Handler(t *testing.T) { Column: "latest_version", Value: "9.9.9"} cell2 := dbtype.Cell{ Column: cell1.Column, Value: cell1.Value + "-dev"} - want := queryRow(t, testAPI.db, target) + want := queryRow(t, tAPI.db, target) want.SetLatestVersion(cell1.Value, false) msg1 := dbtype.Message{ ServiceID: target, @@ -190,11 +176,11 @@ func TestAPI_Handler(t *testing.T) { ServiceID: target, Cells: []dbtype.Cell{cell2}, } - *testAPI.config.DatabaseChannel <- msg1 + *tAPI.config.DatabaseChannel <- msg1 time.Sleep(250 * time.Millisecond) // THEN the cell was changed in the DB - got := queryRow(t, testAPI.db, target) + got := queryRow(t, tAPI.db, target) if got.LatestVersion() != want.LatestVersion() { t.Errorf("Expected %q to be updated to %q\ngot %#v\nwant %#v", cell1.Column, cell1.Value, got, want) @@ -203,14 +189,14 @@ func TestAPI_Handler(t *testing.T) { // ------------------------------ // WHEN a message is sent to the DatabaseChannel deleting a row - *testAPI.config.DatabaseChannel <- dbtype.Message{ + *tAPI.config.DatabaseChannel <- dbtype.Message{ ServiceID: target, Delete: true, } time.Sleep(250 * time.Millisecond) // THEN the row is deleted from the DB - got = queryRow(t, testAPI.db, target) + got = queryRow(t, tAPI.db, target) if got.LatestVersion() != "" || got.DeployedVersion() != "" { t.Errorf("Expected row to be deleted\ngot %#v\nwant %#v", got, want) } @@ -218,13 +204,13 @@ func TestAPI_Handler(t *testing.T) { // ------------------------------ // WHEN multiple messages are targeting the same row in quick succession - *testAPI.config.DatabaseChannel <- msg1 + *tAPI.config.DatabaseChannel <- msg1 wantLatestVersion := msg2.Cells[0].Value - *testAPI.config.DatabaseChannel <- msg2 + *tAPI.config.DatabaseChannel <- msg2 time.Sleep(250 * time.Millisecond) // THEN the last message is the one that is applied - got = queryRow(t, testAPI.db, target) + got = queryRow(t, tAPI.db, target) if got.LatestVersion() != wantLatestVersion { t.Errorf("Expected %q to be updated to %q\ngot %#v\nwant %#v", cell2.Column, cell2.Value, got, want) diff --git a/db/help_test.go b/db/help_test.go index 64f89da5..8684fc48 100644 --- a/db/help_test.go +++ b/db/help_test.go @@ -18,7 +18,9 @@ package db import ( "database/sql" + "fmt" "os" + "strings" "testing" "time" @@ -32,18 +34,21 @@ import ( var cfg *config.Config func TestMain(m *testing.M) { - log := util.NewJLog("DEBUG", false) - log.Testing = true databaseFile := "TestRun.db" - LogInit(log, databaseFile) + + // Log + jLog := util.NewJLog("DEBUG", false) + jLog.Testing = true + LogInit(jLog, databaseFile) cfg = testConfig() *cfg.Settings.Data.DatabaseFile = databaseFile - defer os.Remove(*cfg.Settings.Data.DatabaseFile) - go Run(cfg) + go Run(cfg, nil) time.Sleep(250 * time.Millisecond) // Time for db to start - os.Exit(m.Run()) + exitCode := m.Run() + os.Remove(*cfg.Settings.Data.DatabaseFile) + os.Exit(exitCode) } func stringPtr(val string) *string { @@ -82,6 +87,7 @@ func testConfig() (cfg *config.Config) { DatabaseChannel: &databaseChannel, SaveChannel: &saveChannel, } + // Services for svcName := range cfg.Service { svc := service.Service{ @@ -105,6 +111,20 @@ func testConfig() (cfg *config.Config) { return } +func testAPI(primary string, secondary string) *api { + testAPI := api{config: testConfig()} + + databaseFile := strings.ReplaceAll(fmt.Sprintf("%s-%s.db", primary, secondary), " ", "_") + testAPI.config.Settings.Data.DatabaseFile = &databaseFile + + return &testAPI +} +func dbCleanup(api *api) { + api.db.Close() + os.Remove(*api.config.Settings.Data.DatabaseFile) + os.Remove(*api.config.Settings.Data.DatabaseFile + "-journal") +} + func queryRow(t *testing.T, db *sql.DB, serviceID string) *svcstatus.Status { sqlStmt := ` SELECT diff --git a/db/init.go b/db/init.go index b7ecda8c..a7f0495d 100644 --- a/db/init.go +++ b/db/init.go @@ -27,8 +27,11 @@ import ( // LogInit for this package. func LogInit(log *util.JLog, databaseFile string) { - jLog = log - logFrom = util.LogFrom{Primary: "db", Secondary: databaseFile} + // Only set the log if it hasn't been set (avoid RACE condition) + if jLog == nil { + jLog = log + logFrom = &util.LogFrom{Primary: "db", Secondary: databaseFile} + } } func checkFile(path string) { @@ -64,8 +67,11 @@ func checkFile(path string) { } } -func Run(cfg *config.Config) { +func Run(cfg *config.Config, log *util.JLog) { api := api{config: cfg} + if log != nil { + LogInit(log, cfg.Settings.DataDatabaseFile()) + } api.initialise() defer api.db.Close() if len(api.config.Order) > 0 { diff --git a/db/init_test.go b/db/init_test.go index cb7b66da..17ed29d5 100644 --- a/db/init_test.go +++ b/db/init_test.go @@ -111,15 +111,14 @@ func TestCheckFile(t *testing.T) { func TestAPI_Initialise(t *testing.T) { // GIVEN a config with a database location - cfg := testConfig() - testAPI := api{config: cfg} - *testAPI.config.Settings.Data.DatabaseFile = "TestInitialise.db" + tAPI := testAPI("TestAPI_Initialise", "db") + defer dbCleanup(tAPI) // WHEN the db is initialised with it - testAPI.initialise() + tAPI.initialise() // THEN the status table was created in the db - rows, err := testAPI.db.Query(` + rows, err := tAPI.db.Query(` SELECT id, latest_version, latest_version_timestamp, @@ -142,26 +141,23 @@ func TestAPI_Initialise(t *testing.T) { ) err = rows.Scan(&id, &lv, &lvt, &dv, &dvt, &av) } - testAPI.db.Close() - os.Remove(*testAPI.config.Settings.Data.DatabaseFile) } func TestDBQueryService(t *testing.T) { // GIVEN a blank DB - cfg := testConfig() - testAPI := api{config: cfg} - *testAPI.config.Settings.Data.DatabaseFile = "TestQueryService.db" - testAPI.initialise() + tAPI := testAPI("TestDBQueryService", "db") + defer dbCleanup(tAPI) + tAPI.initialise() // Get a Service from the Config var serviceName string - for k := range testAPI.config.Service { + for k := range tAPI.config.Service { serviceName = k break } - svc := testAPI.config.Service[serviceName] + svc := tAPI.config.Service[serviceName] // WHEN the database contains data for a Service - testAPI.updateRow( + tAPI.updateRow( serviceName, []dbtype.Cell{ {Column: "id", Value: serviceName}, @@ -173,7 +169,7 @@ func TestDBQueryService(t *testing.T) { ) // THEN that data can be queried - got := queryRow(t, testAPI.db, serviceName) + got := queryRow(t, tAPI.db, serviceName) if (*svc).Status.LatestVersion() != got.LatestVersion() { t.Errorf("LatestVersion %q was not pushed to the db. Got %q", (*svc).Status.LatestVersion(), got.LatestVersion()) @@ -194,16 +190,13 @@ func TestDBQueryService(t *testing.T) { t.Errorf("ApprovedVersion %q was not pushed to the db. Got %q", (*svc).Status.ApprovedVersion(), got.ApprovedVersion()) } - testAPI.db.Close() - os.Remove(*testAPI.config.Settings.Data.DatabaseFile) } func TestAPI_RemoveUnknownServices(t *testing.T) { // GIVEN a DB with loads of service status' - cfg := testConfig() - testAPI := api{config: cfg} - *testAPI.config.Settings.Data.DatabaseFile = "TestRemoveUnknownServices.db" - testAPI.initialise() + tAPI := testAPI("TestAPI_RemoveUnknownServices", "db") + defer dbCleanup(tAPI) + tAPI.initialise() sqlStmt := ` INSERT OR REPLACE INTO status ( @@ -215,7 +208,7 @@ func TestAPI_RemoveUnknownServices(t *testing.T) { approved_version ) VALUES` - for id, svc := range testAPI.config.Service { + for id, svc := range tAPI.config.Service { sqlStmt += fmt.Sprintf(" (%q, %q, %q, %q, %q, %q),", id, svc.Status.LatestVersion(), @@ -225,16 +218,16 @@ func TestAPI_RemoveUnknownServices(t *testing.T) { svc.Status.ApprovedVersion(), ) } - _, err := testAPI.db.Exec(sqlStmt[:len(sqlStmt)-1] + ";") + _, err := tAPI.db.Exec(sqlStmt[:len(sqlStmt)-1] + ";") if err != nil { t.Fatal(err) } // WHEN the unknown Services are removed with removeUnknownServices - testAPI.removeUnknownServices() + tAPI.removeUnknownServices() // THEN the rows of Services not in .All are returned - rows, err := testAPI.db.Query(` + rows, err := tAPI.db.Query(` SELECT id, latest_version, latest_version_timestamp, @@ -258,18 +251,16 @@ func TestAPI_RemoveUnknownServices(t *testing.T) { av string ) err = rows.Scan(&id, &lv, &lvt, &dv, &dvt, &av) - svc := testAPI.config.Service[id] - if svc == nil || !util.Contains(testAPI.config.Order, id) { + svc := tAPI.config.Service[id] + if svc == nil || !util.Contains(tAPI.config.Order, id) { t.Errorf("%q should have been removed from the table", id) } } - if count != len(testAPI.config.Order) { + if count != len(tAPI.config.Order) { t.Errorf("Only %d were left in the table. Expected %d", - count, len(testAPI.config.Order)) + count, len(tAPI.config.Order)) } - testAPI.db.Close() - os.Remove(*testAPI.config.Settings.Data.DatabaseFile) } func TestAPI_Run(t *testing.T) { @@ -291,13 +282,15 @@ func TestAPI_Run(t *testing.T) { if err != nil { t.Fatal(err) } + defer os.Remove(*otherCfg.Settings.Data.DatabaseFile) err = os.WriteFile(*otherCfg.Settings.Data.DatabaseFile, bytesRead, os.FileMode(0644)) if err != nil { t.Fatal(err) } - testAPI := api{config: otherCfg} - testAPI.initialise() - got := queryRow(t, testAPI.db, target) + tAPI := api{config: otherCfg} + defer dbCleanup(&tAPI) + tAPI.initialise() + got := queryRow(t, tAPI.db, target) want := svcstatus.Status{} want.Init( 0, 0, 0, @@ -309,26 +302,21 @@ func TestAPI_Run(t *testing.T) { want.SetDeployedVersion("0.0.0", false) want.SetDeployedVersionTimestamp("2020-01-01T01:01:01Z") if got.LatestVersion() != want.LatestVersion() { - t.Errorf("Expected %q to be updated to %q\ngot %v\nwant %v", + t.Errorf("Expected %q to be updated to %q, not %q.\nWant %v", cell.Column, cell.Value, got, want.String()) } - testAPI.db.Close() - os.Remove(*cfg.Settings.Data.DatabaseFile) - os.Remove(*otherCfg.Settings.Data.DatabaseFile) } func TestAPI_extractServiceStatus(t *testing.T) { // GIVEN an API on a DB containing atleast 1 row - cfg := testConfig() - *cfg.Settings.Data.DatabaseFile = "TestAPI_extractServiceStatus.db" - defer os.Remove(*cfg.Settings.Data.DatabaseFile) - testAPI := api{config: cfg} - testAPI.initialise() - go testAPI.handler() + tAPI := testAPI("TestAPI_extractServiceStatus", "db") + defer dbCleanup(tAPI) + tAPI.initialise() + go tAPI.handler() wantStatus := make([]svcstatus.Status, len(cfg.Service)) // push a random Status for each Service to the DB index := 0 - for id, svc := range cfg.Service { + for id, svc := range tAPI.config.Service { id := id wantStatus[index].ServiceID = &id wantStatus[index].SetLatestVersion(fmt.Sprintf("%d.%d.%d", rand.Intn(10), rand.Intn(10), rand.Intn(10)), false) @@ -337,7 +325,7 @@ func TestAPI_extractServiceStatus(t *testing.T) { wantStatus[index].SetDeployedVersionTimestamp(time.Now().UTC().Format(time.RFC3339)) wantStatus[index].SetApprovedVersion(fmt.Sprintf("%d.%d.%d", rand.Intn(10), rand.Intn(10), rand.Intn(10)), false) - *cfg.DatabaseChannel <- dbtype.Message{ + *tAPI.config.DatabaseChannel <- dbtype.Message{ ServiceID: id, Cells: []dbtype.Cell{ {Column: "id", Value: id}, @@ -355,29 +343,30 @@ func TestAPI_extractServiceStatus(t *testing.T) { time.Sleep(250 * time.Millisecond) // WHEN extractServiceStatus is called - testAPI.extractServiceStatus() + tAPI.extractServiceStatus() // THEN the Status in the Config is updated + errMsg := "Expected %q to be updated to %q, got %q.\nWant %q" for i := range wantStatus { - row := queryRow(t, testAPI.db, *wantStatus[i].ServiceID) + row := queryRow(t, tAPI.db, *wantStatus[i].ServiceID) if row.LatestVersion() != wantStatus[i].LatestVersion() { - t.Errorf("Expected %q to be updated to %q\ngot %q, want %q", + t.Errorf(errMsg, "latest_version", row.LatestVersion(), row, wantStatus[i].String()) } if row.LatestVersionTimestamp() != wantStatus[i].LatestVersionTimestamp() { - t.Errorf("Expected %q to be updated to %q\ngot %q, want %q", + t.Errorf(errMsg, "latest_version_timestamp", row.LatestVersionTimestamp(), row, wantStatus[i].String()) } if row.DeployedVersion() != wantStatus[i].DeployedVersion() { - t.Errorf("Expected %q to be updated to %q\ngot %q, want %q", + t.Errorf(errMsg, "deployed_version", row.DeployedVersion(), row, wantStatus[i].String()) } if row.DeployedVersionTimestamp() != wantStatus[i].DeployedVersionTimestamp() { - t.Errorf("Expected %q to be updated to %q\ngot %q, want %q", + t.Errorf(errMsg, "deployed_version_timestamp", row.DeployedVersionTimestamp(), row, wantStatus[i].String()) } if row.ApprovedVersion() != wantStatus[i].ApprovedVersion() { - t.Errorf("Expected %q to be updated to %q\ngot %q, want %q", + t.Errorf(errMsg, "approved_version", row.ApprovedVersion(), row, wantStatus[i].String()) } } diff --git a/db/types.go b/db/types.go index 4b166ab7..9ec0e324 100644 --- a/db/types.go +++ b/db/types.go @@ -28,5 +28,5 @@ type api struct { var ( jLog *util.JLog - logFrom util.LogFrom + logFrom *util.LogFrom ) diff --git a/notifiers/shoutrrr/shoutrrr.go b/notifiers/shoutrrr/shoutrrr.go index ed6a0c65..71e7c62a 100644 --- a/notifiers/shoutrrr/shoutrrr.go +++ b/notifiers/shoutrrr/shoutrrr.go @@ -360,7 +360,7 @@ func (s *Shoutrrr) Send( useDelay bool, useMetrics bool, ) (errs error) { - logFrom := util.LogFrom{Primary: s.ID, Secondary: serviceInfo.ID} // For logging + logFrom := &util.LogFrom{Primary: s.ID, Secondary: serviceInfo.ID} // For logging if useDelay && s.GetDelay() != "0s" { // Delay sending the Shoutrrr message by the defined interval. @@ -402,7 +402,7 @@ func (s *Shoutrrr) parseSend( err []error, combinedErrs map[string]int, serviceName string, - logFrom util.LogFrom, + logFrom *util.LogFrom, ) (failed bool) { for i := range err { if err[i] != nil { @@ -441,7 +441,7 @@ func (s *Shoutrrr) send( message string, params *shoutrrr_types.Params, serviceName string, - logFrom util.LogFrom, + logFrom *util.LogFrom, ) (errs error) { combinedErrs := make(map[string]int) triesLeft := s.GetMaxTries() // Number of times to send Shoutrrr (until 200 received). diff --git a/service/deployed_version/query.go b/service/deployed_version/query.go index 3e7089a9..598d1950 100644 --- a/service/deployed_version/query.go +++ b/service/deployed_version/query.go @@ -63,7 +63,7 @@ func (l *Lookup) query(logFrom *util.LogFrom) (string, error) { if l.JSON != "" { version, err = util.GetValueByKey(rawBody, l.JSON, l.GetURL()) if err != nil { - jLog.Error(err, *logFrom, true) + jLog.Error(err, logFrom, true) //nolint:wrapcheck return "", err } @@ -80,7 +80,7 @@ func (l *Lookup) query(logFrom *util.LogFrom) (string, error) { if len(texts) == 0 { err := fmt.Errorf("regex %q didn't find a match on %q", l.Regex, version) - jLog.Warn(err, *logFrom, true) + jLog.Warn(err, logFrom, true) return "", err } @@ -97,7 +97,7 @@ func (l *Lookup) query(logFrom *util.LogFrom) (string, error) { "style of 'MAJOR.MINOR.PATCH' (https://semver.org/), or disabling semantic versioning "+ "(globally with defaults.service.semantic_versioning or just for this service with the semantic_versioning var)", version) - jLog.Error(err, *logFrom, true) + jLog.Error(err, logFrom, true) return "", err } } @@ -175,7 +175,7 @@ func (l *Lookup) HandleNewVersion(version string, writeToDB bool) { // Announce version change to WebSocket clients. jLog.Info( fmt.Sprintf("Updated to %q", version), - util.LogFrom{Primary: *l.Status.ServiceID}, + &util.LogFrom{Primary: *l.Status.ServiceID}, true) l.Status.AnnounceUpdate() } @@ -191,7 +191,7 @@ func (l *Lookup) httpRequest(logFrom *util.LogFrom) (rawBody []byte, err error) req, err := http.NewRequest(http.MethodGet, l.GetURL(), nil) if err != nil { - jLog.Error(err, *logFrom, true) + jLog.Error(err, logFrom, true) return } @@ -212,23 +212,23 @@ func (l *Lookup) httpRequest(logFrom *util.LogFrom) (rawBody []byte, err error) // Don't crash on invalid certs. if strings.Contains(err.Error(), "x509") { err = fmt.Errorf("x509 (certificate invalid)") - jLog.Warn(err, *logFrom, true) + jLog.Warn(err, logFrom, true) return } - jLog.Error(err, *logFrom, true) + jLog.Error(err, logFrom, true) return } // Ignore non-2XX responses. if resp.StatusCode < 200 || resp.StatusCode >= 300 { err = fmt.Errorf("non-2XX response code: %d", resp.StatusCode) - jLog.Warn(err, *logFrom, true) + jLog.Warn(err, logFrom, true) return } // Read the response body. defer resp.Body.Close() rawBody, err = io.ReadAll(resp.Body) - jLog.Error(err, *logFrom, err != nil) + jLog.Error(err, logFrom, err != nil) return } diff --git a/service/deployed_version/refresh.go b/service/deployed_version/refresh.go index 34b5fe48..aec5c211 100644 --- a/service/deployed_version/refresh.go +++ b/service/deployed_version/refresh.go @@ -86,7 +86,7 @@ func (l *Lookup) applyOverrides( l.Defaults, l.HardDefaults) if err := lookup.CheckValues(""); err != nil { - jLog.Error(err, *logFrom, true) + jLog.Error(err, logFrom, true) return nil, fmt.Errorf("values failed validity check:\n%w", err) } lookup.Status.Init( @@ -109,7 +109,7 @@ func (l *Lookup) Refresh( url *string, ) (version string, announceUpdate bool, err error) { serviceID := *l.Status.ServiceID - logFrom := util.LogFrom{Primary: "deployed_version/refresh", Secondary: serviceID} + logFrom := &util.LogFrom{Primary: "deployed_version/refresh", Secondary: serviceID} var lookup *Lookup lookup, err = l.applyOverrides( @@ -122,7 +122,7 @@ func (l *Lookup) Refresh( semanticVersioning, url, &serviceID, - &logFrom) + logFrom) if err != nil { return } @@ -143,7 +143,7 @@ func (l *Lookup) Refresh( regexTemplate != nil // Query the lookup. - version, err = lookup.Query(!overrides, &logFrom) + version, err = lookup.Query(!overrides, logFrom) if err != nil { return } @@ -171,7 +171,7 @@ func basicAuthFromString(jsonStr *string, previous *BasicAuth, logFrom *util.Log // Ignore the JSON if it failed to unmarshal if err != nil { jLog.Error(fmt.Sprintf("Failed converting JSON - %q\n%s", *jsonStr, util.ErrorToString(err)), - *logFrom, true) + logFrom, true) return previous } keys := util.GetKeysFromJSON(*jsonStr) @@ -203,7 +203,7 @@ func headersFromString(jsonStr *string, previous *[]Header, logFrom *util.LogFro // Ignore the JSON if it failed to unmarshal if err != nil { jLog.Error(fmt.Sprintf("Failed converting JSON - %q\n%s", *jsonStr, util.ErrorToString(err)), - *logFrom, true) + logFrom, true) return previous } diff --git a/service/handlers.go b/service/handlers.go index f2edff83..c3fec54a 100644 --- a/service/handlers.go +++ b/service/handlers.go @@ -79,7 +79,7 @@ func (s *Service) HandleUpdateActions(writeToDB bool) { if s.Dashboard.GetAutoApprove() { msg := fmt.Sprintf("Sending WebHooks/Running Commands for %q", s.Status.LatestVersion()) - jLog.Info(msg, util.LogFrom{Primary: s.ID}, true) + jLog.Info(msg, &util.LogFrom{Primary: s.ID}, true) // Run the Command(s) go func() { @@ -97,7 +97,7 @@ func (s *Service) HandleUpdateActions(writeToDB bool) { } }() } else { - jLog.Info("Waiting for approval on the Web UI", util.LogFrom{Primary: s.ID}, true) + jLog.Info("Waiting for approval on the Web UI", &util.LogFrom{Primary: s.ID}, true) s.Status.AnnounceQueryNewVersion() } @@ -185,7 +185,7 @@ func (s *Service) HandleCommand(command string) { // Find the command index := s.CommandController.Find(command) if index == nil { - jLog.Warn(command+" not found", util.LogFrom{Primary: "Command", Secondary: s.ID}, true) + jLog.Warn(command+" not found", &util.LogFrom{Primary: "Command", Secondary: s.ID}, true) return } diff --git a/service/latest_version/filter/docker.go b/service/latest_version/filter/docker.go index f7efca50..4eec7660 100644 --- a/service/latest_version/filter/docker.go +++ b/service/latest_version/filter/docker.go @@ -400,7 +400,7 @@ func (d *DockerCheckDefaults) getToken(dType string) (token string) { // getValidToken looks for an existing queryToken on this registry type that's currently valid. // // empty string if no valid queryToken is found. -func (d *DockerCheck) getValidToken(dType string) (queryToken string) { +func (d *DockerCheck) getValidToken() (queryToken string) { if d == nil { return } @@ -468,7 +468,7 @@ func (d *DockerCheckDefaults) getQueryToken(dType string) (queryToken string, va // getQueryToken for API queries. func (d *DockerCheck) getQueryToken() (queryToken string, err error) { dType := d.GetType() - queryToken = d.getValidToken(dType) + queryToken = d.getValidToken() if queryToken != "" { return } @@ -506,7 +506,7 @@ func (d *DockerCheck) getQueryToken() (queryToken string, err error) { } // Get the refreshed token - queryToken = d.getValidToken(dType) + queryToken = d.getValidToken() return } diff --git a/service/latest_version/filter/docker_test.go b/service/latest_version/filter/docker_test.go index d31f56fd..bccc0312 100644 --- a/service/latest_version/filter/docker_test.go +++ b/service/latest_version/filter/docker_test.go @@ -982,7 +982,7 @@ func TestDockerCheck_getValidToken(t *testing.T) { } // WHEN getValidToken is called on it - got := tc.dockerCheck.getValidToken(tc.dockerCheck.GetType()) + got := tc.dockerCheck.getValidToken() // THEN the token is what we expect if tc.want != got { diff --git a/service/latest_version/filter/regex.go b/service/latest_version/filter/regex.go index f592609b..2aadb5e7 100644 --- a/service/latest_version/filter/regex.go +++ b/service/latest_version/filter/regex.go @@ -39,7 +39,7 @@ func (r *Require) RegexCheckVersion( err := fmt.Errorf("regex not matched on version %q", version) r.Status.RegexMissVersion() - jLog.Info(err, *logFrom, r.Status.RegexMissesVersion() == 1) + jLog.Info(err, logFrom, r.Status.RegexMissesVersion() == 1) return err } @@ -84,7 +84,7 @@ func (r *Require) RegexCheckContent( jLog.Debug( fmt.Sprintf("%q RegexContent on %q, match=%t", r.RegexContent, searchArea[i], regexMatch), - *logFrom, true) + logFrom, true) } if !regexMatch { // if we're on the last asset @@ -94,7 +94,7 @@ func (r *Require) RegexCheckContent( "regex %q not matched on content for version %q", regexStr, version) r.Status.RegexMissContent() - jLog.Info(err, *logFrom, r.Status.RegexMissesContent() == 1) + jLog.Info(err, logFrom, r.Status.RegexMissesContent() == 1) return err } // continue searching the other assets diff --git a/service/latest_version/filter/require.go b/service/latest_version/filter/require.go index 28e4b16f..6d5ad184 100644 --- a/service/latest_version/filter/require.go +++ b/service/latest_version/filter/require.go @@ -158,7 +158,7 @@ func RequireFromStr(jsonStr *string, previous *Require, logFrom *util.LogFrom) ( jLog.Error( fmt.Sprintf("Failed converting JSON - %q\n%s", *jsonStr, util.ErrorToString(err)), - *logFrom, true) + logFrom, true) return nil, fmt.Errorf("require - %w", err) } diff --git a/service/latest_version/filter/urlcommand.go b/service/latest_version/filter/urlcommand.go index 8f7586b6..9665bd50 100644 --- a/service/latest_version/filter/urlcommand.go +++ b/service/latest_version/filter/urlcommand.go @@ -77,15 +77,15 @@ func (s *URLCommandSlice) UnmarshalYAML(unmarshal func(interface{}) error) (err } // Run all of the URLCommand(s) in this URLCommandSlice. -func (s *URLCommandSlice) Run(text string, logFrom util.LogFrom) (string, error) { +func (s *URLCommandSlice) Run(text string, logFrom *util.LogFrom) (string, error) { if s == nil { return text, nil } - logFrom.Secondary = "url_commands" + urlCommandLogFrom := &util.LogFrom{Primary: logFrom.Primary, Secondary: "url_commands"} var err error for commandIndex := range *s { - text, err = (*s)[commandIndex].run(text, &logFrom) + text, err = (*s)[commandIndex].run(text, urlCommandLogFrom) if err != nil { return text, err } @@ -101,7 +101,7 @@ func (c *URLCommand) run(text string, logFrom *util.LogFrom) (string, error) { if jLog.IsLevel("DEBUG") { jLog.Debug( fmt.Sprintf("Looking through:\n%q", text), - *logFrom, true) + logFrom, true) } var msg string @@ -125,7 +125,7 @@ func (c *URLCommand) run(text string, logFrom *util.LogFrom) (string, error) { msg = fmt.Sprintf("%s\nResolved to %s", msg, text) if jLog.IsLevel("DEBUG") { - jLog.Debug(msg, *logFrom, true) + jLog.Debug(msg, logFrom, true) } return text, err } @@ -149,7 +149,7 @@ func (c *URLCommand) regex(text string, logFrom *util.LogFrom) (string, error) { err = fmt.Errorf("%w on %q", err, text) } - jLog.Warn(err, *logFrom, true) + jLog.Warn(err, logFrom, true) return text, err } @@ -157,7 +157,7 @@ func (c *URLCommand) regex(text string, logFrom *util.LogFrom) (string, error) { if (len(texts) - index) < 1 { err := fmt.Errorf("%s (%s) returned %d elements on %q, but the index wants element number %d", c.Type, *c.Regex, len(texts), text, (index + 1)) - jLog.Warn(err, *logFrom, true) + jLog.Warn(err, logFrom, true) return text, err } @@ -173,7 +173,7 @@ func (c *URLCommand) split(text string, logFrom *util.LogFrom) (string, error) { if len(texts) == 1 { err := fmt.Errorf("%s didn't find any %q to split on", c.Type, *c.Text) - jLog.Warn(err, *logFrom, true) + jLog.Warn(err, logFrom, true) return text, err } @@ -187,7 +187,7 @@ func (c *URLCommand) split(text string, logFrom *util.LogFrom) (string, error) { if (len(texts) - index) < 1 { err := fmt.Errorf("%s (%s) returned %d elements on %q, but the index wants element number %d", c.Type, *c.Text, len(texts), text, (index + 1)) - jLog.Warn(err, *logFrom, true) + jLog.Warn(err, logFrom, true) return text, err } @@ -275,7 +275,7 @@ func URLCommandsFromStr(jsonStr *string, defaults *URLCommandSlice, logFrom *uti // Ignore the JSON if it failed to unmarshal if err != nil { jLog.Error(fmt.Sprintf("Failed converting JSON - %q\n%s", *jsonStr, util.ErrorToString(err)), - *logFrom, err != nil) + logFrom, err != nil) return defaults, fmt.Errorf("failed converting JSON - %w", err) } diff --git a/service/latest_version/filter/urlcommand_test.go b/service/latest_version/filter/urlcommand_test.go index 1d6b458b..e1cb08ba 100644 --- a/service/latest_version/filter/urlcommand_test.go +++ b/service/latest_version/filter/urlcommand_test.go @@ -303,7 +303,7 @@ func TestURLCommandSlice_Run(t *testing.T) { if tc.text != "" { text = tc.text } - text, err := tc.slice.Run(text, util.LogFrom{}) + text, err := tc.slice.Run(text, &util.LogFrom{}) // THEN the expected text was returned if tc.want != text { diff --git a/service/latest_version/github.go b/service/latest_version/github.go index 258f8910..a8af6c88 100644 --- a/service/latest_version/github.go +++ b/service/latest_version/github.go @@ -52,7 +52,7 @@ func (l *Lookup) filterGitHubReleases( if tag == "" { tag = releases[i].Name } - if tagName, err = l.URLCommands.Run(tag, *logFrom); err != nil { + if tagName, err = l.URLCommands.Run(tag, logFrom); err != nil { continue } @@ -114,25 +114,25 @@ func (l *Lookup) checkGitHubReleasesBody(body *[]byte, logFrom *util.LogFrom) (r if len(string(*body)) < 500 { if strings.Contains(string(*body), "rate limit") { err = errors.New("rate limit reached for GitHub") - jLog.Warn(err, *logFrom, true) + jLog.Warn(err, logFrom, true) return } if !strings.Contains(string(*body), `"tag_name"`) { err = errors.New("github access token is invalid") - jLog.Error(err, *logFrom, strings.Contains(string(*body), "Bad credentials")) + jLog.Error(err, logFrom, strings.Contains(string(*body), "Bad credentials")) err = fmt.Errorf("tag_name not found at %s\n%s", l.URL, string(*body)) - jLog.Error(err, *logFrom, true) + jLog.Error(err, logFrom, true) return } } if err = json.Unmarshal(*body, &releases); err != nil { - jLog.Error(err, *logFrom, true) + jLog.Error(err, logFrom, true) err = fmt.Errorf("unmarshal of GitHub API data failed\n%w", err) - jLog.Error(err, *logFrom, true) + jLog.Error(err, logFrom, true) } return } diff --git a/service/latest_version/query.go b/service/latest_version/query.go index fb441c31..c8d57648 100644 --- a/service/latest_version/query.go +++ b/service/latest_version/query.go @@ -54,7 +54,7 @@ func (l *Lookup) query(logFrom *util.LogFrom, checkNumber int) (bool, error) { // Verify that the version has changed. (GitHub may have just omitted the tag for some reason) if checkNumber == 0 { msg := fmt.Sprintf("Possibly found a new version (From %q to %q). Checking again", latestVersion, version) - jLog.Verbose(msg, *logFrom, latestVersion != "") + jLog.Verbose(msg, logFrom, latestVersion != "") time.Sleep(time.Second) return l.query(logFrom, 1) } @@ -65,7 +65,7 @@ func (l *Lookup) query(logFrom *util.LogFrom, checkNumber int) (bool, error) { if err != nil { err = fmt.Errorf("failed converting %q to a semantic version. If all versions are in this style, consider adding url_commands to get the version into the style of 'MAJOR.MINOR.PATCH' (https://semver.org/), or disabling semantic versioning (globally with defaults.service.semantic_versioning or just for this service with the semantic_versioning var)", version) - jLog.Error(err, *logFrom, true) + jLog.Error(err, logFrom, true) return false, err } @@ -82,7 +82,7 @@ func (l *Lookup) query(logFrom *util.LogFrom, checkNumber int) (bool, error) { if newVersion.LessThan(oldVersion) { err := fmt.Errorf("queried version %q is less than the deployed version %q", version, l.Status.LatestVersion()) - jLog.Warn(err, *logFrom, true) + jLog.Warn(err, logFrom, true) return false, err } } @@ -99,7 +99,7 @@ func (l *Lookup) query(logFrom *util.LogFrom, checkNumber int) (bool, error) { l.Status.SetDeployedVersion(version, true) } msg := fmt.Sprintf("Latest Release - %q", version) - jLog.Info(msg, *logFrom, true) + jLog.Info(msg, logFrom, true) l.Status.AnnounceFirstVersion() @@ -110,12 +110,12 @@ func (l *Lookup) query(logFrom *util.LogFrom, checkNumber int) (bool, error) { // New version found. l.Status.SetLatestVersion(version, true) msg := fmt.Sprintf("New Release - %q", version) - jLog.Info(msg, *logFrom, true) + jLog.Info(msg, logFrom, true) return true, nil } msg := fmt.Sprintf("Staying on %q as that's the latest version in the second check", version) - jLog.Verbose(msg, *logFrom, checkNumber == 1) + jLog.Verbose(msg, logFrom, checkNumber == 1) // Announce `LastQueried` l.Status.AnnounceQuery() // No version change. @@ -188,7 +188,7 @@ func (l *Lookup) httpRequest(logFrom *util.LogFrom) (rawBody []byte, err error) if err != nil { err = fmt.Errorf("failed creating http request for %q: %w", l.URL, err) - jLog.Error(err, *logFrom, true) + jLog.Error(err, logFrom, true) return } @@ -213,17 +213,17 @@ func (l *Lookup) httpRequest(logFrom *util.LogFrom) (rawBody []byte, err error) // Don't crash on invalid certs. if strings.Contains(err.Error(), "x509") { err = fmt.Errorf("x509 (certificate invalid)") - jLog.Warn(err, *logFrom, true) + jLog.Warn(err, logFrom, true) return } - jLog.Error(err, *logFrom, true) + jLog.Error(err, logFrom, true) return } // Read the response body. defer resp.Body.Close() rawBody, err = io.ReadAll(resp.Body) - jLog.Error(err, *logFrom, err != nil) + jLog.Error(err, logFrom, err != nil) if l.Type == "github" && err == nil { // 200 - Resource has changed if resp.StatusCode == http.StatusOK { @@ -236,13 +236,13 @@ func (l *Lookup) httpRequest(logFrom *util.LogFrom) (rawBody []byte, err error) // Flip the fallback flag l.GitHubData.SetTagFallback() if l.GitHubData.TagFallback() { - jLog.Verbose(fmt.Sprintf("/releases gave %v, trying /tags", string(rawBody)), *logFrom, true) + jLog.Verbose(fmt.Sprintf("/releases gave %v, trying /tags", string(rawBody)), logFrom, true) rawBody, err = l.httpRequest(logFrom) } // Has tags/releases } else { msg := fmt.Sprintf("Potentially found new releases (new ETag %s)", newETag) - jLog.Verbose(msg, *logFrom, true) + jLog.Verbose(msg, logFrom, true) } // 304 - Resource has not changed @@ -252,7 +252,7 @@ func (l *Lookup) httpRequest(logFrom *util.LogFrom) (rawBody []byte, err error) // Flip the fallback flag l.GitHubData.SetTagFallback() if l.GitHubData.TagFallback() { - jLog.Verbose("no tags found on /releases, trying /tags", *logFrom, true) + jLog.Verbose("no tags found on /releases, trying /tags", logFrom, true) rawBody, err = l.httpRequest(logFrom) } } @@ -281,14 +281,14 @@ func (l *Lookup) GetVersions( filteredReleases = l.filterGitHubReleases(logFrom) if len(filteredReleases) == 0 { err = fmt.Errorf("no releases were found matching the url_commands") - jLog.Warn(err, *logFrom, true) + jLog.Warn(err, logFrom, true) return } // url service } else { var version string - version, err = l.URLCommands.Run(body, *logFrom) + version, err = l.URLCommands.Run(body, logFrom) if err != nil { //nolint:wrapcheck return @@ -309,7 +309,7 @@ func (l *Lookup) GetVersion(rawBody []byte, logFrom *util.LogFrom) (version stri } } else if l.Type == "github" { // ReCheck this ETag's filteredReleases incase filters/releases changed - jLog.Verbose("Using cached releases (ETag unchanged)", *logFrom, true) + jLog.Verbose("Using cached releases (ETag unchanged)", logFrom, true) filteredReleases = l.filterGitHubReleases(logFrom) } @@ -354,20 +354,20 @@ func (l *Lookup) GetVersion(rawBody []byte, logFrom *util.LogFrom) (version stri if strings.HasSuffix(err.Error(), "\n") { err = fmt.Errorf(strings.TrimSuffix(err.Error(), "\n")) } - jLog.Warn(err, *logFrom, true) + jLog.Warn(err, logFrom, true) continue // else if the tag does exist (and we did search for one) } else if l.Require.Docker != nil { jLog.Info( fmt.Sprintf(`found %s container "%s:%s"`, l.Require.Docker.GetType(), l.Require.Docker.Image, l.Require.Docker.GetTag(version)), - *logFrom, true) + logFrom, true) } break } if version == "" { err = fmt.Errorf("no releases were found matching the url_commands and/or require") - jLog.Warn(err, *logFrom, true) + jLog.Warn(err, logFrom, true) } return } diff --git a/service/latest_version/refresh.go b/service/latest_version/refresh.go index 33b2bd34..117ab41c 100644 --- a/service/latest_version/refresh.go +++ b/service/latest_version/refresh.go @@ -122,7 +122,7 @@ func (l *Lookup) applyOverrides( } if err := lookup.CheckValues(""); err != nil { - jLog.Error(err, *logFrom, true) + jLog.Error(err, logFrom, true) return nil, fmt.Errorf("values failed validity check:\n%w", err) } @@ -147,7 +147,7 @@ func (l *Lookup) Refresh( usePreRelease *string, ) (version string, announceUpdate bool, err error) { serviceID := *l.Status.ServiceID - logFrom := util.LogFrom{Primary: "latest_version/refresh", Secondary: serviceID} + logFrom := &util.LogFrom{Primary: "latest_version/refresh", Secondary: serviceID} var lookup *Lookup lookup, err = l.applyOverrides( @@ -160,7 +160,7 @@ func (l *Lookup) Refresh( urlCommands, usePreRelease, &serviceID, - &logFrom) + logFrom) if err != nil { return } @@ -180,7 +180,7 @@ func (l *Lookup) Refresh( usePreRelease != nil // Query the lookup. - _, err = lookup.Query(!overrides, &logFrom) + _, err = lookup.Query(!overrides, logFrom) if err != nil { return } diff --git a/service/new.go b/service/new.go index 119c5adb..1f18d1c5 100644 --- a/service/new.go +++ b/service/new.go @@ -85,8 +85,8 @@ func FromPayload( dec1 := json.NewDecoder(bytes.NewReader(buf.Bytes())) err = dec1.Decode(newService) if err != nil { - jLog.Error(err, *logFrom, true) - jLog.Verbose(fmt.Sprintf("Payload: %s", buf.String()), *logFrom, true) + jLog.Error(err, logFrom, true) + jLog.Verbose(fmt.Sprintf("Payload: %s", buf.String()), logFrom, true) return } @@ -95,7 +95,7 @@ func FromPayload( var secretRefs oldSecretRefs err = dec2.Decode(&secretRefs) if err != nil { - jLog.Error(err, *logFrom, true) + jLog.Error(err, logFrom, true) return } diff --git a/service/track.go b/service/track.go index 50dd9465..9c2b900a 100644 --- a/service/track.go +++ b/service/track.go @@ -36,7 +36,7 @@ func (s *Slice) Track(ordering *[]string, orderMutex *sync.RWMutex) { jLog.Verbose( fmt.Sprintf("Tracking %s at %s every %s", (*s)[key].ID, (*s)[key].LatestVersion.ServiceURL(true), (*s)[key].Options.GetInterval()), - util.LogFrom{Primary: (*s)[key].ID}, + &util.LogFrom{Primary: (*s)[key].ID}, true) // Track this Service in a infinite loop goroutine. diff --git a/test/config.go b/test/config.go index a2ddbc77..2208c07c 100644 --- a/test/config.go +++ b/test/config.go @@ -38,15 +38,31 @@ func NilFlags(cfg *config.Config) { cfg.Settings.NilUndefinedFlags(&flagMap) } -func BareConfig() (cfg *config.Config) { +func BareConfig(nilFlags bool) (cfg *config.Config) { cfg = &config.Config{ Settings: config.Settings{ SettingsBase: config.SettingsBase{ Web: config.WebSettings{ RoutePrefix: StringPtr(""), }}}} - NilFlags(cfg) + + // NilFlags can be a RACE condition, so use it conditionally + if nilFlags { + NilFlags(cfg) + } else { + cfg.Settings.FromFlags.Log.Level = nil + cfg.Settings.FromFlags.Log.Timestamps = nil + cfg.Settings.FromFlags.Data.DatabaseFile = nil + cfg.Settings.FromFlags.Web.ListenHost = nil + cfg.Settings.FromFlags.Web.ListenPort = nil + cfg.Settings.FromFlags.Web.CertFile = nil + cfg.Settings.FromFlags.Web.KeyFile = nil + cfg.Settings.FromFlags.Web.RoutePrefix = nil + cfg.Settings.FromFlags.Web.BasicAuth = nil + } + cfg.HardDefaults.SetDefaults() cfg.Settings.SetDefaults() + return } diff --git a/test/config_test.go b/test/config_test.go index 9adfdc4c..985c9645 100644 --- a/test/config_test.go +++ b/test/config_test.go @@ -69,7 +69,7 @@ func TestBareConfig(t *testing.T) { } // WHEN the config is initialized - cfg := BareConfig() + cfg := BareConfig(true) strFlags = map[string]struct { flag *string cfg *string diff --git a/testing/commands.go b/testing/commands.go index e2f35a81..4ddde55c 100644 --- a/testing/commands.go +++ b/testing/commands.go @@ -33,7 +33,7 @@ func CommandTest( if *flag == "" { return } - logFrom := util.LogFrom{Primary: "Testing", Secondary: *flag} + logFrom := &util.LogFrom{Primary: "Testing", Secondary: *flag} log.Info( "", @@ -61,7 +61,7 @@ func CommandTest( service.CommandController == nil) //nolint:errcheck - service.CommandController.Exec(&logFrom) + service.CommandController.Exec(logFrom) if !log.Testing { os.Exit(0) } diff --git a/testing/service.go b/testing/service.go index 1f617347..61d26fb3 100644 --- a/testing/service.go +++ b/testing/service.go @@ -33,7 +33,7 @@ func ServiceTest( if *flag == "" { return } - logFrom := util.LogFrom{Primary: "Testing", Secondary: *flag} + logFrom := &util.LogFrom{Primary: "Testing", Secondary: *flag} log.Info( "", @@ -54,7 +54,7 @@ func ServiceTest( } if service != nil { - _, err := service.LatestVersion.Query(false, &logFrom) + _, err := service.LatestVersion.Query(false, logFrom) if err != nil { helpMsg := "" if service.LatestVersion.Type == "url" && strings.Count(service.LatestVersion.URL, "/") == 1 && !strings.HasPrefix(service.LatestVersion.URL, "http") { @@ -75,7 +75,7 @@ func ServiceTest( // DeployedVersionLookup if service.DeployedVersionLookup != nil { - version, err := service.DeployedVersionLookup.Query(false, &logFrom) + version, err := service.DeployedVersionLookup.Query(false, logFrom) log.Info( fmt.Sprintf( "Deployed version - %q", diff --git a/testing/shoutrrr.go b/testing/shoutrrr.go index a75c7756..8b888deb 100644 --- a/testing/shoutrrr.go +++ b/testing/shoutrrr.go @@ -37,10 +37,10 @@ func NotifyTest( return } shoutrrr.LogInit(log) - logFrom := util.LogFrom{Primary: "Testing", Secondary: *flag} + logFrom := &util.LogFrom{Primary: "Testing", Secondary: *flag} // Find the Shoutrrr to test - notify := findShoutrrr(*flag, cfg, log, &logFrom) + notify := findShoutrrr(*flag, cfg, log, logFrom) // Default webURL if not set if notify.ServiceStatus.WebURL == nil { @@ -112,7 +112,7 @@ func findShoutrrr( // Check if all values are set if err := notify.CheckValues(" "); err != nil { msg := fmt.Sprintf("notify:\n %s:\n%s\n", name, strings.ReplaceAll(err.Error(), "\\", "\n")) - log.Fatal(msg, *logFrom, true) + log.Fatal(msg, logFrom, true) } // Not found @@ -120,7 +120,7 @@ func findShoutrrr( all := getAllShoutrrrNames(cfg) msg := fmt.Sprintf("Notifier %q could not be found in config.notify or in any config.service\nDid you mean one of these?\n - %s\n", name, strings.Join(all, "\n - ")) - log.Fatal(msg, *logFrom, true) + log.Fatal(msg, logFrom, true) } } serviceID := "TESTING" diff --git a/util/log.go b/util/log.go index 520465ad..9f9906a5 100644 --- a/util/log.go +++ b/util/log.go @@ -69,7 +69,7 @@ func (l *JLog) SetLevel(level string) { value := levelMap[level] msg := fmt.Sprintf("%q is not a valid log.level. It should be one of ERROR, WARN, INFO, VERBOSE or DEBUG.", level) - l.Fatal(msg, LogFrom{}, value == 0 && level != "ERROR") + l.Fatal(msg, &LogFrom{}, value == 0 && level != "ERROR") l.Level = uint(value) } @@ -88,7 +88,7 @@ func (l *JLog) SetTimestamps(enable bool) { // from.Primary defined = `from.Primary ` // // from.Secondary defined = `from.Secondary ` -func FormatMessageSource(from LogFrom) (msg string) { +func FormatMessageSource(from *LogFrom) (msg string) { // from.Primary defined if from.Primary != "" { // from.Primary and from.Secondary are defined @@ -128,7 +128,7 @@ func (l *JLog) IsLevel(level string) bool { // Error log the msg. // // (if otherCondition is true) -func (l *JLog) Error(msg interface{}, from LogFrom, otherCondition bool) { +func (l *JLog) Error(msg interface{}, from *LogFrom, otherCondition bool) { if otherCondition { msgString := fmt.Sprintf("%s%v", FormatMessageSource(from), msg) // ERROR: msg from.Primary (from.Secondary) @@ -143,7 +143,7 @@ func (l *JLog) Error(msg interface{}, from LogFrom, otherCondition bool) { // Warn log msg if l.Level is > 0 (WARNING, INFO, VERBOSE or DEBUG). // // (if otherCondition is true) -func (l *JLog) Warn(msg interface{}, from LogFrom, otherCondition bool) { +func (l *JLog) Warn(msg interface{}, from *LogFrom, otherCondition bool) { if l.Level > 0 && otherCondition { msgString := fmt.Sprintf("%s%v", FormatMessageSource(from), msg) // WARNING: msg from.Primary (from.Secondary) @@ -158,7 +158,7 @@ func (l *JLog) Warn(msg interface{}, from LogFrom, otherCondition bool) { // Info log msg if l.Level is > 1 (INFO, VERBOSE or DEBUG). // // (if otherCondition is true) -func (l *JLog) Info(msg interface{}, from LogFrom, otherCondition bool) { +func (l *JLog) Info(msg interface{}, from *LogFrom, otherCondition bool) { if l.Level > 1 && otherCondition { msgString := fmt.Sprintf("%s%v", FormatMessageSource(from), msg) // INFO: msg from.Primary (from.Secondary) @@ -173,7 +173,7 @@ func (l *JLog) Info(msg interface{}, from LogFrom, otherCondition bool) { // Verbose log msg if l.Level is > 2 (VERBOSE or DEBUG). // // (if otherCondition is true) -func (l *JLog) Verbose(msg interface{}, from LogFrom, otherCondition bool) { +func (l *JLog) Verbose(msg interface{}, from *LogFrom, otherCondition bool) { if l.Level > 2 && otherCondition { msgString := fmt.Sprintf("%s%v", FormatMessageSource(from), msg) @@ -194,7 +194,7 @@ func (l *JLog) Verbose(msg interface{}, from LogFrom, otherCondition bool) { // Debug log msg if l.Level is 4 (DEBUG). // // (if otherCondition is true) -func (l *JLog) Debug(msg interface{}, from LogFrom, otherCondition bool) { +func (l *JLog) Debug(msg interface{}, from *LogFrom, otherCondition bool) { if l.Level == 4 && otherCondition { msgString := fmt.Sprintf("%s%v", FormatMessageSource(from), msg) @@ -213,7 +213,7 @@ func (l *JLog) Debug(msg interface{}, from LogFrom, otherCondition bool) { } // Fatal is equivalent to Error() followed by a call to os.Exit(1). -func (l *JLog) Fatal(msg interface{}, from LogFrom, otherCondition bool) { +func (l *JLog) Fatal(msg interface{}, from *LogFrom, otherCondition bool) { if otherCondition { l.Error(msg, from, true) if !l.Testing { diff --git a/util/log_test.go b/util/log_test.go index 1271d4dd..89538838 100644 --- a/util/log_test.go +++ b/util/log_test.go @@ -172,7 +172,7 @@ func TestFormatMessageSource(t *testing.T) { t.Parallel() // WHEN FormatMessageSource is called with this LogFrom - got := FormatMessageSource(tc.logFrom) + got := FormatMessageSource(&tc.logFrom) // THEN an empty string is returned if got != tc.want { @@ -315,7 +315,7 @@ func TestJLog_Error(t *testing.T) { log.SetOutput(&logOut) // WHEN Error is called with true - jLog.Error(fmt.Errorf(msg), LogFrom{}, tc.otherCondition) + jLog.Error(fmt.Errorf(msg), &LogFrom{}, tc.otherCondition) // THEN msg was logged if shouldPrint, with/without timestamps w.Close() @@ -393,7 +393,7 @@ func TestJLog_Warn(t *testing.T) { log.SetOutput(&logOut) // WHEN Warn is called with true - jLog.Warn(fmt.Errorf(msg), LogFrom{}, tc.otherCondition) + jLog.Warn(fmt.Errorf(msg), &LogFrom{}, tc.otherCondition) // THEN msg was logged if shouldPrint, with/without timestamps w.Close() @@ -471,7 +471,7 @@ func TestJLog_Info(t *testing.T) { log.SetOutput(&logOut) // WHEN Info is called with true - jLog.Info(fmt.Errorf(msg), LogFrom{}, tc.otherCondition) + jLog.Info(fmt.Errorf(msg), &LogFrom{}, tc.otherCondition) // THEN msg was logged if shouldPrint, with/without timestamps w.Close() @@ -559,7 +559,7 @@ func TestJLog_Verbose(t *testing.T) { } // WHEN Verbose is called with true - jLog.Verbose(fmt.Errorf(msg), LogFrom{}, tc.otherCondition) + jLog.Verbose(fmt.Errorf(msg), &LogFrom{}, tc.otherCondition) // THEN msg was logged if shouldPrint, with/without timestamps w.Close() @@ -659,7 +659,7 @@ func TestJLog_Debug(t *testing.T) { } // WHEN Debug is called with true - jLog.Debug(fmt.Errorf(msg), LogFrom{}, tc.otherCondition) + jLog.Debug(fmt.Errorf(msg), &LogFrom{}, tc.otherCondition) // THEN msg was logged if shouldPrint, with/without timestamps w.Close() @@ -773,7 +773,7 @@ func TestJLog_Fatal(t *testing.T) { } // WHEN Fatal is called with true - jLog.Fatal(fmt.Errorf(msg), LogFrom{}, tc.otherCondition) + jLog.Fatal(fmt.Errorf(msg), &LogFrom{}, tc.otherCondition) // THEN msg was logged if shouldPrint, with/without timestamps w.Close() diff --git a/util/util_test.go b/util/util_test.go index 8bf568b6..0fc22869 100644 --- a/util/util_test.go +++ b/util/util_test.go @@ -542,12 +542,12 @@ func TestFirstNonDefaultWithEnv(t *testing.T) { diffAddress: true, }, "1 non-default var (env var partial)": { - env: map[string]string{"TESTFIRSTNONDEFAULTWITHENV_ONE": "bar"}, + env: map[string]string{"TESTFIRSTNONDEFAULTWITHENV_TWO": "bar"}, slice: []string{ "", "", "", - "foo${TESTFIRSTNONDEFAULTWITHENV_ONE}"}, + "foo${TESTFIRSTNONDEFAULTWITHENV_TWO}"}, wantIndex: 3, wantText: "foobar", diffAddress: true, @@ -562,13 +562,13 @@ func TestFirstNonDefaultWithEnv(t *testing.T) { }, "2 non-default vars (empty env vars ignored)": { env: map[string]string{ - "TESTFIRSTNONDEFAULTWITHENV_TWO": "", - "TESTFIRSTNONDEFAULTWITHENV_THREE": "bar"}, + "TESTFIRSTNONDEFAULTWITHENV_THREE": "", + "TESTFIRSTNONDEFAULTWITHENV_FOUR": "bar"}, slice: []string{ - "${TESTFIRSTNONDEFAULTWITHENV_TWO}", + "${TESTFIRSTNONDEFAULTWITHENV_THREE}", "${TESTFIRSTNONDEFAULTWITHENV_UNSET}", "", - "${TESTFIRSTNONDEFAULTWITHENV_THREE}"}, + "${TESTFIRSTNONDEFAULTWITHENV_FOUR}"}, wantIndex: 3, wantText: "${TESTFIRSTNONDEFAULTWITHENV_UNSET}", diffAddress: true, diff --git a/web/api/v1/api.go b/web/api/v1/api.go index 760b4f77..796355ed 100644 --- a/web/api/v1/api.go +++ b/web/api/v1/api.go @@ -27,7 +27,6 @@ import ( // API is the API to use for the webserver. type API struct { Config *config.Config - Log *util.JLog BaseRouter *mux.Router Router *mux.Router RoutePrefix string @@ -35,19 +34,23 @@ type API struct { // NewAPI will create a new API with the provided config. func NewAPI(cfg *config.Config, log *util.JLog) *API { + LogInit(log) + baseRouter := mux.NewRouter().StrictSlash(true) routePrefix := cfg.Settings.WebRoutePrefix() api := &API{ Config: cfg, - Log: log, BaseRouter: baseRouter, RoutePrefix: routePrefix, } + + // For cases where routePrefix is "/", remove it to prevent "//" + routePrefix = strings.TrimSuffix(routePrefix, "/") // On baseRouter as Router may have basicAuth - baseRouter.Path(fmt.Sprintf("%s/api/v1/healthcheck", strings.TrimSuffix(routePrefix, "/"))).HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + baseRouter.Path(fmt.Sprintf("%s/api/v1/healthcheck", routePrefix)).HandlerFunc(func(w http.ResponseWriter, r *http.Request) { logFrom := util.LogFrom{Primary: "apiHealthcheck", Secondary: getIP(r)} - api.Log.Verbose("-", logFrom, true) + jLog.Verbose("-", &logFrom, true) w.Header().Set("Connection", "close") fmt.Fprintf(w, "Alive") }) diff --git a/web/api/v1/help_test.go b/web/api/v1/help_test.go index 10bc3df5..1de380c9 100644 --- a/web/api/v1/help_test.go +++ b/web/api/v1/help_test.go @@ -19,6 +19,7 @@ package v1 import ( "os" "strings" + "sync" "testing" "github.com/gorilla/websocket" @@ -37,11 +38,23 @@ import ( "github.com/release-argus/Argus/webhook" ) +var ( + loadMutex sync.Mutex + loadCount int +) + func TestMain(m *testing.M) { // initialize jLog jLog := util.NewJLog("DEBUG", false) jLog.Testing = true - service.LogInit(jLog) + flags := make(map[string]bool) + path := "TestMain.yml" + testYAML_Argus(path) + var config config.Config + config.Load(path, &flags, jLog) + os.Remove(path) + jLog.SetLevel("DEBUG") + LogInit(jLog) // run other tests exitCode := m.Run() @@ -52,9 +65,7 @@ func TestMain(m *testing.M) { func testClient() Client { hub := NewHub() - api := API{} return Client{ - api: &api, hub: hub, ip: "1.1.1.1", conn: &websocket.Conn{}, @@ -62,13 +73,11 @@ func testClient() Client { } } -func testLoad(file string) *config.Config { +func testLoad(file string, jLog *util.JLog) *config.Config { var config config.Config flags := make(map[string]bool) - jLog := util.NewJLog("DEBUG", false) - jLog.Testing = true - config.Load(file, &flags, jLog) + config.Load(file, &flags, nil) announceChannel := make(chan []byte, 8) config.HardDefaults.Service.Status.AnnounceChannel = &announceChannel @@ -77,17 +86,22 @@ func testLoad(file string) *config.Config { func testAPI(name string) API { testYAML_Argus(name) - cfg := testLoad(name) + + // Only give the log once (to avoid potential RACE condition) + var loadLog *util.JLog + loadMutex.Lock() + if loadCount == 0 { + loadLog = jLog + loadCount++ + } + loadMutex.Unlock() + + cfg := testLoad(name, loadLog) accessToken := os.Getenv("GITHUB_TOKEN") if accessToken != "" { cfg.HardDefaults.Service.LatestVersion.AccessToken = &accessToken } - jLog := util.NewJLog("DEBUG", false) - jLog.Testing = true - return API{ - Config: cfg, - Log: jLog, - } + return API{Config: cfg} } func testService(id string) *service.Service { diff --git a/web/api/v1/http-api-actions.go b/web/api/v1/http-api-actions.go index cbebbc24..27f03ae8 100644 --- a/web/api/v1/http-api-actions.go +++ b/web/api/v1/http-api-actions.go @@ -30,18 +30,18 @@ import ( // // Required params: // -// service_name - Service ID to get. +// service_name - Service ID to get the actions of. func (api *API) httpServiceGetActions(w http.ResponseWriter, r *http.Request) { logFrom := util.LogFrom{Primary: "httpServiceActions", Secondary: getIP(r)} targetService, _ := url.QueryUnescape(mux.Vars(r)["service_name"]) - api.Log.Verbose(targetService, logFrom, true) + jLog.Verbose(targetService, &logFrom, true) api.Config.OrderMutex.RLock() svc := api.Config.Service[targetService] defer api.Config.OrderMutex.RUnlock() if svc == nil { err := fmt.Sprintf("service %q not found", targetService) - api.Log.Error(err, logFrom, true) + jLog.Error(err, &logFrom, true) failRequest(&w, err, http.StatusNotFound) return } @@ -70,7 +70,7 @@ func (api *API) httpServiceGetActions(w http.ResponseWriter, r *http.Request) { WebHook: webhookSummary} err := json.NewEncoder(w).Encode(msg) - api.Log.Error(err, logFrom, err != nil) + jLog.Error(err, &logFrom, err != nil) } type RunActionsPayload struct { @@ -81,11 +81,18 @@ type RunActionsPayload struct { // // Required params: // -// service_name - Service ID to get. +// service_name - Service ID to target. +// +// target - The action to take. Can be one of: +// - "ARGUS_ALL" - Approve all actions. +// - "ARGUS_FAILED" - Approve all failed actions. +// - "ARGUS_SKIP" - Skip this release. +// - "webhook_" - Approve a specific WebHook. +// - "command_" - Approve a specific Command. func (api *API) httpServiceRunActions(w http.ResponseWriter, r *http.Request) { - logFrom := util.LogFrom{Primary: "httpServiceRunActions", Secondary: getIP(r)} + logFrom := &util.LogFrom{Primary: "httpServiceRunActions", Secondary: getIP(r)} targetService, _ := url.QueryUnescape(mux.Vars(r)["service_name"]) - api.Log.Verbose(targetService, logFrom, true) + jLog.Verbose(targetService, logFrom, true) // Check the service exists. api.Config.OrderMutex.RLock() @@ -93,7 +100,7 @@ func (api *API) httpServiceRunActions(w http.ResponseWriter, r *http.Request) { defer api.Config.OrderMutex.RUnlock() if svc == nil { err := fmt.Sprintf("service %q not found", targetService) - api.Log.Error(err, logFrom, true) + jLog.Error(err, logFrom, true) failRequest(&w, err, http.StatusNotFound) return } @@ -103,19 +110,19 @@ func (api *API) httpServiceRunActions(w http.ResponseWriter, r *http.Request) { var payload RunActionsPayload err := json.NewDecoder(payloadBytes).Decode(&payload) if err != nil { - api.Log.Error(fmt.Sprintf("Invalid payload - %v", err), logFrom, true) + jLog.Error(fmt.Sprintf("Invalid payload - %v", err), logFrom, true) failRequest(&w, "invalid payload", http.StatusBadRequest) return } if payload.Target == nil { errMsg := "invalid payload, target service not provided" - api.Log.Error(errMsg, logFrom, true) + jLog.Error(errMsg, logFrom, true) failRequest(&w, errMsg, http.StatusBadRequest) return } if !svc.Options.GetActive() { errMsg := "service is inactive, actions can't be run for it" - api.Log.Error(errMsg, logFrom, true) + jLog.Error(errMsg, logFrom, true) failRequest(&w, errMsg, http.StatusBadRequest) return } @@ -124,13 +131,13 @@ func (api *API) httpServiceRunActions(w http.ResponseWriter, r *http.Request) { if *payload.Target == "ARGUS_SKIP" { msg := fmt.Sprintf("%q release skip - %q", targetService, svc.Status.LatestVersion()) - api.Log.Info(msg, logFrom, true) + jLog.Info(msg, logFrom, true) svc.HandleSkip() return } if svc.WebHook == nil && svc.Command == nil { - api.Log.Error(fmt.Sprintf("%q does not have any commands/webhooks to approve", targetService), logFrom, true) + jLog.Error(fmt.Sprintf("%q does not have any commands/webhooks to approve", targetService), logFrom, true) return } @@ -145,7 +152,7 @@ func (api *API) httpServiceRunActions(w http.ResponseWriter, r *http.Request) { "ARGUS_FAILED", "ALL UNSENT/FAILED"), "ARGUS_SKIP", "SKIP"), ) - api.Log.Info(msg, logFrom, true) + jLog.Info(msg, logFrom, true) switch *payload.Target { case "ARGUS_ALL", "ARGUS_FAILED": go svc.HandleFailedActions() diff --git a/web/api/v1/http-api-actions_test.go b/web/api/v1/http-api-actions_test.go index d1d2f7bf..66014044 100644 --- a/web/api/v1/http-api-actions_test.go +++ b/web/api/v1/http-api-actions_test.go @@ -284,7 +284,7 @@ func TestHTTP_httpServiceRunActions(t *testing.T) { "invalid payload": { serviceID: "__name__", payload: test.StringPtr("target: foo"), // not JSON - stdoutRegex: `invalid payload`, + stdoutRegex: `Invalid payload - invalid character`, }, "ARGUS_SKIP known service_id": { serviceID: "__name__", diff --git a/web/api/v1/http-api-config.go b/web/api/v1/http-api-config.go index 275c6272..54435884 100644 --- a/web/api/v1/http-api-config.go +++ b/web/api/v1/http-api-config.go @@ -24,8 +24,8 @@ import ( // wsConfig handles getting the config that's in use and sending it as YAML. func (api *API) httpConfig(w http.ResponseWriter, r *http.Request) { - logFrom := util.LogFrom{Primary: "httpConfig", Secondary: getIP(r)} - api.Log.Verbose("-", logFrom, true) + logFrom := &util.LogFrom{Primary: "httpConfig", Secondary: getIP(r)} + jLog.Verbose("-", logFrom, true) cfg := &api_type.Config{} @@ -108,5 +108,5 @@ func (api *API) httpConfig(w http.ResponseWriter, r *http.Request) { api.Config.OrderMutex.RUnlock() err := json.NewEncoder(w).Encode(cfg) - api.Log.Error(err, logFrom, err != nil) + jLog.Error(err, logFrom, err != nil) } diff --git a/web/api/v1/http-api-edit.go b/web/api/v1/http-api-edit.go index cdfbb84b..5af15925 100644 --- a/web/api/v1/http-api-edit.go +++ b/web/api/v1/http-api-edit.go @@ -46,8 +46,8 @@ func (api *API) httpVersionRefreshUncreated(w http.ResponseWriter, r *http.Reque if deployedVersionRefresh { logFromPrimary = "httpVersionRefreshUncreated_Deployed" } - logFrom := util.LogFrom{Primary: logFromPrimary, Secondary: getIP(r)} - api.Log.Verbose("-", logFrom, true) + logFrom := &util.LogFrom{Primary: logFromPrimary, Secondary: getIP(r)} + jLog.Verbose("-", logFrom, true) // Set headers w.Header().Set("Connection", "close") @@ -114,7 +114,7 @@ func (api *API) httpVersionRefreshUncreated(w http.ResponseWriter, r *http.Reque Error: util.ErrorToString(err), Date: time.Now().UTC(), }) - api.Log.Error(err, logFrom, err != nil) + jLog.Error(err, logFrom, err != nil) } // httpVersionRefresh refreshes the latest/deployed version of the target service. @@ -133,8 +133,8 @@ func (api *API) httpVersionRefresh(w http.ResponseWriter, r *http.Request) { if deployedVersionRefresh { logFromPrimary = "httpVersionRefresh_Deployed" } - logFrom := util.LogFrom{Primary: logFromPrimary, Secondary: getIP(r)} - api.Log.Verbose(targetService, logFrom, true) + logFrom := &util.LogFrom{Primary: logFromPrimary, Secondary: getIP(r)} + jLog.Verbose(targetService, logFrom, true) // Set headers w.Header().Set("Connection", "close") @@ -147,7 +147,7 @@ func (api *API) httpVersionRefresh(w http.ResponseWriter, r *http.Request) { defer api.Config.OrderMutex.RUnlock() if api.Config.Service[targetService] == nil { err := fmt.Sprintf("service %q not found", targetService) - api.Log.Error(err, logFrom, true) + jLog.Error(err, logFrom, true) failRequest(&w, err, http.StatusNotFound) return } @@ -206,7 +206,7 @@ func (api *API) httpVersionRefresh(w http.ResponseWriter, r *http.Request) { Error: util.ErrorToString(err), Date: time.Now().UTC(), }) - api.Log.Error(err, logFrom, err != nil) + jLog.Error(err, logFrom, err != nil) } // httpServiceDetail handles sending details about a Service @@ -218,8 +218,8 @@ func (api *API) httpServiceDetail(w http.ResponseWriter, r *http.Request) { // service to get details from (empty for create new) targetService, _ := url.QueryUnescape(mux.Vars(r)["service_name"]) - logFrom := util.LogFrom{Primary: "httpServiceDetail", Secondary: getIP(r)} - api.Log.Verbose(targetService, logFrom, true) + logFrom := &util.LogFrom{Primary: "httpServiceDetail", Secondary: getIP(r)} + jLog.Verbose(targetService, logFrom, true) // Set Headers w.Header().Set("Connection", "close") @@ -234,7 +234,7 @@ func (api *API) httpServiceDetail(w http.ResponseWriter, r *http.Request) { if svc == nil { err := fmt.Sprintf("service %q not found", targetService) - api.Log.Error(err, logFrom, true) + jLog.Error(err, logFrom, true) failRequest(&w, err, http.StatusNotFound) return } @@ -253,7 +253,7 @@ func (api *API) httpServiceDetail(w http.ResponseWriter, r *http.Request) { } err := json.NewEncoder(w).Encode(serviceJSON) - api.Log.Error(err, logFrom, err != nil) + jLog.Error(err, logFrom, err != nil) } // httpOtherServiceDetails handles sending details about the global notify/webhook's, defaults and hard defaults. @@ -261,8 +261,8 @@ func (api *API) httpServiceDetail(w http.ResponseWriter, r *http.Request) { // # GET func (api *API) httpOtherServiceDetails(w http.ResponseWriter, r *http.Request) { logFromPrimary := "httpOtherServiceDetails" - logFrom := util.LogFrom{Primary: logFromPrimary, Secondary: getIP(r)} - api.Log.Verbose("-", logFrom, true) + logFrom := &util.LogFrom{Primary: logFromPrimary, Secondary: getIP(r)} + jLog.Verbose("-", logFrom, true) // Set headers w.Header().Set("Connection", "close") @@ -275,7 +275,7 @@ func (api *API) httpOtherServiceDetails(w http.ResponseWriter, r *http.Request) Notify: convertAndCensorNotifySliceDefaults(&api.Config.Notify), WebHook: convertAndCensorWebHookSliceDefaults(&api.Config.WebHook), }) - api.Log.Error(err, logFrom, err != nil) + jLog.Error(err, logFrom, err != nil) } // httpServiceEdit handles creating/editing a Service. @@ -298,8 +298,8 @@ func (api *API) httpServiceEdit(w http.ResponseWriter, r *http.Request) { reqType = "edit" } - logFrom := util.LogFrom{Primary: "httpServiceEdit", Secondary: getIP(r)} - api.Log.Verbose(fmt.Sprintf("%s %s", reqType, targetService), + logFrom := &util.LogFrom{Primary: "httpServiceEdit", Secondary: getIP(r)} + jLog.Verbose(fmt.Sprintf("%s %s", reqType, targetService), logFrom, true) w.Header().Set("Connection", "close") @@ -331,9 +331,9 @@ func (api *API) httpServiceEdit(w http.ResponseWriter, r *http.Request) { &api.Config.WebHook, &api.Config.Defaults.WebHook, &api.Config.HardDefaults.WebHook, - &logFrom) + logFrom) if err != nil { - api.Log.Error(err, logFrom, true) + jLog.Error(err, logFrom, true) failRequest(&w, fmt.Sprintf(`%s %q failed (invalid json)\%s`, reqType, targetService, err.Error())) return @@ -349,7 +349,7 @@ func (api *API) httpServiceEdit(w http.ResponseWriter, r *http.Request) { // Check the values err = newService.CheckValues("") if err != nil { - api.Log.Error(err, logFrom, true) + jLog.Error(err, logFrom, true) // Remove the service name from the error err = errors.New(strings.Join(strings.Split(err.Error(), `\`)[1:], `\`)) @@ -361,7 +361,7 @@ func (api *API) httpServiceEdit(w http.ResponseWriter, r *http.Request) { // Ensure LatestVersion and DeployedVersion (if set) can fetch err = newService.CheckFetches() if err != nil { - api.Log.Error(err, logFrom, true) + jLog.Error(err, logFrom, true) failRequest(&w, fmt.Sprintf(`%s %q failed (fetches failed)\%s`, reqType, util.FirstNonDefault(targetService, newService.ID), err.Error())) @@ -395,8 +395,8 @@ func (api *API) httpServiceDelete(w http.ResponseWriter, r *http.Request) { targetService, _ := url.QueryUnescape(mux.Vars(r)["service_name"]) logFromPrimary := "httpServiceDelete" - logFrom := util.LogFrom{Primary: logFromPrimary, Secondary: getIP(r)} - api.Log.Verbose(targetService, logFrom, true) + logFrom := &util.LogFrom{Primary: logFromPrimary, Secondary: getIP(r)} + jLog.Verbose(targetService, logFrom, true) // If service doesn't exist, return 404. if api.Config.Service[targetService] == nil { @@ -436,21 +436,21 @@ func (api *API) httpNotifyTest(w http.ResponseWriter, r *http.Request) { w.Header().Set("Connection", "close") w.Header().Set("Content-Type", "application/json") - logFrom := util.LogFrom{Primary: "httpNotifyTest", Secondary: getIP(r)} - api.Log.Verbose("-", logFrom, true) + logFrom := &util.LogFrom{Primary: "httpNotifyTest", Secondary: getIP(r)} + jLog.Verbose("-", logFrom, true) // Payload payload := http.MaxBytesReader(w, r.Body, 102400) var buf bytes.Buffer if _, err := buf.ReadFrom(payload); err != nil { - api.Log.Error(err, logFrom, true) + jLog.Error(err, logFrom, true) failRequest(&w, err.Error()) return } var parsedPayload shoutrrr.TestPayload err := json.Unmarshal(buf.Bytes(), &parsedPayload) if err != nil { - api.Log.Error(err, logFrom, true) + jLog.Error(err, logFrom, true) failRequest(&w, err.Error()) return } @@ -477,7 +477,7 @@ func (api *API) httpNotifyTest(w http.ResponseWriter, r *http.Request) { api.Config.Defaults.Notify, api.Config.HardDefaults.Notify) if err != nil { - api.Log.Error(err, logFrom, true) + jLog.Error(err, logFrom, true) failRequest(&w, err.Error()) return } @@ -486,7 +486,7 @@ func (api *API) httpNotifyTest(w http.ResponseWriter, r *http.Request) { // Test the notify err = testNotify.TestSend(serviceURL) if err != nil { - api.Log.Error(err, logFrom, true) + jLog.Error(err, logFrom, true) failRequest(&w, err.Error()) return } diff --git a/web/api/v1/http-api-flags.go b/web/api/v1/http-api-flags.go index 7488b39e..b7e6ab9b 100644 --- a/web/api/v1/http-api-flags.go +++ b/web/api/v1/http-api-flags.go @@ -24,8 +24,8 @@ import ( // httpFlags returns the values of vars that can be set with flags. func (api *API) httpFlags(w http.ResponseWriter, r *http.Request) { - logFrom := util.LogFrom{Primary: "httpFlags", Secondary: getIP(r)} - api.Log.Verbose("-", logFrom, true) + logFrom := &util.LogFrom{Primary: "httpFlags", Secondary: getIP(r)} + jLog.Verbose("-", logFrom, true) // Create and send status page data msg := api_type.Flags{ @@ -40,5 +40,5 @@ func (api *API) httpFlags(w http.ResponseWriter, r *http.Request) { WebRoutePrefix: api.Config.Settings.WebRoutePrefix()} err := json.NewEncoder(w).Encode(msg) - api.Log.Error(err, logFrom, err != nil) + jLog.Error(err, logFrom, err != nil) } diff --git a/web/api/v1/http-api-service.go b/web/api/v1/http-api-service.go index 7d843530..3a726c04 100644 --- a/web/api/v1/http-api-service.go +++ b/web/api/v1/http-api-service.go @@ -29,19 +29,19 @@ type ServiceOrderAPI struct { } func (api *API) httpServiceOrder(w http.ResponseWriter, r *http.Request) { - logFrom := util.LogFrom{Primary: "httpServiceOrder", Secondary: getIP(r)} - api.Log.Verbose("-", logFrom, true) + logFrom := &util.LogFrom{Primary: "httpServiceOrder", Secondary: getIP(r)} + jLog.Verbose("-", logFrom, true) api.Config.OrderMutex.RLock() defer api.Config.OrderMutex.RUnlock() err := json.NewEncoder(w).Encode(ServiceOrderAPI{Order: api.Config.Order}) - api.Log.Error(err, logFrom, err != nil) + jLog.Error(err, logFrom, err != nil) } func (api *API) httpServiceSummary(w http.ResponseWriter, r *http.Request) { - logFrom := util.LogFrom{Primary: "httpServiceSummary", Secondary: getIP(r)} + logFrom := &util.LogFrom{Primary: "httpServiceSummary", Secondary: getIP(r)} targetService, _ := url.QueryUnescape(mux.Vars(r)["service_name"]) - api.Log.Verbose(targetService, logFrom, true) + jLog.Verbose(targetService, logFrom, true) // Check Service still exists in this ordering api.Config.OrderMutex.RLock() @@ -49,7 +49,7 @@ func (api *API) httpServiceSummary(w http.ResponseWriter, r *http.Request) { service := api.Config.Service[targetService] if service == nil { err := fmt.Sprintf("service %q not found", targetService) - api.Log.Error(err, logFrom, true) + jLog.Error(err, logFrom, true) failRequest(&w, err, http.StatusNotFound) return } @@ -58,5 +58,5 @@ func (api *API) httpServiceSummary(w http.ResponseWriter, r *http.Request) { summary := service.Summary() err := json.NewEncoder(w).Encode(summary) - api.Log.Error(err, logFrom, err != nil) + jLog.Error(err, logFrom, err != nil) } diff --git a/web/api/v1/http-api-status.go b/web/api/v1/http-api-status.go index 3f61b482..0f6bb797 100644 --- a/web/api/v1/http-api-status.go +++ b/web/api/v1/http-api-status.go @@ -26,8 +26,8 @@ import ( // httpRuntimeInfo returns runtime info about the server. func (api *API) httpRuntimeInfo(w http.ResponseWriter, r *http.Request) { - logFrom := util.LogFrom{Primary: "httpBuildInfo", Secondary: getIP(r)} - api.Log.Verbose("-", logFrom, true) + logFrom := &util.LogFrom{Primary: "httpBuildInfo", Secondary: getIP(r)} + jLog.Verbose("-", logFrom, true) // Create and send status page data msg := api_type.RuntimeInfo{ @@ -39,5 +39,5 @@ func (api *API) httpRuntimeInfo(w http.ResponseWriter, r *http.Request) { GoDebug: os.Getenv("GODEBUG")} err := json.NewEncoder(w).Encode(msg) - api.Log.Error(err, logFrom, err != nil) + jLog.Error(err, logFrom, err != nil) } diff --git a/web/api/v1/http.go b/web/api/v1/http.go index a93a2777..780d7fcd 100644 --- a/web/api/v1/http.go +++ b/web/api/v1/http.go @@ -194,8 +194,8 @@ func (api *API) SetupRoutesFavicon() { // httpVersion serves Argus version JSON over HTTP. func (api *API) httpVersion(w http.ResponseWriter, r *http.Request) { - logFrom := util.LogFrom{Primary: "httpVersion", Secondary: getIP(r)} - api.Log.Verbose("-", logFrom, true) + logFrom := &util.LogFrom{Primary: "httpVersion", Secondary: getIP(r)} + jLog.Verbose("-", logFrom, true) // Set headers w.Header().Set("Connection", "close") @@ -206,7 +206,7 @@ func (api *API) httpVersion(w http.ResponseWriter, r *http.Request) { BuildDate: util.BuildDate, GoVersion: util.GoVersion, }) - api.Log.Error(err, logFrom, err != nil) + jLog.Error(err, logFrom, err != nil) } // failRequest with a JSON response containing a message and status code. diff --git a/web/api/v1/http_test.go b/web/api/v1/http_test.go index 0636a802..4b11bbc7 100644 --- a/web/api/v1/http_test.go +++ b/web/api/v1/http_test.go @@ -35,7 +35,6 @@ import ( func TestHTTP_Version(t *testing.T) { // GIVEN an API and the Version,BuildDate and GoVersion vars defined api := API{} - api.Log = util.NewJLog("WARN", false) util.Version = "1.2.3" util.BuildDate = "2022-01-01T01:01:01Z" @@ -164,7 +163,7 @@ func TestHTTP_SetupRoutesFavicon(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { - cfg := test.BareConfig() + cfg := test.BareConfig(true) cfg.Settings.Web.Favicon = testFaviconSettings(tc.urlPNG, tc.urlSVG) api := NewAPI(cfg, util.NewJLog("WARN", false)) api.SetupRoutesFavicon() @@ -402,6 +401,8 @@ func TestHTTP_DisableRoutes(t *testing.T) { }`, }, } + log := util.NewJLog("WARN", false) + log.Testing = true disableCombinations := test.Combinations(util.SortedKeys(tests)) // Split tests into groups @@ -421,7 +422,7 @@ func TestHTTP_DisableRoutes(t *testing.T) { } t.Run(strings.Join(disabledRoutes, ";"), func(t *testing.T) { - cfg := test.BareConfig() + cfg := test.BareConfig(false) cfg.Settings.Web.DisabledRoutes = disabledRoutes // Give every other test a route prefix routePrefix := "" @@ -429,7 +430,7 @@ func TestHTTP_DisableRoutes(t *testing.T) { routePrefix = "/test" cfg.Settings.Web.RoutePrefix = &routePrefix } - api := NewAPI(cfg, util.NewJLog("WARN", false)) + api := NewAPI(cfg, log) api.SetupRoutesAPI() ts := httptest.NewServer(api.Router) ts.Config.Handler = api.Router diff --git a/web/api/v1/util.go b/web/api/v1/util.go index a17ca3f7..0bd8b74e 100644 --- a/web/api/v1/util.go +++ b/web/api/v1/util.go @@ -30,6 +30,18 @@ import ( "github.com/release-argus/Argus/webhook" ) +var ( + jLog *util.JLog +) + +// LogInit for this package. +func LogInit(log *util.JLog) { + // Only set the log if it hasn't been set (avoid RACE condition) + if jLog == nil { + jLog = log + } +} + // getParam from a URL query string func getParam(queryParams *url.Values, param string) *string { if queryParams.Has(param) { diff --git a/web/api/v1/websocket-client.go b/web/api/v1/websocket-client.go index f4d1d970..9ff0cbb3 100644 --- a/web/api/v1/websocket-client.go +++ b/web/api/v1/websocket-client.go @@ -58,9 +58,6 @@ var upgrader = websocket.Upgrader{ // Client is a middleman between the websocket connection and the hub. type Client struct { - // The API. - api *API - // The WebSocket hub. hub *Hub @@ -126,9 +123,9 @@ func (c *Client) readPump() { defer func() { c.hub.unregister <- c err := c.conn.Close() - c.api.Log.Verbose( + jLog.Verbose( fmt.Sprintf("Closing the websocket connection failed (readPump)\n%s", util.ErrorToString(err)), - util.LogFrom{}, + &util.LogFrom{}, err != nil, ) }() @@ -157,10 +154,10 @@ func (c *Client) readPump() { break } - if c.api.Log.IsLevel("DEBUG") { - c.api.Log.Debug( + if jLog.IsLevel("DEBUG") { + jLog.Debug( fmt.Sprintf("READ %s", message), - util.LogFrom{Primary: "WebSocket", Secondary: c.ip}, true) + &util.LogFrom{Primary: "WebSocket", Secondary: c.ip}, true) } message = bytes.TrimSpace(bytes.ReplaceAll(message, newline, space)) @@ -168,9 +165,9 @@ func (c *Client) readPump() { var validation serverMessage err = json.Unmarshal(message, &validation) if err != nil { - c.api.Log.Warn( + jLog.Warn( fmt.Sprintf("Invalid message (missing/invalid version key)\n%s", message), - util.LogFrom{Primary: "WebSocket", Secondary: c.ip}, + &util.LogFrom{Primary: "WebSocket", Secondary: c.ip}, true, ) continue @@ -190,9 +187,9 @@ func (c *Client) writePump() { defer func() { ticker.Stop() err := c.conn.Close() - c.api.Log.Verbose( + jLog.Verbose( fmt.Sprintf("Closing the connection\n%s", util.ErrorToString(err)), - util.LogFrom{Primary: "WebSocket", Secondary: c.ip}, + &util.LogFrom{Primary: "WebSocket", Secondary: c.ip}, true, ) }() @@ -206,9 +203,9 @@ func (c *Client) writePump() { if !ok { // The hub closed the channel. err := c.conn.WriteMessage(websocket.CloseMessage, []byte{}) - c.api.Log.Verbose( + jLog.Verbose( fmt.Sprintf("Closing the connection (writePump)\n%s", util.ErrorToString(err)), - util.LogFrom{Primary: "WebSocket", Secondary: c.ip}, + &util.LogFrom{Primary: "WebSocket", Secondary: c.ip}, true, ) return @@ -217,9 +214,9 @@ func (c *Client) writePump() { var msg api_type.WebSocketMessage err := json.Unmarshal(message, &msg) if err != nil { - c.api.Log.Error( + jLog.Error( fmt.Sprintf("Message failed to unmarshal %s", util.ErrorToString(err)), - util.LogFrom{Primary: "WebSocket", Secondary: c.ip}, + &util.LogFrom{Primary: "WebSocket", Secondary: c.ip}, true, ) continue @@ -233,15 +230,15 @@ func (c *Client) writePump() { switch msg.Type { case "VERSION", "WEBHOOK", "COMMAND", "SERVICE", "EDIT", "DELETE": err := c.conn.WriteJSON(msg) - c.api.Log.Error( + jLog.Error( fmt.Sprintf("Writing JSON to the websocket failed for %s\n%s", msg.Type, util.ErrorToString(err)), - util.LogFrom{Primary: "WebSocket", Secondary: c.ip}, + &util.LogFrom{Primary: "WebSocket", Secondary: c.ip}, err != nil, ) default: - c.api.Log.Error( + jLog.Error( fmt.Sprintf("Unknown TYPE %q\nFull message: %s", msg.Type, string(message)), - util.LogFrom{Primary: "WebSocket", Secondary: c.ip}, + &util.LogFrom{Primary: "WebSocket", Secondary: c.ip}, true, ) continue @@ -252,9 +249,9 @@ func (c *Client) writePump() { n := len(c.send) for i := 0; i < n; i++ { err := c.conn.WriteJSON(<-c.send) - c.api.Log.Error( + jLog.Error( fmt.Sprintf("WriteJSON for the queued chat messages\n%s\n", util.ErrorToString(err)), - util.LogFrom{Primary: "WebSocket", Secondary: c.ip}, + &util.LogFrom{Primary: "WebSocket", Secondary: c.ip}, err != nil, ) } @@ -280,7 +277,6 @@ func ServeWs(api *API, hub *Hub, w http.ResponseWriter, r *http.Request) { conn.RemoteAddr() client := &Client{ - api: api, hub: hub, ip: getIP(r), conn: conn, diff --git a/web/api/v1/websocket-hub.go b/web/api/v1/websocket-hub.go index fbca7f9e..6b8b3c85 100644 --- a/web/api/v1/websocket-hub.go +++ b/web/api/v1/websocket-hub.go @@ -56,11 +56,14 @@ type AnnounceMSG struct { } // Run will start the WebSocket Hub. -func (h *Hub) Run(jLog *util.JLog) { +func (h *Hub) Run() { for { select { case client := <-h.register: - h.clients[client] = true + // Avoid unnecessary writes to the map + if _, ok := h.clients[client]; !ok { + h.clients[client] = true + } case client := <-h.unregister: if _, ok := h.clients[client]; ok { delete(h.clients, client) @@ -72,14 +75,14 @@ func (h *Hub) Run(jLog *util.JLog) { if jLog.IsLevel("DEBUG") { jLog.Debug( fmt.Sprintf("Broadcast %s", string(message)), - util.LogFrom{Primary: "WebSocket"}, + &util.LogFrom{Primary: "WebSocket"}, len(h.clients) > 0) } var msg AnnounceMSG if err := json.Unmarshal(message, &msg); err != nil { jLog.Warn( "Invalid JSON broadcast to the WebSocket", - util.LogFrom{Primary: "WebSocket"}, + &util.LogFrom{Primary: "WebSocket"}, true, ) n = len(h.Broadcast) diff --git a/web/api/v1/websocket-hub_test.go b/web/api/v1/websocket-hub_test.go index 2d9ab9ae..2126f167 100644 --- a/web/api/v1/websocket-hub_test.go +++ b/web/api/v1/websocket-hub_test.go @@ -20,8 +20,6 @@ import ( "encoding/json" "testing" "time" - - "github.com/release-argus/Argus/util" ) func TestNewHub(t *testing.T) { @@ -46,38 +44,41 @@ func TestNewHub(t *testing.T) { } func TestHub_RunWithRegister(t *testing.T) { - // GIVEN a WebSocket Hub and API + // GIVEN a WebSocket Hub and two clients hub := NewHub() - api := API{} - go hub.Run(util.NewJLog("WARN", false)) - time.Sleep(time.Second) - - // WHEN a new client connects + go hub.Run() client := testClient() - client.api = &api - client.hub = hub + otherClient := testClient() + + // WHEN a new client connects (two for synchronisation) hub.register <- &client - time.Sleep(time.Second) + hub.register <- &otherClient + hub.register <- &otherClient // THEN that client is registered to the Hub - // DATA RACE, but just for testing + // DATA RACE - Unsure why as register is a second before this read if !hub.clients[&client] { - t.Error("Client wasn't registerd to the Hub") + t.Error("Client wasn't registered to the Hub") } } func TestHub_RunWithUnregister(t *testing.T) { // GIVEN a Client is connected to the WebSocket Hub client := testClient() + otherClient := testClient() hub := client.hub - go hub.Run(util.NewJLog("WARN", false)) - time.Sleep(time.Second) + go hub.Run() hub.register <- &client - time.Sleep(time.Second) + hub.register <- &otherClient + hub.register <- &otherClient + if !hub.clients[&client] { + t.Error("Client wasn't registered to the Hub") + } - // WHEN that client disconnects - hub.unregister <- &client + // WHEN that client disconnects (two for synchronisation) hub.unregister <- &client + hub.unregister <- &otherClient + hub.unregister <- &otherClient // THEN that client is unregistered to the Hub if hub.clients[&client] { @@ -91,7 +92,7 @@ func TestHub_RunWithBroadcast(t *testing.T) { // and a valid message wants to be sent client := testClient() hub := client.hub - go hub.Run(util.NewJLog("DEBUG", false)) + go hub.Run() time.Sleep(time.Second) hub.register <- &client time.Sleep(2 * time.Second) @@ -120,7 +121,7 @@ func TestHub_RunWithInvalidBroadcast(t *testing.T) { // and an invalid message wants to be sent client := testClient() hub := client.hub - go hub.Run(util.NewJLog("WARN", false)) + go hub.Run() time.Sleep(time.Second) hub.register <- &client time.Sleep(time.Second) diff --git a/web/help_test.go b/web/help_test.go index 20d1596b..2b5e1657 100644 --- a/web/help_test.go +++ b/web/help_test.go @@ -25,7 +25,6 @@ import ( "fmt" "math/big" "net" - "net/http" "os" "sync" "testing" @@ -68,26 +67,27 @@ func stringifyPointer[T comparable](ptr *T) string { func TestMain(m *testing.M) { // initialize jLog - jLog = util.NewJLog("DEBUG", false) + jLog := util.NewJLog("DEBUG", false) jLog.Testing = true // GIVEN a valid config with a Service file := "TestMain.yml" - mainCfg = testConfig(file, nil) + mainCfg = testConfig(file, jLog, nil) os.Remove(file) defer os.Remove(*mainCfg.Settings.Data.DatabaseFile) port = mainCfg.Settings.Web.ListenPort + mainCfg.Settings.Web.ListenHost = stringPtr("localhost") // WHEN the Router is fetched for this Config router = newWebUI(mainCfg) - go http.ListenAndServe("localhost:"+*port, router) + go Run(mainCfg, jLog) // THEN Web UI is accessible for the tests code := m.Run() os.Exit(code) } -func testConfig(path string, t *testing.T) (cfg *config.Config) { +func testConfig(path string, jLog *util.JLog, t *testing.T) (cfg *config.Config) { testYAML_Argus(path, t) cfg = &config.Config{} diff --git a/web/web.go b/web/web.go index dd288feb..94172281 100644 --- a/web/web.go +++ b/web/web.go @@ -29,7 +29,7 @@ var jLog *util.JLog // NewRouter that serves the Prometheus metrics, // WebSocket and NodeJS frontend at the RoutePrefix. -func NewRouter(cfg *config.Config, jLog *util.JLog, hub *api_v1.Hub) *mux.Router { +func NewRouter(cfg *config.Config, hub *api_v1.Hub) *mux.Router { // Go api := api_v1.NewAPI(cfg, jLog) @@ -54,8 +54,8 @@ func NewRouter(cfg *config.Config, jLog *util.JLog, hub *api_v1.Hub) *mux.Router // newWebUI will set up everything web-related for Argus. func newWebUI(cfg *config.Config) *mux.Router { hub := api_v1.NewHub() - go hub.Run(jLog) - router := NewRouter(cfg, jLog, hub) + go hub.Run() + router := NewRouter(cfg, hub) // Hand out the broadcast channel cfg.HardDefaults.Service.Status.AnnounceChannel = &hub.Broadcast @@ -67,21 +67,25 @@ func newWebUI(cfg *config.Config) *mux.Router { } func Run(cfg *config.Config, log *util.JLog) { - jLog = log + // Only set if unset (avoid RACE condition in tests) + if log != nil && jLog == nil { + jLog = log + } + router := newWebUI(cfg) listenAddress := fmt.Sprintf("%s:%s", cfg.Settings.WebListenHost(), cfg.Settings.WebListenPort()) - jLog.Info("Listening on "+listenAddress+cfg.Settings.WebRoutePrefix(), util.LogFrom{}, true) + jLog.Info("Listening on "+listenAddress+cfg.Settings.WebRoutePrefix(), &util.LogFrom{}, true) if cfg.Settings.WebCertFile() != nil && cfg.Settings.WebKeyFile() != nil { jLog.Fatal( http.ListenAndServeTLS( listenAddress, *cfg.Settings.WebCertFile(), *cfg.Settings.WebKeyFile(), router), - util.LogFrom{}, true) + &util.LogFrom{}, true) } else { jLog.Fatal( http.ListenAndServe( listenAddress, router), - util.LogFrom{}, true) + &util.LogFrom{}, true) } } diff --git a/web/web_test.go b/web/web_test.go index 2b4445a5..d4ab9cf9 100644 --- a/web/web_test.go +++ b/web/web_test.go @@ -37,7 +37,7 @@ var router *mux.Router func TestMainWithRoutePrefix(t *testing.T) { // GIVEN a valid config with a Service - cfg := testConfig("TestMainWithRoutePrefix.yml", t) + cfg := testConfig("TestMainWithRoutePrefix.yml", nil, t) *cfg.Settings.Web.RoutePrefix = "/test" // WHEN the Web UI is started with this Config @@ -125,7 +125,7 @@ func TestAccessibleHTTPS(t *testing.T) { bodyRegex: fmt.Sprintf(`"goVersion":"%s"`, util.GoVersion)}, } - cfg := testConfig("TestAccessibleHTTPS.yml", t) + cfg := testConfig("TestAccessibleHTTPS.yml", nil, t) cfg.Settings.Web.CertFile = stringPtr("TestAccessibleHTTPS_cert.pem") cfg.Settings.Web.KeyFile = stringPtr("TestAccessibleHTTPS_key.pem") generateCertFiles(*cfg.Settings.Web.CertFile, *cfg.Settings.Web.KeyFile) @@ -133,7 +133,7 @@ func TestAccessibleHTTPS(t *testing.T) { defer os.Remove(*cfg.Settings.Web.KeyFile) router = newWebUI(cfg) - go Run(cfg, util.NewJLog("WARN", false)) + go Run(cfg, nil) time.Sleep(250 * time.Millisecond) address := fmt.Sprintf("https://localhost:%s", *cfg.Settings.Web.ListenPort) diff --git a/webhook/send.go b/webhook/send.go index 541fbf03..21e345be 100644 --- a/webhook/send.go +++ b/webhook/send.go @@ -62,8 +62,9 @@ func (w *WebHook) Send( serviceInfo *util.ServiceInfo, useDelay bool, ) (errs error) { - logFrom := util.LogFrom{Primary: w.ID, Secondary: serviceInfo.ID} // For logging - triesLeft := w.GetMaxTries() // Number of times to send WebHook (until DesiredStatusCode received). + logFrom := &util.LogFrom{Primary: w.ID, Secondary: serviceInfo.ID} + // Number of times to send WebHook (until DesiredStatusCode received). + triesLeft := w.GetMaxTries() if useDelay && w.GetDelay() != "0s" { // Delay sending the WebHook message by the defined interval. @@ -82,7 +83,7 @@ func (w *WebHook) Send( } // Try sending the WebHook. - err := w.try(&logFrom) + err := w.try(logFrom) // SUCCESS! if err == nil { @@ -134,7 +135,7 @@ func (w *WebHook) try(logFrom *util.LogFrom) (err error) { req := w.BuildRequest() if req == nil { err = fmt.Errorf("failed to get *http.request for webhook") - jLog.Error(err, *logFrom, true) + jLog.Error(err, logFrom, true) return } @@ -164,7 +165,7 @@ func (w *WebHook) try(logFrom *util.LogFrom) (err error) { desiredStatusCode := w.GetDesiredStatusCode() if bodyOkay && (resp.StatusCode == desiredStatusCode || (desiredStatusCode == 0 && (strconv.Itoa(resp.StatusCode)[:1] == "2"))) { msg := fmt.Sprintf("(%d) WebHook received", resp.StatusCode) - jLog.Info(msg, *logFrom, true) + jLog.Info(msg, logFrom, true) return } From 2ca97a93fdcf7a93c66c07d5298773e70083b0bd Mon Sep 17 00:00:00 2001 From: Joseph Kavanagh Date: Sat, 20 Apr 2024 16:21:39 +0100 Subject: [PATCH 2/4] test: serial stdout-dependent tests --- cmd/argus/main_test.go | 4 + commands/commands_test.go | 10 +++ config/defaults_test.go | 4 + config/edit_test.go | 6 +- config/help_test.go | 6 +- test/config.go => config/test/main.go | 9 +- .../test/main_test.go | 0 config/verify_test.go | 7 ++ db/init_test.go | 5 ++ .../shoutrrr/test/main.go | 6 +- .../shoutrrr/test/main_test.go | 0 notifiers/shoutrrr/verify_test.go | 4 + service/deployed_version/query_test.go | 7 +- service/latest_version/query_test.go | 11 +++ service/track_test.go | 5 ++ service/verify_test.go | 4 + test/main.go | 21 +++-- test/main_test.go | 46 ---------- testing/commands_test.go | 4 + testing/service_test.go | 4 + testing/shoutrrr_test.go | 7 ++ util/log_test.go | 20 +++++ util/util_test.go | 10 +++ web/api/v1/http-api-actions_test.go | 6 ++ web/api/v1/http-api-edit_test.go | 14 ++-- web/api/v1/http_test.go | 5 +- web/help_test.go | 2 - .../react-app/src/reducers/action-release.tsx | 84 +++++++++---------- web/ui/react-app/src/reducers/monitor.tsx | 2 +- web/ui/react-app/src/types/websocket.tsx | 10 +-- web/web_test.go | 8 +- webhook/send_test.go | 7 ++ webhook/verify_test.go | 4 + 33 files changed, 216 insertions(+), 126 deletions(-) rename test/config.go => config/test/main.go (88%) rename test/config_test.go => config/test/main_test.go (100%) rename test/shoutrrr.go => notifiers/shoutrrr/test/main.go (90%) rename test/shoutrrr_test.go => notifiers/shoutrrr/test/main_test.go (100%) diff --git a/cmd/argus/main_test.go b/cmd/argus/main_test.go index 1ede9455..c9ac3ace 100644 --- a/cmd/argus/main_test.go +++ b/cmd/argus/main_test.go @@ -24,6 +24,7 @@ import ( "testing" "time" + "github.com/release-argus/Argus/test" "github.com/release-argus/Argus/util" ) @@ -76,6 +77,9 @@ func TestTheMain(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { + // t.Parallel() - Cannot run in parallel since we're using stdout + test.StdoutMutex.Lock() + defer test.StdoutMutex.Unlock() file := fmt.Sprintf("%s.yml", name) os.Remove(tc.db) diff --git a/commands/commands_test.go b/commands/commands_test.go index da00c4f0..ac9c4eb2 100644 --- a/commands/commands_test.go +++ b/commands/commands_test.go @@ -25,6 +25,7 @@ import ( "testing" svcstatus "github.com/release-argus/Argus/service/status" + "github.com/release-argus/Argus/test" "github.com/release-argus/Argus/util" ) @@ -92,6 +93,9 @@ func TestCommand_Exec(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { + // t.Parallel() - Cannot run in parallel since we're using stdout + test.StdoutMutex.Lock() + defer test.StdoutMutex.Unlock() stdout := os.Stdout r, w, _ := os.Pipe() @@ -157,6 +161,9 @@ func TestController_ExecIndex(t *testing.T) { runNumber := 0 for name, tc := range tests { t.Run(name, func(t *testing.T) { + // t.Parallel() - Cannot run in parallel since we're using stdout + test.StdoutMutex.Lock() + defer test.StdoutMutex.Unlock() stdout := os.Stdout r, w, _ := os.Pipe() @@ -224,6 +231,9 @@ func TestController_Exec(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { + // t.Parallel() - Cannot run in parallel since we're using stdout + test.StdoutMutex.Lock() + defer test.StdoutMutex.Unlock() announce := make(chan []byte, 8) controller := testController(&announce) diff --git a/config/defaults_test.go b/config/defaults_test.go index f4d39fbe..1ded2ed6 100644 --- a/config/defaults_test.go +++ b/config/defaults_test.go @@ -29,6 +29,7 @@ import ( latestver "github.com/release-argus/Argus/service/latest_version" "github.com/release-argus/Argus/service/latest_version/filter" opt "github.com/release-argus/Argus/service/options" + "github.com/release-argus/Argus/test" "github.com/release-argus/Argus/util" "github.com/release-argus/Argus/webhook" ) @@ -1072,6 +1073,9 @@ func TestDefaults_Print(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { + // t.Parallel() - Cannot run in parallel since we're using stdout + test.StdoutMutex.Lock() + defer test.StdoutMutex.Unlock() stdout := os.Stdout r, w, _ := os.Pipe() diff --git a/config/edit_test.go b/config/edit_test.go index fc2aa60c..1dcd6794 100644 --- a/config/edit_test.go +++ b/config/edit_test.go @@ -66,7 +66,7 @@ func TestConfig_RenameService(t *testing.T) { file := fmt.Sprintf("TestConfig_RenameService_%s.yml", name) testYAML_Edit(file, t) logMutex.Lock() - cfg := testLoadBasic(file, t) // Global vars could otherwise DATA RACE + cfg := testLoadBasic(file, t) newSVC := testServiceURL(tc.newName) // WHEN the service is renamed @@ -141,7 +141,7 @@ func TestConfig_DeleteService(t *testing.T) { file := fmt.Sprintf("TestConfig_DeleteService_%s.yml", name) testYAML_Edit(file, t) logMutex.Lock() - cfg := testLoadBasic(file, t) // Global vars could otherwise DATA RACE + cfg := testLoadBasic(file, t) // WHEN the service is deleted cfg.DeleteService(tc.name) @@ -234,7 +234,7 @@ func TestConfig_AddService(t *testing.T) { file := fmt.Sprintf("TestConfig_AddService_%s.yml", strings.ReplaceAll(name, " ", "_")) testYAML_Edit(file, t) logMutex.Lock() - cfg := testLoadBasic(file, t) // Global vars could otherwise DATA RACE + cfg := testLoadBasic(file, t) if tc.nilMap { cfg.Service = nil cfg.Order = []string{} diff --git a/config/help_test.go b/config/help_test.go index a5818412..94c075a3 100644 --- a/config/help_test.go +++ b/config/help_test.go @@ -111,7 +111,7 @@ func testLoad(file string, t *testing.T) (config *Config) { return } -var mutex sync.Mutex +var configInitMutex sync.Mutex func testLoadBasic(file string, t *testing.T) (config *Config) { config = &Config{} @@ -136,9 +136,7 @@ func testLoadBasic(file string, t *testing.T) (config *Config) { config.HardDefaults.Service.Status.DatabaseChannel = config.DatabaseChannel config.GetOrder(data) - mutex.Lock() - defer mutex.Unlock() - config.Init(true) + config.Init(false) // Log already set in TestMain for name, service := range config.Service { service.ID = name } diff --git a/test/config.go b/config/test/main.go similarity index 88% rename from test/config.go rename to config/test/main.go index 2208c07c..99e6180c 100644 --- a/test/config.go +++ b/config/test/main.go @@ -16,8 +16,12 @@ package test -import "github.com/release-argus/Argus/config" +import ( + "github.com/release-argus/Argus/config" + "github.com/release-argus/Argus/test" +) +// NilFlags sets all flags to nil in the given config func NilFlags(cfg *config.Config) { flags := []string{ "log.level", @@ -38,12 +42,13 @@ func NilFlags(cfg *config.Config) { cfg.Settings.NilUndefinedFlags(&flagMap) } +// BareConfig returns a minimal config with no flags set func BareConfig(nilFlags bool) (cfg *config.Config) { cfg = &config.Config{ Settings: config.Settings{ SettingsBase: config.SettingsBase{ Web: config.WebSettings{ - RoutePrefix: StringPtr(""), + RoutePrefix: test.StringPtr(""), }}}} // NilFlags can be a RACE condition, so use it conditionally diff --git a/test/config_test.go b/config/test/main_test.go similarity index 100% rename from test/config_test.go rename to config/test/main_test.go diff --git a/config/verify_test.go b/config/verify_test.go index 7bf77812..2c4e5677 100644 --- a/config/verify_test.go +++ b/config/verify_test.go @@ -27,6 +27,7 @@ import ( "github.com/release-argus/Argus/service" latestver "github.com/release-argus/Argus/service/latest_version" opt "github.com/release-argus/Argus/service/options" + "github.com/release-argus/Argus/test" "github.com/release-argus/Argus/webhook" ) @@ -120,6 +121,9 @@ func TestConfig_CheckValues(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { + // t.Parallel() - Cannot run in parallel since we're using stdout + test.StdoutMutex.Lock() + defer test.StdoutMutex.Unlock() stdout := os.Stdout r, w, _ := os.Pipe() @@ -199,6 +203,9 @@ func TestConfig_Print(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { + // t.Parallel() - Cannot run in parallel since we're using stdout + test.StdoutMutex.Lock() + defer test.StdoutMutex.Unlock() stdout := os.Stdout r, w, _ := os.Pipe() diff --git a/db/init_test.go b/db/init_test.go index 17ed29d5..8f2f0a99 100644 --- a/db/init_test.go +++ b/db/init_test.go @@ -29,6 +29,7 @@ import ( dbtype "github.com/release-argus/Argus/db/types" svcstatus "github.com/release-argus/Argus/service/status" + "github.com/release-argus/Argus/test" "github.com/release-argus/Argus/util" _ "modernc.org/sqlite" ) @@ -385,6 +386,10 @@ func Test_UpdateColumnTypes(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { + // t.Parallel() - Cannot run in parallel since we're using stdout + test.StdoutMutex.Lock() + defer test.StdoutMutex.Unlock() + stdout := os.Stdout r, w, _ := os.Pipe() os.Stdout = w diff --git a/test/shoutrrr.go b/notifiers/shoutrrr/test/main.go similarity index 90% rename from test/shoutrrr.go rename to notifiers/shoutrrr/test/main.go index 5b701a34..737c9021 100644 --- a/test/shoutrrr.go +++ b/notifiers/shoutrrr/test/main.go @@ -22,8 +22,10 @@ import ( "github.com/release-argus/Argus/notifiers/shoutrrr" svcstatus "github.com/release-argus/Argus/service/status" + "github.com/release-argus/Argus/test" ) +// testShoutrrrrGotifyToken returns the token for the Gotify test func testShoutrrrrGotifyToken() (token string) { token = os.Getenv("ARGUS_TEST_GOTIFY_TOKEN") if token == "" { @@ -33,6 +35,7 @@ func testShoutrrrrGotifyToken() (token string) { return } +// ShoutrrrDefaults returns a ShoutrrrDefaults instance for testing func ShoutrrrDefaults(failing bool, selfSignedCert bool) *shoutrrr.ShoutrrrDefaults { url := "valid.release-argus.io" if selfSignedCert { @@ -53,6 +56,7 @@ func ShoutrrrDefaults(failing bool, selfSignedCert bool) *shoutrrr.ShoutrrrDefau return shoutrrr } +// Shoutrrr returns a shoutrrr instance for testing func Shoutrrr(failing bool, selfSignedCert bool) *shoutrrr.Shoutrrr { url := "valid.release-argus.io" if selfSignedCert { @@ -79,7 +83,7 @@ func Shoutrrr(failing bool, selfSignedCert bool) *shoutrrr.Shoutrrr { shoutrrr.ID = "test" shoutrrr.ServiceStatus = &svcstatus.Status{ - ServiceID: StringPtr("service"), + ServiceID: test.StringPtr("service"), } shoutrrr.ServiceStatus.Fails.Shoutrrr.Init(1) shoutrrr.Failed = &shoutrrr.ServiceStatus.Fails.Shoutrrr diff --git a/test/shoutrrr_test.go b/notifiers/shoutrrr/test/main_test.go similarity index 100% rename from test/shoutrrr_test.go rename to notifiers/shoutrrr/test/main_test.go diff --git a/notifiers/shoutrrr/verify_test.go b/notifiers/shoutrrr/verify_test.go index fabbde2c..d60a7b9e 100644 --- a/notifiers/shoutrrr/verify_test.go +++ b/notifiers/shoutrrr/verify_test.go @@ -24,6 +24,7 @@ import ( "testing" svcstatus "github.com/release-argus/Argus/service/status" + "github.com/release-argus/Argus/test" "github.com/release-argus/Argus/util" ) @@ -1383,6 +1384,9 @@ notify: for name, tc := range tests { t.Run(name, func(t *testing.T) { + // t.Parallel() - Cannot run in parallel since we're using stdout + test.StdoutMutex.Lock() + defer test.StdoutMutex.Unlock() if tc.want != "" { tc.want += "\n" diff --git a/service/deployed_version/query_test.go b/service/deployed_version/query_test.go index 625f612e..6e653928 100644 --- a/service/deployed_version/query_test.go +++ b/service/deployed_version/query_test.go @@ -27,6 +27,7 @@ import ( dbtype "github.com/release-argus/Argus/db/types" opt "github.com/release-argus/Argus/service/options" svcstatus "github.com/release-argus/Argus/service/status" + "github.com/release-argus/Argus/test" "github.com/release-argus/Argus/util" metric "github.com/release-argus/Argus/web/metrics" ) @@ -435,7 +436,9 @@ func TestLookup_Track(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { - // t.Parallel() - cannot run in parallel because of stdout + // t.Parallel() - Cannot run in parallel since we're using stdout + test.StdoutMutex.Lock() + defer test.StdoutMutex.Unlock() for k, v := range tc.env { os.Setenv(k, v) @@ -530,6 +533,8 @@ func TestLookup_Track(t *testing.T) { t.Errorf("expected DatabaseChannel to have %d messages in queue, not %d", tc.wantDatabaseMesages, len(*tc.lookup.Status.DatabaseChannel)) } + + // Set Deleting to stop the Track tc.lookup.Status.SetDeleting() }) } diff --git a/service/latest_version/query_test.go b/service/latest_version/query_test.go index f1bb9e5a..2e8420f3 100644 --- a/service/latest_version/query_test.go +++ b/service/latest_version/query_test.go @@ -25,6 +25,7 @@ import ( "time" "github.com/release-argus/Argus/service/latest_version/filter" + "github.com/release-argus/Argus/test" "github.com/release-argus/Argus/util" ) @@ -214,6 +215,9 @@ func TestLookup_Query(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { + // t.Parallel() - Cannot run in parallel since we're using stdout + test.StdoutMutex.Lock() + defer test.StdoutMutex.Unlock() try := 0 temporaryFailureInNameResolution := true @@ -291,6 +295,10 @@ func TestLookup_Query(t *testing.T) { } func TestLookup_Query__EmptyListETagChanged(t *testing.T) { + // t.Parallel() - Cannot run in parallel since we're using stdout + test.StdoutMutex.Lock() + defer test.StdoutMutex.Unlock() + // Lock so that default empty list ETag isn't changed by other tests emptyListETagTestMutex.Lock() defer emptyListETagTestMutex.Unlock() @@ -378,6 +386,9 @@ no releases were found matching the url_commands and/or require`}, for name, tc := range tests { t.Run(name, func(t *testing.T) { + // t.Parallel() - Cannot run in parallel since we're using stdout + test.StdoutMutex.Lock() + defer test.StdoutMutex.Unlock() lookup := testLookup(false, false) lookup.GitHubData.SetETag("foo") diff --git a/service/track_test.go b/service/track_test.go index 28d3d3b4..0af62ff7 100644 --- a/service/track_test.go +++ b/service/track_test.go @@ -91,6 +91,9 @@ func TestSlice_Track(t *testing.T) { t.Fatalf("didn't expect Query to have done anything for %s\n%#v", i, (*slice)[i].Status.String()) } + + // Set Deleting to stop the Track + (*slice)[i].Status.SetDeleting() } }) } @@ -395,6 +398,8 @@ func TestService_Track(t *testing.T) { if len(didFinish) == 0 && !shouldFinish { t.Fatal("expected Track to finish when not active, or is deleting") } + + // Set Deleting to stop the Track svc.Status.SetDeleting() }) } diff --git a/service/verify_test.go b/service/verify_test.go index 30d1263d..3c91702d 100644 --- a/service/verify_test.go +++ b/service/verify_test.go @@ -29,6 +29,7 @@ import ( latestver "github.com/release-argus/Argus/service/latest_version" "github.com/release-argus/Argus/service/latest_version/filter" opt "github.com/release-argus/Argus/service/options" + "github.com/release-argus/Argus/test" "github.com/release-argus/Argus/util" "github.com/release-argus/Argus/webhook" ) @@ -84,6 +85,9 @@ func TestSlice_Print(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { + // t.Parallel() - Cannot run in parallel since we're using stdout + test.StdoutMutex.Lock() + defer test.StdoutMutex.Unlock() if tc.want != "" { tc.want += "\n" diff --git a/test/main.go b/test/main.go index 0192690c..d1c53d45 100644 --- a/test/main.go +++ b/test/main.go @@ -19,23 +19,33 @@ package test import ( "fmt" "strings" - - "github.com/release-argus/Argus/util" + "sync" ) +var StdoutMutex sync.Mutex // Only one test should write to stdout at a time + +// BoolPtr returns a pointer to the given boolean value func BoolPtr(val bool) *bool { return &val } + +// IntPtr returns a pointer to the given integer value func IntPtr(val int) *int { return &val } + +// StringPtr returns a pointer to the given string value func StringPtr(val string) *string { return &val } + +// UIntPtr returns a pointer to the given unsigned integer value func UIntPtr(val int) *uint { converted := uint(val) return &converted } + +// StringifyPtr returns a string representation of the given pointer func StringifyPtr[T comparable](ptr *T) string { str := "nil" if ptr != nil { @@ -44,11 +54,7 @@ func StringifyPtr[T comparable](ptr *T) string { return str } -func CopyMapPtr(tgt map[string]string) *map[string]string { - ptr := util.CopyMap(tgt) - return &ptr -} - +// TrimJSON removes unnecessary whitespace from a JSON string func TrimJSON(str string) string { str = strings.TrimSpace(str) str = strings.ReplaceAll(str, "\n", "") @@ -59,6 +65,7 @@ func TrimJSON(str string) string { return str } +// Combinations generates all possible combinations of the given input func Combinations[T comparable](input []T) [][]T { var result [][]T diff --git a/test/main_test.go b/test/main_test.go index 268d56bc..e8794d79 100644 --- a/test/main_test.go +++ b/test/main_test.go @@ -165,52 +165,6 @@ func TestStringifyPtr(t *testing.T) { } } -func TestCopyMapPtr(t *testing.T) { - // GIVEN a map - tests := map[string]struct { - tgt map[string]string - want map[string]string - }{ - "nil": { - tgt: nil, - want: nil, - }, - "empty": { - tgt: map[string]string{}, - want: map[string]string{}, - }, - "non-empty": { - tgt: map[string]string{ - "key": "value", - }, - want: map[string]string{ - "key": "value", - }, - }, - } - - for name, tc := range tests { - t.Run(name, func(t *testing.T) { - t.Parallel() - - // WHEN CopyMapPtr is called - result := CopyMapPtr(tc.tgt) - - // THEN the result should be a pointer to a copy of the map - if len(*result) != len(tc.want) { - t.Errorf("length differs, expected %d but got %d", - len(tc.want), len(*result)) - } - for k, v := range tc.want { - if (*result)[k] != v { - t.Errorf("%q: expected %q but got %q", - k, v, (*result)[k]) - } - } - }) - } -} - func TestTrimJSON(t *testing.T) { // GIVEN a JSON string tests := map[string]struct { diff --git a/testing/commands_test.go b/testing/commands_test.go index 7126e430..f8bb0502 100644 --- a/testing/commands_test.go +++ b/testing/commands_test.go @@ -27,6 +27,7 @@ import ( "github.com/release-argus/Argus/config" "github.com/release-argus/Argus/service" opt "github.com/release-argus/Argus/service/options" + "github.com/release-argus/Argus/test" ) func TestCommandTest(t *testing.T) { @@ -102,6 +103,9 @@ func TestCommandTest(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { + // t.Parallel() - Cannot run in parallel since we're using stdout + test.StdoutMutex.Lock() + defer test.StdoutMutex.Unlock() stdout := os.Stdout r, w, _ := os.Pipe() diff --git a/testing/service_test.go b/testing/service_test.go index 32b94a60..5cfd9d29 100644 --- a/testing/service_test.go +++ b/testing/service_test.go @@ -33,6 +33,7 @@ import ( "github.com/release-argus/Argus/service/latest_version/filter" opt "github.com/release-argus/Argus/service/options" svcstatus "github.com/release-argus/Argus/service/status" + "github.com/release-argus/Argus/test" "github.com/release-argus/Argus/webhook" ) @@ -159,6 +160,9 @@ func TestServiceTest(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { + // t.Parallel() - Cannot run in parallel since we're using stdout + test.StdoutMutex.Lock() + defer test.StdoutMutex.Unlock() stdout := os.Stdout r, w, _ := os.Pipe() diff --git a/testing/shoutrrr_test.go b/testing/shoutrrr_test.go index 99a1d9e5..62cec4bc 100644 --- a/testing/shoutrrr_test.go +++ b/testing/shoutrrr_test.go @@ -26,6 +26,7 @@ import ( "github.com/release-argus/Argus/config" "github.com/release-argus/Argus/notifiers/shoutrrr" "github.com/release-argus/Argus/service" + "github.com/release-argus/Argus/test" "github.com/release-argus/Argus/util" "github.com/release-argus/Argus/webhook" ) @@ -382,6 +383,9 @@ func TestFindShoutrrr(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { + // t.Parallel() - Cannot run in parallel since we're using stdout + test.StdoutMutex.Lock() + defer test.StdoutMutex.Unlock() stdout := os.Stdout r, w, _ := os.Pipe() @@ -543,6 +547,9 @@ func TestNotifyTest(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { + // t.Parallel() - Cannot run in parallel since we're using stdout + test.StdoutMutex.Lock() + defer test.StdoutMutex.Unlock() stdout := os.Stdout r, w, _ := os.Pipe() diff --git a/util/log_test.go b/util/log_test.go index 89538838..3341d093 100644 --- a/util/log_test.go +++ b/util/log_test.go @@ -25,6 +25,8 @@ import ( "regexp" "strings" "testing" + + "github.com/release-argus/Argus/test" ) func TestNewJLog(t *testing.T) { @@ -306,6 +308,9 @@ func TestJLog_Error(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { + // t.Parallel() - Cannot run in parallel since we're using stdout + test.StdoutMutex.Lock() + defer test.StdoutMutex.Unlock() jLog := NewJLog(tc.level, tc.timestamps) stdout := os.Stdout @@ -384,6 +389,9 @@ func TestJLog_Warn(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { + // t.Parallel() - Cannot run in parallel since we're using stdout + test.StdoutMutex.Lock() + defer test.StdoutMutex.Unlock() jLog := NewJLog(tc.level, tc.timestamps) stdout := os.Stdout @@ -462,6 +470,9 @@ func TestJLog_Info(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { + // t.Parallel() - Cannot run in parallel since we're using stdout + test.StdoutMutex.Lock() + defer test.StdoutMutex.Unlock() jLog := NewJLog(tc.level, tc.timestamps) stdout := os.Stdout @@ -545,6 +556,9 @@ func TestJLog_Verbose(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { + // t.Parallel() - Cannot run in parallel since we're using stdout + test.StdoutMutex.Lock() + defer test.StdoutMutex.Unlock() msg := "argus" @@ -647,6 +661,9 @@ func TestJLog_Debug(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { + // t.Parallel() - Cannot run in parallel since we're using stdout + test.StdoutMutex.Lock() + defer test.StdoutMutex.Unlock() jLog := NewJLog(tc.level, tc.timestamps) stdout := os.Stdout @@ -741,6 +758,9 @@ func TestJLog_Fatal(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { + // t.Parallel() - Cannot run in parallel since we're using stdout + test.StdoutMutex.Lock() + defer test.StdoutMutex.Unlock() jLog := NewJLog(tc.level, tc.timestamps) stdout := os.Stdout diff --git a/util/util_test.go b/util/util_test.go index 0fc22869..45355286 100644 --- a/util/util_test.go +++ b/util/util_test.go @@ -25,6 +25,8 @@ import ( "regexp" "strings" "testing" + + "github.com/release-argus/Argus/test" ) func TestContains(t *testing.T) { @@ -630,6 +632,8 @@ func TestPrintlnIfNotDefault(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { // t.Parallel() - Cannot run in parallel since we're using stdout + test.StdoutMutex.Lock() + defer test.StdoutMutex.Unlock() msg := "var is not default from PrintlnIfNotDefault" stdout := os.Stdout @@ -673,6 +677,9 @@ func TestPrintlnIfNotNil(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { + // t.Parallel() - Cannot run in parallel since we're using stdout + test.StdoutMutex.Lock() + defer test.StdoutMutex.Unlock() msg := "var is not default from PrintlnIfNotNil" stdout := os.Stdout @@ -716,6 +723,9 @@ func TestPrintlnIfNil(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { + // t.Parallel() - Cannot run in parallel since we're using stdout + test.StdoutMutex.Lock() + defer test.StdoutMutex.Unlock() msg := "var is not default from PrintlnIfNil" stdout := os.Stdout diff --git a/web/api/v1/http-api-actions_test.go b/web/api/v1/http-api-actions_test.go index 66014044..c76aec7e 100644 --- a/web/api/v1/http-api-actions_test.go +++ b/web/api/v1/http-api-actions_test.go @@ -107,6 +107,9 @@ func TestHTTP_httpServiceGetActions(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { + // t.Parallel() - Cannot run in parallel since we're using stdout + test.StdoutMutex.Lock() + defer test.StdoutMutex.Unlock() if tc.statusCode == 0 { tc.statusCode = http.StatusOK @@ -428,6 +431,9 @@ func TestHTTP_httpServiceRunActions(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { + // t.Parallel() - Cannot run in parallel since we're using stdout + test.StdoutMutex.Lock() + defer test.StdoutMutex.Unlock() tc.serviceID = strings.ReplaceAll(tc.serviceID, "__name__", name) svc := testService(tc.serviceID) diff --git a/web/api/v1/http-api-edit_test.go b/web/api/v1/http-api-edit_test.go index d1fa1bd6..88a15b8c 100644 --- a/web/api/v1/http-api-edit_test.go +++ b/web/api/v1/http-api-edit_test.go @@ -33,6 +33,7 @@ import ( "github.com/gorilla/mux" "github.com/release-argus/Argus/notifiers/shoutrrr" + shoutrrr_test "github.com/release-argus/Argus/notifiers/shoutrrr/test" "github.com/release-argus/Argus/service" "github.com/release-argus/Argus/test" "github.com/release-argus/Argus/util" @@ -821,16 +822,17 @@ func TestHTTP_NotifyTest(t *testing.T) { file := "TestHTTP_NotifyTest.yml" api := testAPI(file) defer os.RemoveAll(file) - validNotify := test.Shoutrrr(false, false) + validNotify := shoutrrr_test.Shoutrrr(false, false) api.Config.Notify = shoutrrr.SliceDefaults{} + options := util.CopyMap(validNotify.Options) + params := util.CopyMap(validNotify.Params) + urlFields := util.CopyMap(validNotify.URLFields) api.Config.Notify["test"] = shoutrrr.NewDefaults( "gotify", - test.CopyMapPtr(validNotify.Options), - test.CopyMapPtr(validNotify.Params), - test.CopyMapPtr(validNotify.URLFields)) + &options, ¶ms, &urlFields) api.Config.Service["test"].Notify = map[string]*shoutrrr.Shoutrrr{ - "test": test.Shoutrrr(false, false), - "no_main": test.Shoutrrr(false, false)} + "test": shoutrrr_test.Shoutrrr(false, false), + "no_main": shoutrrr_test.Shoutrrr(false, false)} tests := map[string]struct { queryParams map[string]string payload string diff --git a/web/api/v1/http_test.go b/web/api/v1/http_test.go index 4b11bbc7..13494540 100644 --- a/web/api/v1/http_test.go +++ b/web/api/v1/http_test.go @@ -27,6 +27,7 @@ import ( "testing" "github.com/release-argus/Argus/config" + config_test "github.com/release-argus/Argus/config/test" "github.com/release-argus/Argus/test" "github.com/release-argus/Argus/util" api_type "github.com/release-argus/Argus/web/api/types" @@ -163,7 +164,7 @@ func TestHTTP_SetupRoutesFavicon(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { - cfg := test.BareConfig(true) + cfg := config_test.BareConfig(true) cfg.Settings.Web.Favicon = testFaviconSettings(tc.urlPNG, tc.urlSVG) api := NewAPI(cfg, util.NewJLog("WARN", false)) api.SetupRoutesFavicon() @@ -422,7 +423,7 @@ func TestHTTP_DisableRoutes(t *testing.T) { } t.Run(strings.Join(disabledRoutes, ";"), func(t *testing.T) { - cfg := test.BareConfig(false) + cfg := config_test.BareConfig(false) cfg.Settings.Web.DisabledRoutes = disabledRoutes // Give every other test a route prefix routePrefix := "" diff --git a/web/help_test.go b/web/help_test.go index 2b5e1657..2da19ce8 100644 --- a/web/help_test.go +++ b/web/help_test.go @@ -26,7 +26,6 @@ import ( "math/big" "net" "os" - "sync" "testing" "time" @@ -44,7 +43,6 @@ import ( "github.com/release-argus/Argus/webhook" ) -var stdoutMutex sync.Mutex var mainCfg *config.Config var port *string diff --git a/web/ui/react-app/src/reducers/action-release.tsx b/web/ui/react-app/src/reducers/action-release.tsx index 391beb3f..36bcd300 100644 --- a/web/ui/react-app/src/reducers/action-release.tsx +++ b/web/ui/react-app/src/reducers/action-release.tsx @@ -26,53 +26,51 @@ export default function reducerActionModal( ) return state; - switch (action.sub_type) { - case "EVENT": - if (action.webhook_data) - for (const webhook_id in action.webhook_data) { - // Remove them from the sending list - newState.sentWH.splice( - newState.sentWH.indexOf( - `${action.service_data.id} ${webhook_id}` - ), - 1 - ); + if (action.sub_type == "EVENT") { + if (action.webhook_data) + for (const webhook_id in action.webhook_data) { + // Remove them from the sending list + newState.sentWH.splice( + newState.sentWH.indexOf( + `${action.service_data.id} ${webhook_id}` + ), + 1 + ); - // Record the success/fail (if it's the current modal service) - if ( - action.service_data.id === state.service_id && - newState.webhooks[webhook_id] !== undefined - ) - newState.webhooks[webhook_id] = { - failed: action.webhook_data[webhook_id].failed, - next_runnable: action.webhook_data[webhook_id].next_runnable, - }; - } + // Record the success/fail (if it's the current modal service) + if ( + action.service_data.id === state.service_id && + newState.webhooks[webhook_id] !== undefined + ) + newState.webhooks[webhook_id] = { + failed: action.webhook_data[webhook_id].failed, + next_runnable: action.webhook_data[webhook_id].next_runnable, + }; + } - if (action.command_data) - for (const command in action.command_data) { - // Remove them from the sending list - newState.sentC.splice( - newState.sentC.indexOf(`${action.service_data.id} ${command}`), - 1 - ); + if (action.command_data) + for (const command in action.command_data) { + // Remove them from the sending list + newState.sentC.splice( + newState.sentC.indexOf(`${action.service_data.id} ${command}`), + 1 + ); - // Record the success/fail (if it's the current modal service) - if ( - action.service_data.id === state.service_id && - newState.commands[command] !== undefined - ) - newState.commands[command] = { - failed: action.command_data[command].failed, - next_runnable: action.command_data[command].next_runnable, - }; - } - break; - default: - console.error(action); - throw new Error(); + // Record the success/fail (if it's the current modal service) + if ( + action.service_data.id === state.service_id && + newState.commands[command] !== undefined + ) + newState.commands[command] = { + failed: action.command_data[command].failed, + next_runnable: action.command_data[command].next_runnable, + }; + } + break; + } else { + console.error(action); + throw new Error(); } - break; // REFRESH // RESET diff --git a/web/ui/react-app/src/reducers/monitor.tsx b/web/ui/react-app/src/reducers/monitor.tsx index 46ddae67..fe2d96fd 100644 --- a/web/ui/react-app/src/reducers/monitor.tsx +++ b/web/ui/react-app/src/reducers/monitor.tsx @@ -55,7 +55,7 @@ export default function reducerMonitor( // UPDATED // INIT case "VERSION": { - const id = action.service_data?.id as string; + const id = action.service_data.id; if (state.service[id] === undefined) return state; switch (action.sub_type) { case "QUERY": { diff --git a/web/ui/react-app/src/types/websocket.tsx b/web/ui/react-app/src/types/websocket.tsx index 8be25f3e..3f3aa03e 100644 --- a/web/ui/react-app/src/types/websocket.tsx +++ b/web/ui/react-app/src/types/websocket.tsx @@ -13,16 +13,16 @@ export interface GorillaWebSocketMessage { data: WebSocketResponse; defaultPrevented: boolean; eventPhase: number; - explicitOriginalTarget: wsMessageDetail; + explicitOriginalTarget: WebSocketMessageDetail; isTrusted: boolean; lastEventId: string; origin: string; - originalTarget: wsMessageDetail; + originalTarget: WebSocketMessageDetail; ports: number[]; returnValue: boolean; source: null; - srcElement: wsMessageDetail; - target: wsMessageDetail; + srcElement: WebSocketMessageDetail; + target: WebSocketMessageDetail; timeStamp: number; type: string; } @@ -70,7 +70,7 @@ export type WebSocketResponse = service_data: ServiceSummaryType; }; -export interface wsMessageDetail { +export interface WebSocketMessageDetail { binaryType: string; bufferedAmount: number; extensions: string; diff --git a/web/web_test.go b/web/web_test.go index d4ab9cf9..5fc3e2e7 100644 --- a/web/web_test.go +++ b/web/web_test.go @@ -29,6 +29,7 @@ import ( "github.com/gorilla/mux" "github.com/gorilla/websocket" + "github.com/release-argus/Argus/test" "github.com/release-argus/Argus/util" api_type "github.com/release-argus/Argus/web/api/types" ) @@ -41,7 +42,7 @@ func TestMainWithRoutePrefix(t *testing.T) { *cfg.Settings.Web.RoutePrefix = "/test" // WHEN the Web UI is started with this Config - go Run(cfg, util.NewJLog("WARN", false)) + go Run(cfg, nil) time.Sleep(500 * time.Millisecond) // THEN Web UI is accessible @@ -238,10 +239,11 @@ func TestWebSocket(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { + // t.Parallel() - Cannot run in parallel since we're using stdout + test.StdoutMutex.Lock() + defer test.StdoutMutex.Unlock() ws := connectToWebSocket(t) - stdoutMutex.Lock() - defer stdoutMutex.Unlock() stdout := os.Stdout r, w, _ := os.Pipe() os.Stdout = w diff --git a/webhook/send_test.go b/webhook/send_test.go index 1306bc32..eabc85bf 100644 --- a/webhook/send_test.go +++ b/webhook/send_test.go @@ -26,6 +26,7 @@ import ( "github.com/prometheus/client_golang/prometheus/testutil" "github.com/release-argus/Argus/notifiers/shoutrrr" + "github.com/release-argus/Argus/test" "github.com/release-argus/Argus/util" metric "github.com/release-argus/Argus/web/metrics" ) @@ -174,6 +175,9 @@ func TestWebHook_Send(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { + // t.Parallel() - Cannot run in parallel since we're using stdout + test.StdoutMutex.Lock() + defer test.StdoutMutex.Unlock() try := 0 contextDeadlineExceeded := true @@ -282,6 +286,9 @@ func TestSlice_Send(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { + // t.Parallel() - Cannot run in parallel since we're using stdout + test.StdoutMutex.Lock() + defer test.StdoutMutex.Unlock() try := 0 contextDeadlineExceeded := true diff --git a/webhook/verify_test.go b/webhook/verify_test.go index 8ba92986..bc01094e 100644 --- a/webhook/verify_test.go +++ b/webhook/verify_test.go @@ -25,6 +25,7 @@ import ( "testing" svcstatus "github.com/release-argus/Argus/service/status" + "github.com/release-argus/Argus/test" "github.com/release-argus/Argus/util" ) @@ -84,6 +85,9 @@ webhook: for name, tc := range tests { t.Run(name, func(t *testing.T) { + // t.Parallel() - Cannot run in parallel since we're using stdout + test.StdoutMutex.Lock() + defer test.StdoutMutex.Unlock() if tc.want != "" { tc.want += "\n" From 326ccc6e04e27c3d39a375a9aac574420bd17404 Mon Sep 17 00:00:00 2001 From: Joseph Kavanagh Date: Sun, 21 Apr 2024 17:50:18 +0100 Subject: [PATCH 3/4] test(testing): service_test.go - broken pipe * panic wasn't restoring stdout --- service/latest_version/query.go | 20 ++++++++------- testing/service.go | 44 ++++++++++++++++++--------------- testing/service_test.go | 4 +++ 3 files changed, 39 insertions(+), 29 deletions(-) diff --git a/service/latest_version/query.go b/service/latest_version/query.go index c8d57648..169a9150 100644 --- a/service/latest_version/query.go +++ b/service/latest_version/query.go @@ -175,7 +175,7 @@ func (l *Lookup) queryMetrics(err error) { } } -func (l *Lookup) httpRequest(logFrom *util.LogFrom) (rawBody []byte, err error) { +func (l *Lookup) httpRequest(logFrom *util.LogFrom) (rawBodyPtr *[]byte, err error) { customTransport := &http.Transport{} // HTTPS insecure skip verify. if l.GetAllowInvalidCerts() { @@ -222,7 +222,9 @@ func (l *Lookup) httpRequest(logFrom *util.LogFrom) (rawBody []byte, err error) // Read the response body. defer resp.Body.Close() + var rawBody []byte rawBody, err = io.ReadAll(resp.Body) + rawBodyPtr = &rawBody jLog.Error(err, logFrom, err != nil) if l.Type == "github" && err == nil { // 200 - Resource has changed @@ -237,7 +239,7 @@ func (l *Lookup) httpRequest(logFrom *util.LogFrom) (rawBody []byte, err error) l.GitHubData.SetTagFallback() if l.GitHubData.TagFallback() { jLog.Verbose(fmt.Sprintf("/releases gave %v, trying /tags", string(rawBody)), logFrom, true) - rawBody, err = l.httpRequest(logFrom) + rawBodyPtr, err = l.httpRequest(logFrom) } // Has tags/releases } else { @@ -253,7 +255,7 @@ func (l *Lookup) httpRequest(logFrom *util.LogFrom) (rawBody []byte, err error) l.GitHubData.SetTagFallback() if l.GitHubData.TagFallback() { jLog.Verbose("no tags found on /releases, trying /tags", logFrom, true) - rawBody, err = l.httpRequest(logFrom) + rawBodyPtr, err = l.httpRequest(logFrom) } } } @@ -264,14 +266,14 @@ func (l *Lookup) httpRequest(logFrom *util.LogFrom) (rawBody []byte, err error) // GetVersions will filter out releases from rawBody that are preReleases (if not wanted) and will sort releases if // semantic versioning is wanted func (l *Lookup) GetVersions( - rawBody []byte, + rawBody *[]byte, logFrom *util.LogFrom, ) (filteredReleases []github_types.Release, err error) { var releases []github_types.Release - body := string(rawBody) + body := string(*rawBody) // GitHub service. if l.Type == "github" { - releases, err = l.checkGitHubReleasesBody(&rawBody, logFrom) + releases, err = l.checkGitHubReleasesBody(rawBody, logFrom) if err != nil { return } @@ -299,10 +301,10 @@ func (l *Lookup) GetVersions( } // GetVersion will return the latest version from rawBody matching the URLCommands and Regex requirements -func (l *Lookup) GetVersion(rawBody []byte, logFrom *util.LogFrom) (version string, err error) { +func (l *Lookup) GetVersion(rawBody *[]byte, logFrom *util.LogFrom) (version string, err error) { var filteredReleases []github_types.Release // rawBody length = 0 if GitHub ETag is unchanged - if len(rawBody) != 0 { + if len(*rawBody) != 0 { filteredReleases, err = l.GetVersions(rawBody, logFrom) if err != nil { return @@ -337,7 +339,7 @@ func (l *Lookup) GetVersion(rawBody []byte, logFrom *util.LogFrom) (version stri body = filteredReleases[i].Assets // Web service } else { - body = string(rawBody) + body = string(*rawBody) } // If the Content doesn't match the provided RegEx if err = l.Require.RegexCheckContent(version, body, logFrom); err != nil { diff --git a/testing/service.go b/testing/service.go index 61d26fb3..e453d083 100644 --- a/testing/service.go +++ b/testing/service.go @@ -33,16 +33,17 @@ func ServiceTest( if *flag == "" { return } - logFrom := &util.LogFrom{Primary: "Testing", Secondary: *flag} + // Log what we are testing + logFrom := &util.LogFrom{Primary: "Testing", Secondary: *flag} log.Info( "", logFrom, true, ) - service := cfg.Service[*flag] - if service == nil { + // Check if service exists + if !util.Contains(cfg.Order, *flag) { log.Fatal( fmt.Sprintf( "Service %q could not be found in config.service\nDid you mean one of these?\n - %s", @@ -53,24 +54,26 @@ func ServiceTest( ) } - if service != nil { - _, err := service.LatestVersion.Query(false, logFrom) - if err != nil { - helpMsg := "" - if service.LatestVersion.Type == "url" && strings.Count(service.LatestVersion.URL, "/") == 1 && !strings.HasPrefix(service.LatestVersion.URL, "http") { - helpMsg = "This URL looks to be a GitHub repo, but the service's type is url, not github. Try using the github service type.\n" - } - log.Error( - fmt.Sprintf( - "No version matching the conditions specified could be found for %q at %q\n%s", - *flag, - service.LatestVersion.ServiceURL(true), - helpMsg, - ), - logFrom, - true, - ) + // Service we are testing + service := cfg.Service[*flag] + + // LatestVersion + _, err := service.LatestVersion.Query(false, logFrom) + if err != nil { + helpMsg := "" + if service.LatestVersion.Type == "url" && strings.Count(service.LatestVersion.URL, "/") == 1 && !strings.HasPrefix(service.LatestVersion.URL, "http") { + helpMsg = "This URL looks to be a GitHub repo, but the service's type is url, not github. Try using the github service type.\n" } + log.Error( + fmt.Sprintf( + "No version matching the conditions specified could be found for %q at %q\n%s", + *flag, + service.LatestVersion.ServiceURL(true), + helpMsg, + ), + logFrom, + true, + ) } // DeployedVersionLookup @@ -85,6 +88,7 @@ func ServiceTest( err == nil, ) } + if !log.Testing { os.Exit(0) } diff --git a/testing/service_test.go b/testing/service_test.go index 5cfd9d29..cdf302fd 100644 --- a/testing/service_test.go +++ b/testing/service_test.go @@ -170,6 +170,10 @@ func TestServiceTest(t *testing.T) { if tc.panicRegex != nil { // Switch Fatal to panic and disable this panic. defer func() { + // Reset stdout + w.Close() + os.Stdout = stdout + // Check the panic message r := recover() rStr := fmt.Sprint(r) re := regexp.MustCompile(*tc.panicRegex) From 3aedd9b24ffa0016cf4d8795d35ad652962fed90 Mon Sep 17 00:00:00 2001 From: Joseph Kavanagh Date: Sun, 21 Apr 2024 19:49:05 +0100 Subject: [PATCH 4/4] test: simply stdout tests with function to capture+release --- cmd/argus/main_test.go | 18 +--- commands/commands_test.go | 89 ++++++----------- config/defaults_test.go | 43 ++++---- config/save_test.go | 2 +- config/settings_test.go | 37 ++++--- config/verify_test.go | 47 +++------ db/init_test.go | 22 ++--- notifiers/shoutrrr/verify_test.go | 19 +--- service/deployed_version/query_test.go | 17 +--- service/latest_version/query_test.go | 60 ++++------- service/verify_test.go | 19 +--- test/main.go | 15 +++ testing/commands_test.go | 37 +++---- testing/service_test.go | 41 +++----- testing/shoutrrr_test.go | 60 +++++------ util/log_test.go | 131 ++++++++----------------- util/template_test.go | 1 + util/util_test.go | 58 ++++------- web/api/v1/http-api-actions_test.go | 40 +++----- web/web_test.go | 16 +-- webhook/send_test.go | 44 +++------ webhook/verify_test.go | 17 +--- 22 files changed, 297 insertions(+), 536 deletions(-) diff --git a/cmd/argus/main_test.go b/cmd/argus/main_test.go index c9ac3ace..ac8f6a3b 100644 --- a/cmd/argus/main_test.go +++ b/cmd/argus/main_test.go @@ -18,7 +18,6 @@ package main import ( "fmt" - "io" "os" "strings" "testing" @@ -78,8 +77,7 @@ func TestTheMain(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { // t.Parallel() - Cannot run in parallel since we're using stdout - test.StdoutMutex.Lock() - defer test.StdoutMutex.Unlock() + releaseStdout := test.CaptureStdout() file := fmt.Sprintf("%s.yml", name) os.Remove(tc.db) @@ -87,9 +85,6 @@ func TestTheMain(t *testing.T) { defer os.Remove(tc.db) resetFlags() configFile = &file - stdout := os.Stdout - r, w, _ := os.Pipe() - os.Stdout = w accessToken := os.Getenv("GITHUB_TOKEN") os.Setenv("ARGUS_SERVICE_LATEST_VERSION_ACCESS_TOKEN", accessToken) @@ -98,15 +93,12 @@ func TestTheMain(t *testing.T) { time.Sleep(3 * time.Second) // THEN the program will have printed everything expected - w.Close() - out, _ := io.ReadAll(r) - os.Stdout = stdout - output := string(out) + stdout := releaseStdout() if tc.outputContains != nil { for _, text := range *tc.outputContains { - if !strings.Contains(output, text) { - t.Errorf("%q couldn't be found in the output:\n%s", - text, output) + if !strings.Contains(stdout, text) { + t.Errorf("%q couldn't be found in stdout:\n%s", + text, stdout) } } } diff --git a/commands/commands_test.go b/commands/commands_test.go index ac9c4eb2..9416623d 100644 --- a/commands/commands_test.go +++ b/commands/commands_test.go @@ -18,8 +18,6 @@ package command import ( "fmt" - "io" - "os" "reflect" "regexp" "testing" @@ -80,44 +78,36 @@ func TestCommand_Exec(t *testing.T) { tests := map[string]struct { cmd Command err error - outputRegex string + stdoutRegex string }{ "command that will pass": { cmd: Command{"date", "+%m-%d-%Y"}, - outputRegex: `[0-9]{2}-[0-9]{2}-[0-9]{4}\s+$`}, + stdoutRegex: `[0-9]{2}-[0-9]{2}-[0-9]{4}\s+$`}, "command that will fail": { cmd: Command{"false"}, err: fmt.Errorf("exit status 1"), - outputRegex: `exit status 1\s+$`}, + stdoutRegex: `exit status 1\s+$`}, } for name, tc := range tests { t.Run(name, func(t *testing.T) { // t.Parallel() - Cannot run in parallel since we're using stdout - test.StdoutMutex.Lock() - defer test.StdoutMutex.Unlock() - - stdout := os.Stdout - r, w, _ := os.Pipe() - os.Stdout = w + releaseStdout := test.CaptureStdout() // WHEN Exec is called on it err := tc.cmd.Exec(&util.LogFrom{}) - // THEN the output is expected + // THEN the stdout is expected if util.ErrorToString(err) != util.ErrorToString(tc.err) { t.Fatalf("err's differ\nwant: %s\ngot: %s", tc.err, err) } - w.Close() - out, _ := io.ReadAll(r) - os.Stdout = stdout - output := string(out) - re := regexp.MustCompile(tc.outputRegex) - match := re.MatchString(output) + stdout := releaseStdout() + re := regexp.MustCompile(tc.stdoutRegex) + match := re.MatchString(stdout) if !match { t.Errorf("want match for %q\nnot: %q", - tc.outputRegex, output) + tc.stdoutRegex, stdout) } }) } @@ -142,52 +132,44 @@ func TestController_ExecIndex(t *testing.T) { tests := map[string]struct { index int err error - outputRegex string + stdoutRegex string noAnnounce bool }{ "command index out of range": { index: 2, - outputRegex: `^$`, + stdoutRegex: `^$`, noAnnounce: true}, "command index that will pass": { index: 0, - outputRegex: `[0-9]{2}-[0-9]{2}-[0-9]{4}\s+$`}, + stdoutRegex: `[0-9]{2}-[0-9]{2}-[0-9]{4}\s+$`}, "command index that will fail": { index: 1, err: fmt.Errorf("exit status 1"), - outputRegex: `exit status 1\s+$`}, + stdoutRegex: `exit status 1\s+$`}, } runNumber := 0 for name, tc := range tests { t.Run(name, func(t *testing.T) { // t.Parallel() - Cannot run in parallel since we're using stdout - test.StdoutMutex.Lock() - defer test.StdoutMutex.Unlock() - - stdout := os.Stdout - r, w, _ := os.Pipe() - os.Stdout = w + releaseStdout := test.CaptureStdout() // WHEN the Command @index is exectured err := controller.ExecIndex(&util.LogFrom{}, tc.index) - // THEN the output is expected + // THEN the stdout is expected // err if util.ErrorToString(err) != util.ErrorToString(tc.err) { t.Fatalf("err's differ\nwant: %s\ngot: %s", tc.err, err) } - // output - w.Close() - out, _ := io.ReadAll(r) - os.Stdout = stdout - output := string(out) - re := regexp.MustCompile(tc.outputRegex) - match := re.MatchString(output) + // stdout + stdout := releaseStdout() + re := regexp.MustCompile(tc.stdoutRegex) + match := re.MatchString(stdout) if !match { t.Fatalf("want match for %q\nnot: %q", - tc.outputRegex, output) + tc.stdoutRegex, stdout) } // announced if !tc.noAnnounce { @@ -207,23 +189,23 @@ func TestController_Exec(t *testing.T) { nilController bool commands *Slice err error - outputRegex string + stdoutRegex string noAnnounce bool }{ "nil Controller": { nilController: true, - outputRegex: `^$`, + stdoutRegex: `^$`, noAnnounce: true}, "nil Command": { - outputRegex: `^$`, + stdoutRegex: `^$`, noAnnounce: true}, "single Command": { - outputRegex: `[0-9]{2}-[0-9]{2}-[0-9]{4}\s+$`, + stdoutRegex: `[0-9]{2}-[0-9]{2}-[0-9]{4}\s+$`, commands: &Slice{ {"date", "+%m-%d-%Y"}}}, "multiple Command's": { err: fmt.Errorf("\nexit status 1"), - outputRegex: `[0-9]{2}-[0-9]{2}-[0-9]{4}\s+.*'false'\s.*exit status 1\s+$`, + stdoutRegex: `[0-9]{2}-[0-9]{2}-[0-9]{4}\s+.*'false'\s.*exit status 1\s+$`, commands: &Slice{ {"date", "+%m-%d-%Y"}, {"false"}}}, @@ -232,14 +214,10 @@ func TestController_Exec(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { // t.Parallel() - Cannot run in parallel since we're using stdout - test.StdoutMutex.Lock() - defer test.StdoutMutex.Unlock() + releaseStdout := test.CaptureStdout() announce := make(chan []byte, 8) controller := testController(&announce) - stdout := os.Stdout - r, w, _ := os.Pipe() - os.Stdout = w // WHEN the Command @index is exectured controller.Command = tc.commands @@ -248,22 +226,19 @@ func TestController_Exec(t *testing.T) { } err := controller.Exec(&util.LogFrom{}) - // THEN the output is expected + // THEN the stdout is expected // err if util.ErrorToString(err) != util.ErrorToString(tc.err) { t.Fatalf("err's differ\nwant: %q\ngot: %q", util.ErrorToString(tc.err), util.ErrorToString(err)) } - // output - w.Close() - out, _ := io.ReadAll(r) - os.Stdout = stdout - output := string(out) - re := regexp.MustCompile(tc.outputRegex) - match := re.MatchString(output) + // stdout + stdout := releaseStdout() + re := regexp.MustCompile(tc.stdoutRegex) + match := re.MatchString(stdout) if !match { t.Fatalf("want match for %q\nnot: %q", - tc.outputRegex, output) + tc.stdoutRegex, stdout) } // announced runNumber := 0 diff --git a/config/defaults_test.go b/config/defaults_test.go index 1ded2ed6..b428c19f 100644 --- a/config/defaults_test.go +++ b/config/defaults_test.go @@ -17,7 +17,6 @@ package config import ( - "io" "os" "regexp" "strings" @@ -920,19 +919,22 @@ func TestDefaults_MapEnvToStruct(t *testing.T) { // Catch fatal panics. defer func() { r := recover() - if r != nil { - if tc.errRegex == "" { - t.Fatalf("unexpected panic: %v", r) - } - switch r.(type) { - case string: - if !regexp.MustCompile(tc.errRegex).MatchString(r.(string)) { - t.Errorf("want error matching:\n%v\ngot:\n%v", - tc.errRegex, r.(string)) - } - default: - t.Fatalf("unexpected panic: %v", r) + // Ignore nil panics. + if r == nil { + return + } + + if tc.errRegex == "" { + t.Fatalf("unexpected panic: %v", r) + } + switch r.(type) { + case string: + if !regexp.MustCompile(tc.errRegex).MatchString(r.(string)) { + t.Errorf("want error matching:\n%v\ngot:\n%v", + tc.errRegex, r.(string)) } + default: + t.Fatalf("unexpected panic: %v", r) } }() @@ -1074,24 +1076,17 @@ func TestDefaults_Print(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { // t.Parallel() - Cannot run in parallel since we're using stdout - test.StdoutMutex.Lock() - defer test.StdoutMutex.Unlock() - - stdout := os.Stdout - r, w, _ := os.Pipe() - os.Stdout = w + releaseStdout := test.CaptureStdout() // WHEN Print is called tc.input.Print("") // THEN the expected number of lines are printed - w.Close() - out, _ := io.ReadAll(r) - os.Stdout = stdout - got := strings.Count(string(out), "\n") + stdout := releaseStdout() + got := strings.Count(stdout, "\n") if got != tc.lines { t.Errorf("Print should have given %d lines, but gave %d\n%s", - tc.lines, got, out) + tc.lines, got, stdout) } }) } diff --git a/config/save_test.go b/config/save_test.go index ea72cd56..aa358408 100644 --- a/config/save_test.go +++ b/config/save_test.go @@ -35,7 +35,7 @@ func TestConfig_SaveHandler(t *testing.T) { // GIVEN a message is sent to the SaveHandler config := testConfig() // Disable fatal panics. - defer func() { _ = recover() }() + defer func() { recover() }() go func() { *config.SaveChannel <- true }() diff --git a/config/settings_test.go b/config/settings_test.go index 0a3974d8..b4bbca4d 100644 --- a/config/settings_test.go +++ b/config/settings_test.go @@ -514,19 +514,22 @@ func TestSettings_MapEnvToStruct(t *testing.T) { // Catch fatal panics. defer func() { r := recover() - if r != nil { - if tc.errRegex == "" { - t.Fatalf("unexpected panic: %v", r) - } - switch r.(type) { - case string: - if !regexp.MustCompile(tc.errRegex).MatchString(r.(string)) { - t.Errorf("want error matching:\n%v\ngot:\n%v", - tc.errRegex, t) - } - default: - t.Fatalf("unexpected panic: %v", r) + // Ignore nil panics + if r == nil { + return + } + + if tc.errRegex == "" { + t.Fatalf("unexpected panic: %v", r) + } + switch r.(type) { + case string: + if !regexp.MustCompile(tc.errRegex).MatchString(r.(string)) { + t.Errorf("want error matching:\n%v\ngot:\n%v", + tc.errRegex, t) } + default: + t.Fatalf("unexpected panic: %v", r) } }() @@ -643,9 +646,13 @@ func TestSettings_GetWebFile_NotExist(t *testing.T) { // Catch fatal panics. defer func() { r := recover() - if r != nil && - !(strings.Contains(r.(string), "no such file or directory") || - strings.Contains(r.(string), "cannot find the file specified")) { + // Ignore nil panics + if r == nil { + return + } + + if !(strings.Contains(r.(string), "no such file or directory") || + strings.Contains(r.(string), "cannot find the file specified")) { t.Errorf("expected an error about the file not existing, not %s", r.(string)) } diff --git a/config/verify_test.go b/config/verify_test.go index 2c4e5677..0fa017ae 100644 --- a/config/verify_test.go +++ b/config/verify_test.go @@ -17,8 +17,6 @@ package config import ( - "io" - "os" "regexp" "strings" "testing" @@ -122,12 +120,8 @@ func TestConfig_CheckValues(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { // t.Parallel() - Cannot run in parallel since we're using stdout - test.StdoutMutex.Lock() - defer test.StdoutMutex.Unlock() + releaseStdout := test.CaptureStdout() - stdout := os.Stdout - r, w, _ := os.Pipe() - os.Stdout = w if tc.config != nil { for name, svc := range tc.config.Service { svc.ID = name @@ -136,27 +130,24 @@ func TestConfig_CheckValues(t *testing.T) { // Switch Fatal to panic and disable this panic. if !tc.noPanic { defer func() { - _ = recover() + recover() + stdout := releaseStdout() - w.Close() - out, _ := io.ReadAll(r) - os.Stdout = stdout - output := string(out) - lines := strings.Split(output, "\n") + lines := strings.Split(stdout, "\n") if len(tc.errRegex) == 0 { t.Fatalf("want 0 errors, not %d:\n%v", len(lines), lines) } if len(tc.errRegex) > len(lines) { - t.Fatalf("want %d errors:\n['%s']\ngot %d errors:\n%v\noutput: %q", - len(tc.errRegex), strings.Join(tc.errRegex, `' '`), len(lines), lines, output) + t.Fatalf("want %d errors:\n['%s']\ngot %d errors:\n%v\nstdout: %q", + len(tc.errRegex), strings.Join(tc.errRegex, `' '`), len(lines), lines, stdout) } for i := range tc.errRegex { re := regexp.MustCompile(tc.errRegex[i]) match := re.MatchString(lines[i]) if !match { t.Errorf("want match for: %q\ngot: %q", - tc.errRegex[i], output) + tc.errRegex[i], stdout) return } } @@ -167,11 +158,8 @@ func TestConfig_CheckValues(t *testing.T) { tc.config.CheckValues() // THEN this call will/wont crash the program - w.Close() - out, _ := io.ReadAll(r) - os.Stdout = stdout - output := string(out) - lines := strings.Split(output, `\n`) + stdout := releaseStdout() + lines := strings.Split(stdout, `\n`) if len(tc.errRegex) > len(lines) { t.Errorf("want %d errors:\n%v\ngot %d errors:\n%v", len(tc.errRegex), tc.errRegex, len(lines), lines) @@ -182,7 +170,7 @@ func TestConfig_CheckValues(t *testing.T) { match := re.MatchString(lines[i]) if !match { t.Errorf("want match for: %q\ngot: %q", - tc.errRegex[i], output) + tc.errRegex[i], stdout) return } } @@ -204,24 +192,17 @@ func TestConfig_Print(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { // t.Parallel() - Cannot run in parallel since we're using stdout - test.StdoutMutex.Lock() - defer test.StdoutMutex.Unlock() - - stdout := os.Stdout - r, w, _ := os.Pipe() - os.Stdout = w + releaseStdout := test.CaptureStdout() // WHEN Print is called with these flags config.Print(&tc.flag) // THEN config is printed onlt when the flag is true - w.Close() - out, _ := io.ReadAll(r) - os.Stdout = stdout - got := strings.Count(string(out), "\n") + stdout := releaseStdout() + got := strings.Count(stdout, "\n") if got != tc.lines { t.Errorf("Print with %s wants %d lines but got %d\n%s", - name, tc.lines, got, string(out)) + name, tc.lines, got, stdout) } }) } diff --git a/db/init_test.go b/db/init_test.go index 8f2f0a99..40ab3f5a 100644 --- a/db/init_test.go +++ b/db/init_test.go @@ -19,7 +19,6 @@ package db import ( "database/sql" "fmt" - "io" "math/rand" "os" "regexp" @@ -88,6 +87,7 @@ func TestCheckFile(t *testing.T) { if tc.panicRegex != "" { defer func() { r := recover() + rStr := fmt.Sprint(r) re := regexp.MustCompile(tc.panicRegex) match := re.MatchString(rStr) @@ -387,12 +387,7 @@ func Test_UpdateColumnTypes(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { // t.Parallel() - Cannot run in parallel since we're using stdout - test.StdoutMutex.Lock() - defer test.StdoutMutex.Unlock() - - stdout := os.Stdout - r, w, _ := os.Pipe() - os.Stdout = w + releaseStdout := test.CaptureStdout() databaseFile := "Test_UpdateColumnTypes.db" db, err := sql.Open("sqlite", databaseFile) @@ -457,19 +452,16 @@ func Test_UpdateColumnTypes(t *testing.T) { latest_version, latest_version_timestamp, deployed_version, deployed_version_timestamp, approved_version, got.LatestVersion(), got.LatestVersionTimestamp(), got.DeployedVersion(), got.DeployedVersionTimestamp(), got.ApprovedVersion()) } - // AND the conversion was output to stdout - w.Close() - out, _ := io.ReadAll(r) - os.Stdout = stdout - output := string(out) + // AND the conversion was printed to stdout + stdout := releaseStdout() want := "Finished updating column types" - contains := strings.Contains(output, want) + contains := strings.Contains(stdout, want) if tc.columnType == "TEXT" && contains { t.Errorf("Table started as %q, so should not have been updated. Got %q", - tc.columnType, output) + tc.columnType, stdout) } else if tc.columnType == "STRING" && !contains { t.Errorf("Table started as %q, so should have been updated. Got %q", - tc.columnType, output) + tc.columnType, stdout) } }) } diff --git a/notifiers/shoutrrr/verify_test.go b/notifiers/shoutrrr/verify_test.go index d60a7b9e..53b8792d 100644 --- a/notifiers/shoutrrr/verify_test.go +++ b/notifiers/shoutrrr/verify_test.go @@ -17,8 +17,6 @@ package shoutrrr import ( - "io" - "os" "regexp" "strings" "testing" @@ -1385,28 +1383,21 @@ notify: for name, tc := range tests { t.Run(name, func(t *testing.T) { // t.Parallel() - Cannot run in parallel since we're using stdout - test.StdoutMutex.Lock() - defer test.StdoutMutex.Unlock() + releaseStdout := test.CaptureStdout() if tc.want != "" { tc.want += "\n" } - stdout := os.Stdout - r, w, _ := os.Pipe() - os.Stdout = w // WHEN Print is called tc.slice.Print("") - // THEN it prints the expected output - w.Close() - out, _ := io.ReadAll(r) - os.Stdout = stdout - strOut := string(out) + // THEN it prints the expected stdout + stdout := releaseStdout() tc.want = strings.TrimPrefix(tc.want, "\n") - if strOut != tc.want { + if stdout != tc.want { t.Errorf("Print should have given\n%q\nbut gave\n%q", - tc.want, strOut) + tc.want, stdout) } }) } diff --git a/service/deployed_version/query_test.go b/service/deployed_version/query_test.go index 6e653928..4517c057 100644 --- a/service/deployed_version/query_test.go +++ b/service/deployed_version/query_test.go @@ -17,7 +17,6 @@ package deployedver import ( - "io" "os" "regexp" "testing" @@ -437,19 +436,12 @@ func TestLookup_Track(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { // t.Parallel() - Cannot run in parallel since we're using stdout - test.StdoutMutex.Lock() - defer test.StdoutMutex.Unlock() + releaseStdout := test.CaptureStdout() for k, v := range tc.env { os.Setenv(k, v) defer os.Unsetenv(k) } - stdout := os.Stdout - r, w, _ := os.Pipe() - os.Stdout = w - defer func() { - os.Stdout = stdout - }() if tc.lookup != nil { tc.lookup.AllowInvalidCerts = boolPtr(tc.allowInvalidCerts) tc.lookup.BasicAuth = tc.basicAuth @@ -498,6 +490,7 @@ func TestLookup_Track(t *testing.T) { t.Fatalf("expected Track to finish in %s, but it didn't", tc.wait) } + releaseStdout() return } haveQueried := false @@ -510,10 +503,8 @@ func TestLookup_Track(t *testing.T) { time.Sleep(time.Second) } time.Sleep(5 * time.Second) - w.Close() - out, _ := io.ReadAll(r) - os.Stdout = stdout - t.Log(string(out)) + stdout := releaseStdout() + t.Log(stdout) if tc.wantDeployedVersion != tc.lookup.Status.DeployedVersion() { t.Errorf("expected DeployedVersion to be %q after query, not %q", tc.wantDeployedVersion, tc.lookup.Status.DeployedVersion()) diff --git a/service/latest_version/query_test.go b/service/latest_version/query_test.go index 2e8420f3..3b5d2f00 100644 --- a/service/latest_version/query_test.go +++ b/service/latest_version/query_test.go @@ -17,7 +17,6 @@ package latestver import ( - "io" "os" "regexp" "strings" @@ -94,7 +93,7 @@ func TestLookup_Query(t *testing.T) { requireRegexVersion string requireCommand []string requireDockerCheck *filter.DockerCheck - outputRegex string + stdoutRegex string errRegex string }{ "invalid url": { @@ -195,7 +194,7 @@ func TestLookup_Query(t *testing.T) { githubService: true, url: "go-vikunja/api", regex: stringPtr("v([0-9.]+)"), - outputRegex: `no tags found on /releases, trying /tags`, + stdoutRegex: `no tags found on /releases, trying /tags`, }, "github lookup with no access token": { githubService: true, @@ -216,15 +215,11 @@ func TestLookup_Query(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { // t.Parallel() - Cannot run in parallel since we're using stdout - test.StdoutMutex.Lock() - defer test.StdoutMutex.Unlock() try := 0 temporaryFailureInNameResolution := true for temporaryFailureInNameResolution != false { - stdout := os.Stdout - r, w, _ := os.Pipe() - os.Stdout = w + releaseStdout := test.CaptureStdout() try++ temporaryFailureInNameResolution = false lookup := testLookup(!tc.githubService, tc.allowInvalidCerts) @@ -255,9 +250,7 @@ func TestLookup_Query(t *testing.T) { _, err := lookup.Query(true, &util.LogFrom{}) // THEN any err is expected - w.Close() - out, _ := io.ReadAll(r) - os.Stdout = stdout + stdout := releaseStdout() e := util.ErrorToString(err) if tc.errRegex == "" { tc.errRegex = "^$" @@ -275,13 +268,12 @@ func TestLookup_Query(t *testing.T) { t.Fatalf("want match for %q\nnot: %q", tc.errRegex, e) } - // AND the output contains the expected strings - output := string(out) - re = regexp.MustCompile(tc.outputRegex) - match = re.MatchString(output) + // AND the stdout contains the expected strings + re = regexp.MustCompile(tc.stdoutRegex) + match = re.MatchString(stdout) if !match { t.Fatalf("match for %q not found in:\n%q", - tc.outputRegex, output) + tc.stdoutRegex, stdout) } // AND the LatestVersion is as expected if tc.wantLatestVersion != nil && @@ -296,8 +288,6 @@ func TestLookup_Query(t *testing.T) { func TestLookup_Query__EmptyListETagChanged(t *testing.T) { // t.Parallel() - Cannot run in parallel since we're using stdout - test.StdoutMutex.Lock() - defer test.StdoutMutex.Unlock() // Lock so that default empty list ETag isn't changed by other tests emptyListETagTestMutex.Lock() @@ -308,9 +298,7 @@ func TestLookup_Query__EmptyListETagChanged(t *testing.T) { try := 0 temporaryFailureInNameResolution := true for temporaryFailureInNameResolution != false { - stdout := os.Stdout - r, w, _ := os.Pipe() - os.Stdout = w + releaseStdout := test.CaptureStdout() try++ setEmptyListETag(invalidETag) temporaryFailureInNameResolution = false @@ -322,9 +310,7 @@ func TestLookup_Query__EmptyListETagChanged(t *testing.T) { _, err := lookup.Query(true, &util.LogFrom{}) // THEN any err is expected - w.Close() - out, _ := io.ReadAll(r) - os.Stdout = stdout + stdout := releaseStdout() e := util.ErrorToString(err) errRegex := "^$" re := regexp.MustCompile(errRegex) @@ -340,14 +326,13 @@ func TestLookup_Query__EmptyListETagChanged(t *testing.T) { t.Fatalf("want match for %q\nnot: %q", errRegex, e) } - // AND the output contains the expected strings - output := string(out) + // AND the stdout contains the expected strings wantOutputRegex := `/releases gave \[\], trying /tags` re = regexp.MustCompile(wantOutputRegex) - match = re.MatchString(output) + match = re.MatchString(stdout) if !match { t.Fatalf("match for %q not found in:\n%q", - wantOutputRegex, output) + wantOutputRegex, stdout) } } } @@ -387,8 +372,7 @@ no releases were found matching the url_commands and/or require`}, for name, tc := range tests { t.Run(name, func(t *testing.T) { // t.Parallel() - Cannot run in parallel since we're using stdout - test.StdoutMutex.Lock() - defer test.StdoutMutex.Unlock() + releaseStdout := test.CaptureStdout() lookup := testLookup(false, false) lookup.GitHubData.SetETag("foo") @@ -396,9 +380,6 @@ no releases were found matching the url_commands and/or require`}, lookup.Require.RegexVersion = tc.initialRequireRegexVersion lookup.URLCommands = tc.urlCommands - stdout := os.Stdout - r, w, _ := os.Pipe() - os.Stdout = w attempt := 0 // WHEN Query is called on it attempts number of times var errors string = "" @@ -417,10 +398,7 @@ no releases were found matching the url_commands and/or require`}, } // THEN any err is expected - w.Close() - o, _ := io.ReadAll(r) - out := string(o) - os.Stdout = stdout + stdout := releaseStdout() tc.errRegex = strings.ReplaceAll(tc.errRegex, "\n", "--") re := regexp.MustCompile(tc.errRegex) match := re.MatchString(errors) @@ -428,15 +406,15 @@ no releases were found matching the url_commands and/or require`}, t.Errorf("want match for %q\nnot: %q", tc.errRegex, errors) } - gotETagChanged := strings.Count(out, "new ETag") + gotETagChanged := strings.Count(stdout, "new ETag") if gotETagChanged != tc.eTagChanged { t.Errorf("new ETag - got=%d, want=%d\n%s", - gotETagChanged, tc.eTagChanged, out) + gotETagChanged, tc.eTagChanged, stdout) } - gotETagUnchangedUseCache := strings.Count(out, "Using cached releases") + gotETagUnchangedUseCache := strings.Count(stdout, "Using cached releases") if gotETagUnchangedUseCache != tc.eTagUnchangedUseCache { t.Errorf("ETag unchanged use cache - got=%d, want=%d\n%s", - gotETagUnchangedUseCache, tc.eTagUnchangedUseCache, out) + gotETagUnchangedUseCache, tc.eTagUnchangedUseCache, stdout) } }) } diff --git a/service/verify_test.go b/service/verify_test.go index 3c91702d..53310653 100644 --- a/service/verify_test.go +++ b/service/verify_test.go @@ -17,8 +17,6 @@ package service import ( - "io" - "os" "regexp" "strings" "testing" @@ -86,29 +84,22 @@ func TestSlice_Print(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { // t.Parallel() - Cannot run in parallel since we're using stdout - test.StdoutMutex.Lock() - defer test.StdoutMutex.Unlock() + releaseStdout := test.CaptureStdout() if tc.want != "" { tc.want += "\n" } tc.want = strings.ReplaceAll(tc.want, "\t", "") - stdout := os.Stdout - r, w, _ := os.Pipe() - os.Stdout = w // WHEN Print is called tc.slice.Print("", tc.ordering) - // THEN it prints the expected output - w.Close() - out, _ := io.ReadAll(r) - os.Stdout = stdout - strOut := string(out) + // THEN it prints the expected stdout + stdout := releaseStdout() tc.want = strings.TrimPrefix(tc.want, "\n") - if strOut != tc.want { + if stdout != tc.want { t.Errorf("Print should have given\n%q\nbut gave\n%q", - tc.want, strOut) + tc.want, stdout) } }) } diff --git a/test/main.go b/test/main.go index d1c53d45..2c76e874 100644 --- a/test/main.go +++ b/test/main.go @@ -18,11 +18,26 @@ package test import ( "fmt" + "io" + "os" "strings" "sync" ) var StdoutMutex sync.Mutex // Only one test should write to stdout at a time +func CaptureStdout() func() string { + stdout := os.Stdout + r, w, _ := os.Pipe() + os.Stdout = w + StdoutMutex.Lock() + return func() string { + w.Close() + out, _ := io.ReadAll(r) + os.Stdout = stdout + StdoutMutex.Unlock() + return string(out) + } +} // BoolPtr returns a pointer to the given boolean value func BoolPtr(val bool) *bool { diff --git a/testing/commands_test.go b/testing/commands_test.go index f8bb0502..6b84f75f 100644 --- a/testing/commands_test.go +++ b/testing/commands_test.go @@ -18,8 +18,6 @@ package testing import ( "fmt" - "io" - "os" "regexp" "testing" @@ -35,12 +33,12 @@ func TestCommandTest(t *testing.T) { tests := map[string]struct { flag string slice service.Slice - outputRegex *string + stdoutRegex *string panicRegex *string }{ "flag is empty": { flag: "", - outputRegex: stringPtr("^$"), + stdoutRegex: stringPtr("^$"), slice: service.Slice{ "argus": { ID: "argus", @@ -54,7 +52,7 @@ func TestCommandTest(t *testing.T) { "unknown service in flag": { flag: "something", panicRegex: stringPtr(" could not be found "), - outputRegex: stringPtr("should have panic'd before reaching this"), + stdoutRegex: stringPtr("should have panic'd before reaching this"), slice: service.Slice{ "argus": { ID: "argus", @@ -67,7 +65,7 @@ func TestCommandTest(t *testing.T) { }, "known service in flag successful command": { flag: "argus", - outputRegex: stringPtr(`Executing 'echo command did run'\s+.*command did run\s+`), + stdoutRegex: stringPtr(`Executing 'echo command did run'\s+.*command did run\s+`), slice: service.Slice{ "argus": { ID: "argus", @@ -80,7 +78,7 @@ func TestCommandTest(t *testing.T) { }, "known service in flag failing command": { flag: "argus", - outputRegex: stringPtr(`.*Executing 'ls /root'\s+.*exit status [1-9]\s+`), + stdoutRegex: stringPtr(`.*Executing 'ls /root'\s+.*exit status [1-9]\s+`), slice: service.Slice{ "argus": { ID: "argus", @@ -94,7 +92,7 @@ func TestCommandTest(t *testing.T) { "service with no commands": { flag: "argus", panicRegex: stringPtr(" does not have any `command` defined"), - outputRegex: stringPtr("should have panic'd before reaching this"), + stdoutRegex: stringPtr("should have panic'd before reaching this"), slice: service.Slice{ "argus": { ID: "argus"}}, @@ -104,16 +102,14 @@ func TestCommandTest(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { // t.Parallel() - Cannot run in parallel since we're using stdout - test.StdoutMutex.Lock() - defer test.StdoutMutex.Unlock() + releaseStdout := test.CaptureStdout() - stdout := os.Stdout - r, w, _ := os.Pipe() - os.Stdout = w if tc.panicRegex != nil { // Switch Fatal to panic and disable this panic. defer func() { r := recover() + releaseStdout() + rStr := fmt.Sprint(r) re := regexp.MustCompile(*tc.panicRegex) match := re.MatchString(rStr) @@ -145,17 +141,14 @@ func TestCommandTest(t *testing.T) { } CommandTest(&tc.flag, &cfg, jLog) - // THEN we get the expected output - w.Close() - out, _ := io.ReadAll(r) - os.Stdout = stdout - output := string(out) - if tc.outputRegex != nil { - re := regexp.MustCompile(*tc.outputRegex) - match := re.MatchString(output) + // THEN we get the expected stdout + stdout := releaseStdout() + if tc.stdoutRegex != nil { + re := regexp.MustCompile(*tc.stdoutRegex) + match := re.MatchString(stdout) if !match { t.Errorf("want match for %q\ngot: %q", - *tc.outputRegex, output) + *tc.stdoutRegex, stdout) } } }) diff --git a/testing/service_test.go b/testing/service_test.go index cdf302fd..4083278f 100644 --- a/testing/service_test.go +++ b/testing/service_test.go @@ -18,7 +18,6 @@ package testing import ( "fmt" - "io" "os" "regexp" "testing" @@ -42,12 +41,12 @@ func TestServiceTest(t *testing.T) { tests := map[string]struct { flag string slice service.Slice - outputRegex *string + stdoutRegex *string panicRegex *string }{ "flag is empty": { flag: "", - outputRegex: stringPtr("^$"), + stdoutRegex: stringPtr("^$"), slice: service.Slice{ "argus": { ID: "argus", @@ -71,7 +70,7 @@ func TestServiceTest(t *testing.T) { }, "github service": { flag: "argus", - outputRegex: stringPtr(`argus\)?, Latest Release - "[0-9]+\.[0-9]+\.[0-9]+"`), + stdoutRegex: stringPtr(`argus\)?, Latest Release - "[0-9]+\.[0-9]+\.[0-9]+"`), slice: service.Slice{ "argus": { LatestVersion: *latestver.New( @@ -91,7 +90,7 @@ func TestServiceTest(t *testing.T) { }, "url service type but github owner/repo url": { flag: "argus", - outputRegex: stringPtr("This URL looks to be a GitHub repo, but the service's type is url, not github"), + stdoutRegex: stringPtr("This URL looks to be a GitHub repo, but the service's type is url, not github"), slice: service.Slice{ "argus": { ID: "argus", @@ -111,7 +110,7 @@ func TestServiceTest(t *testing.T) { }, "url service": { flag: "argus", - outputRegex: stringPtr(`Latest Release - "[0-9]+\.[0-9]+\.[0-9]+"`), + stdoutRegex: stringPtr(`Latest Release - "[0-9]+\.[0-9]+\.[0-9]+"`), slice: service.Slice{ "argus": { ID: "argus", @@ -131,7 +130,7 @@ func TestServiceTest(t *testing.T) { }, "service with deployed version lookup": { flag: "argus", - outputRegex: stringPtr(`Latest Release - "[0-9]+\.[0-9]+\.[0-9]+"\s.*Deployed version - "[0-9]+\.[0-9]+\.[0-9]+"`), + stdoutRegex: stringPtr(`Latest Release - "[0-9]+\.[0-9]+\.[0-9]+"\s.*Deployed version - "[0-9]+\.[0-9]+\.[0-9]+"`), slice: service.Slice{ "argus": { ID: "argus", @@ -161,20 +160,15 @@ func TestServiceTest(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { // t.Parallel() - Cannot run in parallel since we're using stdout - test.StdoutMutex.Lock() - defer test.StdoutMutex.Unlock() + releaseStdout := test.CaptureStdout() - stdout := os.Stdout - r, w, _ := os.Pipe() - os.Stdout = w if tc.panicRegex != nil { // Switch Fatal to panic and disable this panic. defer func() { - // Reset stdout - w.Close() - os.Stdout = stdout - // Check the panic message r := recover() + releaseStdout() + + // Check the panic message rStr := fmt.Sprint(r) re := regexp.MustCompile(*tc.panicRegex) match := re.MatchString(rStr) @@ -212,17 +206,14 @@ func TestServiceTest(t *testing.T) { } ServiceTest(&tc.flag, &cfg, jLog) - // THEN we get the expected output - w.Close() - out, _ := io.ReadAll(r) - os.Stdout = stdout - output := string(out) - if tc.outputRegex != nil { - re := regexp.MustCompile(*tc.outputRegex) - match := re.MatchString(output) + // THEN we get the expected stdout + stdout := releaseStdout() + if tc.stdoutRegex != nil { + re := regexp.MustCompile(*tc.stdoutRegex) + match := re.MatchString(stdout) if !match { t.Errorf("want match on %q\ngot:\n%s", - *tc.outputRegex, output) + *tc.stdoutRegex, stdout) } } }) diff --git a/testing/shoutrrr_test.go b/testing/shoutrrr_test.go index 62cec4bc..8c4f6529 100644 --- a/testing/shoutrrr_test.go +++ b/testing/shoutrrr_test.go @@ -18,8 +18,6 @@ package testing import ( "fmt" - "io" - "os" "regexp" "testing" @@ -119,7 +117,7 @@ func TestFindShoutrrr(t *testing.T) { tests := map[string]struct { flag string cfg *config.Config - outputRegex *string + stdoutRegex *string panicRegex *string foundInRoot *bool }{ @@ -172,7 +170,7 @@ func TestFindShoutrrr(t *testing.T) { }, "matching search of notifier in Root": { flag: "bosh", - outputRegex: stringPtr("^$"), + stdoutRegex: stringPtr("^$"), cfg: &config.Config{ Service: service.Slice{ "argus": { @@ -189,7 +187,7 @@ func TestFindShoutrrr(t *testing.T) { }, "matching search of notifier in Service": { flag: "baz", - outputRegex: stringPtr("^$"), + stdoutRegex: stringPtr("^$"), cfg: &config.Config{ Service: service.Slice{ "argus": { @@ -206,7 +204,7 @@ func TestFindShoutrrr(t *testing.T) { }, "matching search of notifier in Root and a Service": { flag: "bar", - outputRegex: stringPtr("^$"), + stdoutRegex: stringPtr("^$"), cfg: &config.Config{ Service: service.Slice{ "argus": { @@ -384,16 +382,14 @@ func TestFindShoutrrr(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { // t.Parallel() - Cannot run in parallel since we're using stdout - test.StdoutMutex.Lock() - defer test.StdoutMutex.Unlock() + releaseStdout := test.CaptureStdout() - stdout := os.Stdout - r, w, _ := os.Pipe() - os.Stdout = w if tc.panicRegex != nil { // Switch Fatal to panic and disable this panic. defer func() { r := recover() + releaseStdout() + rStr := fmt.Sprint(r) re := regexp.MustCompile(*tc.panicRegex) match := re.MatchString(rStr) @@ -414,17 +410,14 @@ func TestFindShoutrrr(t *testing.T) { // WHEN findShoutrrr is called with the test Config got := findShoutrrr(tc.flag, tc.cfg, jLog, &util.LogFrom{}) - // THEN we get the expected output - w.Close() - out, _ := io.ReadAll(r) - os.Stdout = stdout - output := string(out) - if tc.outputRegex != nil { - re := regexp.MustCompile(*tc.outputRegex) - match := re.MatchString(output) + // THEN we get the expected stdout + stdout := releaseStdout() + if tc.stdoutRegex != nil { + re := regexp.MustCompile(*tc.stdoutRegex) + match := re.MatchString(stdout) if !match { t.Fatalf("want match for %q\nnot: %q", - *tc.outputRegex, output) + *tc.stdoutRegex, stdout) } } // if the notifier should have been found in the root or in a service @@ -465,11 +458,11 @@ func TestNotifyTest(t *testing.T) { flag string slice service.Slice rootSlice shoutrrr.SliceDefaults - outputRegex *string + stdoutRegex *string panicRegex *string }{ "empty flag": {flag: "", - outputRegex: stringPtr("^$"), + stdoutRegex: stringPtr("^$"), slice: service.Slice{ "argus": { Notify: shoutrrr.Slice{ @@ -548,12 +541,8 @@ func TestNotifyTest(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { // t.Parallel() - Cannot run in parallel since we're using stdout - test.StdoutMutex.Lock() - defer test.StdoutMutex.Unlock() + releaseStdout := test.CaptureStdout() - stdout := os.Stdout - r, w, _ := os.Pipe() - os.Stdout = w serviceHardDefaults := service.Defaults{} serviceHardDefaults.SetDefaults() shoutrrrHardDefaults := shoutrrr.SliceDefaults{} @@ -568,6 +557,8 @@ func TestNotifyTest(t *testing.T) { // Switch Fatal to panic and disable this panic. defer func() { r := recover() + releaseStdout() + rStr := fmt.Sprint(r) re := regexp.MustCompile(*tc.panicRegex) match := re.MatchString(rStr) @@ -584,17 +575,14 @@ func TestNotifyTest(t *testing.T) { Notify: tc.rootSlice} NotifyTest(&tc.flag, &cfg, jLog) - // THEN we get the expected output - w.Close() - out, _ := io.ReadAll(r) - os.Stdout = stdout - output := string(out) - if tc.outputRegex != nil { - re := regexp.MustCompile(*tc.outputRegex) - match := re.MatchString(output) + // THEN we get the expected stdout + stdout := releaseStdout() + if tc.stdoutRegex != nil { + re := regexp.MustCompile(*tc.stdoutRegex) + match := re.MatchString(stdout) if !match { t.Errorf("want match for %q\nnot: %q", - *tc.outputRegex, output) + *tc.stdoutRegex, stdout) } } }) diff --git a/util/log_test.go b/util/log_test.go index 3341d093..ac15a6d9 100644 --- a/util/log_test.go +++ b/util/log_test.go @@ -19,9 +19,7 @@ package util import ( "bytes" "fmt" - "io" "log" - "os" "regexp" "strings" "testing" @@ -97,6 +95,7 @@ func TestSetLevel(t *testing.T) { // Switch Fatal to panic and disable this panic. defer func() { r := recover() + rStr := fmt.Sprint(r) re := regexp.MustCompile(*tc.panicRegex) match := re.MatchString(rStr) @@ -309,13 +308,9 @@ func TestJLog_Error(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { // t.Parallel() - Cannot run in parallel since we're using stdout - test.StdoutMutex.Lock() - defer test.StdoutMutex.Unlock() + releaseStdout := test.CaptureStdout() jLog := NewJLog(tc.level, tc.timestamps) - stdout := os.Stdout - r, w, _ := os.Pipe() - os.Stdout = w var logOut bytes.Buffer log.SetOutput(&logOut) @@ -323,13 +318,10 @@ func TestJLog_Error(t *testing.T) { jLog.Error(fmt.Errorf(msg), &LogFrom{}, tc.otherCondition) // THEN msg was logged if shouldPrint, with/without timestamps - w.Close() - out, _ := io.ReadAll(r) - got := string(out) - os.Stdout = stdout + stdout := releaseStdout() var regex string if tc.timestamps { - got = logOut.String() + stdout = logOut.String() regex = fmt.Sprintf("^[0-9]{4}\\/[0-9]{2}\\/[0-9]{2} [0-9]{2}:[0-9]{2}:[0-9]{2} ERROR: %s\n$", msg) } else if !tc.otherCondition { regex = "^$" @@ -337,10 +329,10 @@ func TestJLog_Error(t *testing.T) { regex = fmt.Sprintf("^ERROR: %s\n$", msg) } reg := regexp.MustCompile(regex) - match := reg.MatchString(got) + match := reg.MatchString(stdout) if !match { t.Errorf("ERROR printed didn't match %q\nGot %q", - regex, got) + regex, stdout) } }) } @@ -390,13 +382,9 @@ func TestJLog_Warn(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { // t.Parallel() - Cannot run in parallel since we're using stdout - test.StdoutMutex.Lock() - defer test.StdoutMutex.Unlock() + releaseStdout := test.CaptureStdout() jLog := NewJLog(tc.level, tc.timestamps) - stdout := os.Stdout - r, w, _ := os.Pipe() - os.Stdout = w var logOut bytes.Buffer log.SetOutput(&logOut) @@ -404,24 +392,21 @@ func TestJLog_Warn(t *testing.T) { jLog.Warn(fmt.Errorf(msg), &LogFrom{}, tc.otherCondition) // THEN msg was logged if shouldPrint, with/without timestamps - w.Close() - out, _ := io.ReadAll(r) - got := string(out) - os.Stdout = stdout + stdout := releaseStdout() var regex string if !tc.shouldPrint { regex = "^$" } else if tc.timestamps { - got = logOut.String() + stdout = logOut.String() regex = fmt.Sprintf("^[0-9]{4}\\/[0-9]{2}\\/[0-9]{2} [0-9]{2}:[0-9]{2}:[0-9]{2} WARNING: %s\n$", msg) } else { regex = fmt.Sprintf("^WARNING: %s\n$", msg) } reg := regexp.MustCompile(regex) - match := reg.MatchString(got) + match := reg.MatchString(stdout) if !match { t.Errorf("WARNING printed didn't match %q\nGot %q", - regex, got) + regex, stdout) } }) } @@ -471,13 +456,9 @@ func TestJLog_Info(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { // t.Parallel() - Cannot run in parallel since we're using stdout - test.StdoutMutex.Lock() - defer test.StdoutMutex.Unlock() + releaseStdout := test.CaptureStdout() jLog := NewJLog(tc.level, tc.timestamps) - stdout := os.Stdout - r, w, _ := os.Pipe() - os.Stdout = w var logOut bytes.Buffer log.SetOutput(&logOut) @@ -485,24 +466,21 @@ func TestJLog_Info(t *testing.T) { jLog.Info(fmt.Errorf(msg), &LogFrom{}, tc.otherCondition) // THEN msg was logged if shouldPrint, with/without timestamps - w.Close() - out, _ := io.ReadAll(r) - got := string(out) - os.Stdout = stdout + stdout := releaseStdout() var regex string if !tc.shouldPrint { regex = "^$" } else if tc.timestamps { - got = logOut.String() + stdout = logOut.String() regex = fmt.Sprintf("^[0-9]{4}\\/[0-9]{2}\\/[0-9]{2} [0-9]{2}:[0-9]{2}:[0-9]{2} INFO: %s\n$", msg) } else { regex = fmt.Sprintf("^INFO: %s\n$", msg) } reg := regexp.MustCompile(regex) - match := reg.MatchString(got) + match := reg.MatchString(stdout) if !match { t.Errorf("INFO printed didn't match %q\nGot %q", - regex, got) + regex, stdout) } }) } @@ -557,15 +535,11 @@ func TestJLog_Verbose(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { // t.Parallel() - Cannot run in parallel since we're using stdout - test.StdoutMutex.Lock() - defer test.StdoutMutex.Unlock() + releaseStdout := test.CaptureStdout() msg := "argus" jLog := NewJLog(tc.level, tc.timestamps) - stdout := os.Stdout - r, w, _ := os.Pipe() - os.Stdout = w var logOut bytes.Buffer log.SetOutput(&logOut) if tc.customMsg != nil { @@ -576,10 +550,7 @@ func TestJLog_Verbose(t *testing.T) { jLog.Verbose(fmt.Errorf(msg), &LogFrom{}, tc.otherCondition) // THEN msg was logged if shouldPrint, with/without timestamps - w.Close() - out, _ := io.ReadAll(r) - got := string(out) - os.Stdout = stdout + stdout := releaseStdout() var regex string if tc.customMsg != nil && len(*tc.customMsg) > tc.expectedLength { msg = (*tc.customMsg)[:tc.expectedLength] + "..." @@ -587,25 +558,25 @@ func TestJLog_Verbose(t *testing.T) { if !tc.shouldPrint { regex = "^$" } else if tc.timestamps { - got = logOut.String() + stdout = logOut.String() regex = fmt.Sprintf("^[0-9]{4}\\/[0-9]{2}\\/[0-9]{2} [0-9]{2}:[0-9]{2}:[0-9]{2} VERBOSE: %s\n$", msg) } else { regex = fmt.Sprintf("^VERBOSE: %s\n$", msg) } reg := regexp.MustCompile(regex) - match := reg.MatchString(got) + match := reg.MatchString(stdout) if !match { t.Errorf("VERBOSE printed didn't match %q\nGot %q", - regex, got) + regex, stdout) } if tc.customMsg != nil { tc.expectedLength += len("VERBOSE: ") - if strings.HasSuffix(got, "...\n") { + if strings.HasSuffix(stdout, "...\n") { tc.expectedLength += len("...\n") } - if len(got) != tc.expectedLength { + if len(stdout) != tc.expectedLength { t.Errorf("VERBOSE message length not limited to %d\nGot %d\n%q", - tc.expectedLength, len(got), got) + tc.expectedLength, len(stdout), stdout) } } }) @@ -662,13 +633,9 @@ func TestJLog_Debug(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { // t.Parallel() - Cannot run in parallel since we're using stdout - test.StdoutMutex.Lock() - defer test.StdoutMutex.Unlock() + releaseStdout := test.CaptureStdout() jLog := NewJLog(tc.level, tc.timestamps) - stdout := os.Stdout - r, w, _ := os.Pipe() - os.Stdout = w var logOut bytes.Buffer log.SetOutput(&logOut) if tc.customMsg != nil { @@ -679,10 +646,7 @@ func TestJLog_Debug(t *testing.T) { jLog.Debug(fmt.Errorf(msg), &LogFrom{}, tc.otherCondition) // THEN msg was logged if shouldPrint, with/without timestamps - w.Close() - out, _ := io.ReadAll(r) - got := string(out) - os.Stdout = stdout + stdout := releaseStdout() var regex string if tc.customMsg != nil && len(*tc.customMsg) > tc.expectedLength { msg = (*tc.customMsg)[:tc.expectedLength] + "..." @@ -690,25 +654,25 @@ func TestJLog_Debug(t *testing.T) { if !tc.shouldPrint { regex = "^$" } else if tc.timestamps { - got = logOut.String() + stdout = logOut.String() regex = fmt.Sprintf("^[0-9]{4}\\/[0-9]{2}\\/[0-9]{2} [0-9]{2}:[0-9]{2}:[0-9]{2} DEBUG: %s\n$", msg) } else { regex = fmt.Sprintf("^DEBUG: %s\n$", msg) } reg := regexp.MustCompile(regex) - match := reg.MatchString(got) + match := reg.MatchString(stdout) if !match { t.Errorf("DEBUG printed didn't match %q\nGot %q", - regex, got) + regex, stdout) } if tc.customMsg != nil { tc.expectedLength += len("DEBUG: ") - if strings.HasSuffix(got, "...\n") { + if strings.HasSuffix(stdout, "...\n") { tc.expectedLength += len("...\n") } - if len(got) != tc.expectedLength { + if len(stdout) != tc.expectedLength { t.Errorf("DEBUG message length not limited to %d\nGot %d\n%q", - tc.expectedLength, len(got), got) + tc.expectedLength, len(stdout), stdout) } } }) @@ -759,35 +723,27 @@ func TestJLog_Fatal(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { // t.Parallel() - Cannot run in parallel since we're using stdout - test.StdoutMutex.Lock() - defer test.StdoutMutex.Unlock() + releaseStdout := test.CaptureStdout() jLog := NewJLog(tc.level, tc.timestamps) - stdout := os.Stdout - r, w, _ := os.Pipe() - os.Stdout = w - defer func() { - os.Stdout = stdout - }() var logOut bytes.Buffer log.SetOutput(&logOut) if tc.shouldPrint { jLog.Testing = true defer func() { - _ = recover() + recover() + stdout := releaseStdout() + regex := fmt.Sprintf("^ERROR: %s\n$", msg) - w.Close() - out, _ := io.ReadAll(r) - got := string(out) if tc.timestamps { - got = logOut.String() + stdout = logOut.String() regex = fmt.Sprintf("^[0-9]{4}\\/[0-9]{2}\\/[0-9]{2} [0-9]{2}:[0-9]{2}:[0-9]{2} ERROR: %s\n$", msg) } reg := regexp.MustCompile(regex) - match := reg.MatchString(got) + match := reg.MatchString(stdout) if !match { t.Errorf("ERROR wasn't printed/didn't match %q\nGot %q", - regex, got) + regex, stdout) } }() } @@ -796,16 +752,13 @@ func TestJLog_Fatal(t *testing.T) { jLog.Fatal(fmt.Errorf(msg), &LogFrom{}, tc.otherCondition) // THEN msg was logged if shouldPrint, with/without timestamps - w.Close() - out, _ := io.ReadAll(r) - got := string(out) - os.Stdout = stdout + stdout := releaseStdout() regex := "^$" reg := regexp.MustCompile(regex) - match := reg.MatchString(got) + match := reg.MatchString(stdout) if !match { t.Errorf("ERROR printed didn't match %q\nGot %q", - regex, got) + regex, stdout) } }) } diff --git a/util/template_test.go b/util/template_test.go index 2465c427..8c09218b 100644 --- a/util/template_test.go +++ b/util/template_test.go @@ -49,6 +49,7 @@ func TestTemplate_String(t *testing.T) { // Switch Fatal to panic and disable this panic. defer func() { r := recover() + rStr := fmt.Sprint(r) re := regexp.MustCompile(*tc.panicRegex) match := re.MatchString(rStr) diff --git a/util/util_test.go b/util/util_test.go index 45355286..71bc0d99 100644 --- a/util/util_test.go +++ b/util/util_test.go @@ -20,7 +20,6 @@ import ( "crypto/sha256" "encoding/json" "fmt" - "io" "os" "regexp" "strings" @@ -632,32 +631,25 @@ func TestPrintlnIfNotDefault(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { // t.Parallel() - Cannot run in parallel since we're using stdout - test.StdoutMutex.Lock() - defer test.StdoutMutex.Unlock() + releaseStdout := test.CaptureStdout() msg := "var is not default from PrintlnIfNotDefault" - stdout := os.Stdout - r, w, _ := os.Pipe() - os.Stdout = w // WHEN PrintlnIfNotDefault is called PrintlnIfNotDefault(tc.element, msg) // THEN the var is printed when it should be - w.Close() - out, _ := io.ReadAll(r) - got := string(out) - os.Stdout = stdout + stdout := releaseStdout() if !tc.didPrint { - if got != "" { + if stdout != "" { t.Fatalf("printed %q", - got) + stdout) } return } - if got != msg+"\n" { + if stdout != msg+"\n" { t.Errorf("unexpected print %q", - got) + stdout) } }) } @@ -678,32 +670,25 @@ func TestPrintlnIfNotNil(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { // t.Parallel() - Cannot run in parallel since we're using stdout - test.StdoutMutex.Lock() - defer test.StdoutMutex.Unlock() + releaseStdout := test.CaptureStdout() msg := "var is not default from PrintlnIfNotNil" - stdout := os.Stdout - r, w, _ := os.Pipe() - os.Stdout = w // WHEN PrintlnIfNotNil is called PrintlnIfNotNil(tc.element, msg) // THEN the var is printed when it should be - w.Close() - out, _ := io.ReadAll(r) - got := string(out) - os.Stdout = stdout + stdout := releaseStdout() if !tc.didPrint { - if got != "" { + if stdout != "" { t.Fatalf("printed %q", - got) + stdout) } return } - if got != msg+"\n" { + if stdout != msg+"\n" { t.Errorf("unexpected print %q", - got) + stdout) } }) } @@ -724,32 +709,25 @@ func TestPrintlnIfNil(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { // t.Parallel() - Cannot run in parallel since we're using stdout - test.StdoutMutex.Lock() - defer test.StdoutMutex.Unlock() + releaseStdout := test.CaptureStdout() msg := "var is not default from PrintlnIfNil" - stdout := os.Stdout - r, w, _ := os.Pipe() - os.Stdout = w // WHEN PrintlnIfNil is called PrintlnIfNil(tc.element, msg) // THEN the var is printed when it should be - w.Close() - out, _ := io.ReadAll(r) - got := string(out) - os.Stdout = stdout + stdout := releaseStdout() if !tc.didPrint { - if got != "" { + if stdout != "" { t.Fatalf("printed %q", - got) + stdout) } return } - if got != msg+"\n" { + if stdout != msg+"\n" { t.Errorf("unexpected print %q", - got) + stdout) } }) } diff --git a/web/api/v1/http-api-actions_test.go b/web/api/v1/http-api-actions_test.go index c76aec7e..a5b8c4a7 100644 --- a/web/api/v1/http-api-actions_test.go +++ b/web/api/v1/http-api-actions_test.go @@ -108,8 +108,7 @@ func TestHTTP_httpServiceGetActions(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { // t.Parallel() - Cannot run in parallel since we're using stdout - test.StdoutMutex.Lock() - defer test.StdoutMutex.Unlock() + releaseStdout := test.CaptureStdout() if tc.statusCode == 0 { tc.statusCode = http.StatusOK @@ -154,9 +153,6 @@ func TestHTTP_httpServiceGetActions(t *testing.T) { cfg.Order = append(cfg.Order, name) cfg.OrderMutex.Unlock() defer cfg.DeleteService(name) - stdout := os.Stdout - r, w, _ := os.Pipe() - os.Stdout = w target := "/api/v1/service/actions/" target += url.QueryEscape(tc.serviceID) @@ -171,18 +167,15 @@ func TestHTTP_httpServiceGetActions(t *testing.T) { defer res.Body.Close() // THEN we get the expected response - w.Close() - out, _ := io.ReadAll(r) - os.Stdout = stdout - output := string(out) + stdout := releaseStdout() // stdout finishes if tc.stdoutRegex != "" { tc.stdoutRegex = strings.ReplaceAll(tc.stdoutRegex, "__name__", name) re := regexp.MustCompile(tc.stdoutRegex) - match := re.MatchString(output) + match := re.MatchString(stdout) if !match { t.Errorf("match on %q not found in\n%q", - tc.stdoutRegex, output) + tc.stdoutRegex, stdout) } } message, _ := io.ReadAll(res.Body) @@ -432,8 +425,7 @@ func TestHTTP_httpServiceRunActions(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { // t.Parallel() - Cannot run in parallel since we're using stdout - test.StdoutMutex.Lock() - defer test.StdoutMutex.Unlock() + releaseStdout := test.CaptureStdout() tc.serviceID = strings.ReplaceAll(tc.serviceID, "__name__", name) svc := testService(tc.serviceID) @@ -489,10 +481,6 @@ func TestHTTP_httpServiceRunActions(t *testing.T) { api.Config.Order = append(api.Config.Order, name) api.Config.OrderMutex.Unlock() defer api.Config.DeleteService(name) - // Stdout setup - stdout := os.Stdout - r, w, _ := os.Pipe() - os.Stdout = w // WHEN the HTTP request is sent to run the action(s) target := tc.target @@ -570,11 +558,8 @@ func TestHTTP_httpServiceRunActions(t *testing.T) { for expecting != 0 { message := <-*api.Config.HardDefaults.Service.Status.AnnounceChannel if message == nil { - w.Close() - out, _ := io.ReadAll(r) - os.Stdout = stdout - output := string(out) - t.Log(time.Now(), output) + stdout := releaseStdout() + t.Log(time.Now(), stdout) t.Errorf("expecting %d more messages but got %v", expecting, message) return @@ -592,21 +577,18 @@ func TestHTTP_httpServiceRunActions(t *testing.T) { t.Fatalf("wasn't expecting another message but got one\n%#v\n%s", extraMessages, string(raw)) } - w.Close() - out, _ := io.ReadAll(r) - os.Stdout = stdout - output := string(out) + stdout := releaseStdout() // stdout finishes if tc.stdoutRegex != "" { re := regexp.MustCompile(tc.stdoutRegex) - match := re.MatchString(output) + match := re.MatchString(stdout) if !match { t.Errorf("match on %q not found in\n%q", - tc.stdoutRegex, output) + tc.stdoutRegex, stdout) } return } - t.Log(output) + t.Log(stdout) // Check version was skipped if util.DefaultIfNil(tc.target) == "ARGUS_SKIP" { if tc.wantSkipMessage && diff --git a/web/web_test.go b/web/web_test.go index 5fc3e2e7..9457d696 100644 --- a/web/web_test.go +++ b/web/web_test.go @@ -19,7 +19,6 @@ package web import ( "encoding/json" "fmt" - "io" "net/http" "net/http/httptest" "os" @@ -240,13 +239,9 @@ func TestWebSocket(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { // t.Parallel() - Cannot run in parallel since we're using stdout - test.StdoutMutex.Lock() - defer test.StdoutMutex.Unlock() + releaseStdout := test.CaptureStdout() ws := connectToWebSocket(t) - stdout := os.Stdout - r, w, _ := os.Pipe() - os.Stdout = w // WHEN we send a message if err := ws.WriteMessage(websocket.TextMessage, []byte(tc.msg)); err != nil { @@ -258,15 +253,12 @@ func TestWebSocket(t *testing.T) { // THEN we receive the expected response ws.Close() time.Sleep(250 * time.Millisecond) - w.Close() - out, _ := io.ReadAll(r) - os.Stdout = stdout - output := string(out) + stdout := releaseStdout() re := regexp.MustCompile(tc.stdoutRegex) - match := re.MatchString(output) + match := re.MatchString(stdout) if !match { t.Errorf("match on %q not found in\n\n%s", - tc.stdoutRegex, output) + tc.stdoutRegex, stdout) } }) } diff --git a/webhook/send_test.go b/webhook/send_test.go index eabc85bf..d6b01876 100644 --- a/webhook/send_test.go +++ b/webhook/send_test.go @@ -17,8 +17,6 @@ package webhook import ( - "io" - "os" "regexp" "strings" "testing" @@ -176,17 +174,13 @@ func TestWebHook_Send(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { // t.Parallel() - Cannot run in parallel since we're using stdout - test.StdoutMutex.Lock() - defer test.StdoutMutex.Unlock() try := 0 contextDeadlineExceeded := true for contextDeadlineExceeded != false { try++ contextDeadlineExceeded = false - stdout := os.Stdout - r, w, _ := os.Pipe() - os.Stdout = w + releaseStdout := test.CaptureStdout() webhook := testWebHook(tc.wouldFail, false, tc.customHeaders) if tc.deleting { webhook.ServiceStatus.SetDeleting() @@ -219,15 +213,12 @@ func TestWebHook_Send(t *testing.T) { // THEN the logs are expected completedAt := time.Now() - w.Close() - out, _ := io.ReadAll(r) - os.Stdout = stdout - output := string(out) + stdout := releaseStdout() re := regexp.MustCompile(tc.stdoutRegex) - output = strings.ReplaceAll(output, "\n", "-n") - match := re.MatchString(output) + stdout = strings.ReplaceAll(stdout, "\n", "-n") + match := re.MatchString(stdout) if !match { - if strings.Contains(output, "context deadline exceeded") { + if strings.Contains(stdout, "context deadline exceeded") { contextDeadlineExceeded = true if try != 3 { time.Sleep(time.Second) @@ -235,7 +226,7 @@ func TestWebHook_Send(t *testing.T) { } } t.Errorf("match on %q not found in\n%q", - tc.stdoutRegex, output) + tc.stdoutRegex, stdout) } // AND the delay is expected if tc.delay != "" { @@ -287,8 +278,6 @@ func TestSlice_Send(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { // t.Parallel() - Cannot run in parallel since we're using stdout - test.StdoutMutex.Lock() - defer test.StdoutMutex.Unlock() try := 0 contextDeadlineExceeded := true @@ -297,9 +286,7 @@ func TestSlice_Send(t *testing.T) { contextDeadlineExceeded = false tc.repeat++ // repeat to check delay usage as map order is random for tc.repeat != 0 { - stdout := os.Stdout - r, w, _ := os.Pipe() - os.Stdout = w + releaseStdout := test.CaptureStdout() if tc.slice != nil { for id := range *tc.slice { (*tc.slice)[id].ID = id @@ -313,15 +300,12 @@ func TestSlice_Send(t *testing.T) { tc.slice.Send(&util.ServiceInfo{ID: name}, tc.useDelay) // THEN the logs are expected - w.Close() - out, _ := io.ReadAll(r) - os.Stdout = stdout - output := string(out) - output = strings.ReplaceAll(output, "\n", "-n") + stdout := releaseStdout() + stdout = strings.ReplaceAll(stdout, "\n", "-n") re := regexp.MustCompile(tc.stdoutRegex) - match := re.MatchString(output) + match := re.MatchString(stdout) if !match { - if strings.Contains(output, "context deadline exceeded") { + if strings.Contains(stdout, "context deadline exceeded") { contextDeadlineExceeded = true if try != 3 { time.Sleep(time.Second) @@ -330,15 +314,15 @@ func TestSlice_Send(t *testing.T) { } if tc.stdoutRegexAlt != "" { re = regexp.MustCompile(tc.stdoutRegexAlt) - match = re.MatchString(output) + match = re.MatchString(stdout) if !match { t.Errorf("match on %q not found in\n%q", - tc.stdoutRegexAlt, output) + tc.stdoutRegexAlt, stdout) } return } t.Errorf("match on %q not found in\n%q", - tc.stdoutRegex, output) + tc.stdoutRegex, stdout) } tc.repeat-- } diff --git a/webhook/verify_test.go b/webhook/verify_test.go index bc01094e..e9309315 100644 --- a/webhook/verify_test.go +++ b/webhook/verify_test.go @@ -18,8 +18,6 @@ package webhook import ( "fmt" - "io" - "os" "regexp" "strings" "testing" @@ -86,28 +84,21 @@ webhook: for name, tc := range tests { t.Run(name, func(t *testing.T) { // t.Parallel() - Cannot run in parallel since we're using stdout - test.StdoutMutex.Lock() - defer test.StdoutMutex.Unlock() + releaseStdout := test.CaptureStdout() if tc.want != "" { tc.want += "\n" } - stdout := os.Stdout - r, w, _ := os.Pipe() - os.Stdout = w // WHEN Print is called tc.slice.Print("") // THEN it prints the expected output - w.Close() - out, _ := io.ReadAll(r) - os.Stdout = stdout - strOut := string(out) + stdout := releaseStdout() tc.want = strings.TrimPrefix(tc.want, "\n") - if strOut != tc.want { + if stdout != tc.want { t.Errorf("Print should have given\n%q\nbut gave\n%q", - tc.want, strOut) + tc.want, stdout) } }) }