From a850b9e047925436b4729c7f78ebc2f153bb3c51 Mon Sep 17 00:00:00 2001 From: Pierangelo Di Pilato Date: Thu, 10 Mar 2022 19:00:48 +0100 Subject: [PATCH] Use t.Setenv instead of os.Setenv in tests (#2454) Go 1.17 introduced a new handy API for setting env vars scoped for a single test so we can avoid the hard to read set and reset env loops. Signed-off-by: Pierangelo Di Pilato --- changeset/commit_test.go | 7 ++----- leaderelection/config_test.go | 14 +++++--------- leaderelection/context_test.go | 21 ++++----------------- logging/config_test.go | 8 ++------ metrics/config_observability_test.go | 8 ++------ metrics/config_test.go | 7 ++----- metrics/prometheus_exporter_test.go | 7 ++----- network/domain_test.go | 28 +++++++++------------------- test/gke/client_test.go | 5 +---- test/prow/env_test.go | 9 ++------- test/prow/prow_test.go | 15 +++------------ webhook/env_test.go | 7 ++----- 12 files changed, 36 insertions(+), 100 deletions(-) diff --git a/changeset/commit_test.go b/changeset/commit_test.go index 057bccfe0c..b1724c1403 100644 --- a/changeset/commit_test.go +++ b/changeset/commit_test.go @@ -19,7 +19,6 @@ package changeset import ( "errors" "fmt" - "os" "testing" ) @@ -82,10 +81,8 @@ func TestReadFile(t *testing.T) { }} for _, test := range tests { t.Run(test.name, func(t *testing.T) { - if test.koDataPathEnvDoesNotExist { - os.Clearenv() - } else { - os.Setenv(koDataPathEnvName, test.koDataPath) + if !test.koDataPathEnvDoesNotExist { + t.Setenv(koDataPathEnvName, test.koDataPath) } got, err := Get() diff --git a/leaderelection/config_test.go b/leaderelection/config_test.go index 2f2ca83a05..3b95b448c0 100644 --- a/leaderelection/config_test.go +++ b/leaderelection/config_test.go @@ -18,7 +18,6 @@ package leaderelection import ( "fmt" - "os" "strconv" "strings" "testing" @@ -26,6 +25,7 @@ import ( "github.com/google/go-cmp/cmp" corev1 "k8s.io/api/core/v1" + "knative.dev/pkg/kmap" ) @@ -248,20 +248,16 @@ func TestNewStatefulSetConfig(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { if tc.pod != "" { - os.Setenv(controllerOrdinalEnv, tc.pod) - defer os.Unsetenv(controllerOrdinalEnv) + t.Setenv(controllerOrdinalEnv, tc.pod) } if tc.service != "" { - os.Setenv(serviceNameEnv, tc.service) - defer os.Unsetenv(serviceNameEnv) + t.Setenv(serviceNameEnv, tc.service) } if tc.port != "" { - os.Setenv(servicePortEnv, tc.port) - defer os.Unsetenv(servicePortEnv) + t.Setenv(servicePortEnv, tc.port) } if tc.protocol != "" { - os.Setenv(serviceProtocolEnv, tc.protocol) - defer os.Unsetenv(serviceProtocolEnv) + t.Setenv(serviceProtocolEnv, tc.protocol) } ssc, err := newStatefulSetConfig() diff --git a/leaderelection/context_test.go b/leaderelection/context_test.go index 618bd49656..b54e8ffeef 100644 --- a/leaderelection/context_test.go +++ b/leaderelection/context_test.go @@ -19,7 +19,6 @@ package leaderelection import ( "context" "fmt" - "os" "testing" "time" @@ -225,12 +224,8 @@ func TestNewStatefulSetBucketAndSet(t *testing.T) { "http://as-2.autoscaler.knative-testing.svc.cluster.local:80", } - os.Setenv(controllerOrdinalEnv, "as-2") - os.Setenv(serviceNameEnv, "autoscaler") - t.Cleanup(func() { - os.Unsetenv(controllerOrdinalEnv) - os.Unsetenv(serviceNameEnv) - }) + t.Setenv(controllerOrdinalEnv, "as-2") + t.Setenv(serviceNameEnv, "autoscaler") _, _, err := NewStatefulSetBucketAndSet(2) if err == nil { @@ -271,16 +266,8 @@ func TestWithStatefulSetBuilder(t *testing.T) { } enq := func(reconciler.Bucket, types.NamespacedName) {} - if os.Setenv(controllerOrdinalEnv, "as-2") != nil { - t.Fatalf("Failed to set env var %s=%s", controllerOrdinalEnv, "as-2") - } - if os.Setenv(serviceNameEnv, "autoscaler") != nil { - t.Fatalf("Failed to set env var %s=%s", serviceNameEnv, "autoscaler") - } - t.Cleanup(func() { - os.Unsetenv(controllerOrdinalEnv) - os.Unsetenv(serviceNameEnv) - }) + t.Setenv(controllerOrdinalEnv, "as-2") + t.Setenv(serviceNameEnv, "autoscaler") ctx = WithDynamicLeaderElectorBuilder(ctx, nil, cc) if !HasLeaderElection(ctx) { diff --git a/logging/config_test.go b/logging/config_test.go index 506d57bef7..dac48ae2ab 100644 --- a/logging/config_test.go +++ b/logging/config_test.go @@ -18,7 +18,6 @@ package logging import ( "fmt" - "os" "testing" "github.com/google/go-cmp/cmp" @@ -481,14 +480,11 @@ func TestConfigMapName(t *testing.T) { if got, want := ConfigMapName(), "config-logging"; got != want { t.Errorf("ConfigMapName = %q, want: %q", got, want) } - t.Cleanup(func() { - os.Unsetenv(configMapNameEnv) - }) - os.Setenv(configMapNameEnv, "") + t.Setenv(configMapNameEnv, "") if got, want := ConfigMapName(), "config-logging"; got != want { t.Errorf("ConfigMapName = %q, want: %q", got, want) } - os.Setenv(configMapNameEnv, "slowly-dying-inside") + t.Setenv(configMapNameEnv, "slowly-dying-inside") if got, want := ConfigMapName(), "slowly-dying-inside"; got != want { t.Errorf("ConfigMapName = %q, want: %q", got, want) } diff --git a/metrics/config_observability_test.go b/metrics/config_observability_test.go index f042507bb3..5338f510b1 100644 --- a/metrics/config_observability_test.go +++ b/metrics/config_observability_test.go @@ -17,7 +17,6 @@ limitations under the License. package metrics import ( - "os" "testing" "github.com/google/go-cmp/cmp" @@ -146,14 +145,11 @@ func TestConfigMapName(t *testing.T) { if got, want := ConfigMapName(), "config-observability"; got != want { t.Errorf("ConfigMapName = %q, want: %q", got, want) } - t.Cleanup(func() { - os.Unsetenv(configMapNameEnv) - }) - os.Setenv(configMapNameEnv, "") + t.Setenv(configMapNameEnv, "") if got, want := ConfigMapName(), "config-observability"; got != want { t.Errorf("ConfigMapName = %q, want: %q", got, want) } - os.Setenv(configMapNameEnv, "why-is-living-so-hard?") + t.Setenv(configMapNameEnv, "why-is-living-so-hard?") if got, want := ConfigMapName(), "why-is-living-so-hard?"; got != want { t.Errorf("ConfigMapName = %q, want: %q", got, want) } diff --git a/metrics/config_test.go b/metrics/config_test.go index 50b6f0e9ea..ce7d0a7335 100644 --- a/metrics/config_test.go +++ b/metrics/config_test.go @@ -18,7 +18,6 @@ package metrics import ( "context" "math" - "os" "strconv" "strings" "testing" @@ -380,8 +379,7 @@ func TestGetMetricsConfig_fromEnv(t *testing.T) { for _, test := range successTests { t.Run(test.name, func(t *testing.T) { - os.Setenv(test.varName, test.varValue) - defer os.Unsetenv(test.varName) + t.Setenv(test.varName, test.varValue) mc, err := createMetricsConfig(ctx, test.ops) if err != nil { @@ -395,8 +393,7 @@ func TestGetMetricsConfig_fromEnv(t *testing.T) { for _, test := range failureTests { t.Run(test.name, func(t *testing.T) { - os.Setenv(test.varName, test.varValue) - defer os.Unsetenv(test.varName) + t.Setenv(test.varName, test.varValue) mc, err := createMetricsConfig(ctx, test.ops) if mc != nil { diff --git a/metrics/prometheus_exporter_test.go b/metrics/prometheus_exporter_test.go index ef2dcf6871..e5ae2e47c6 100644 --- a/metrics/prometheus_exporter_test.go +++ b/metrics/prometheus_exporter_test.go @@ -14,7 +14,6 @@ package metrics import ( "context" - "os" "testing" "time" @@ -103,12 +102,10 @@ func TestNewPrometheusExporter_fromEnv(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { if tc.prometheusPortVarName != "" { - os.Setenv(tc.prometheusPortVarName, tc.prometheusPortVarValue) - defer os.Unsetenv(tc.prometheusPortVarName) + t.Setenv(tc.prometheusPortVarName, tc.prometheusPortVarValue) } if tc.prometheusHostVarName != "" { - os.Setenv(tc.prometheusHostVarName, tc.prometheusHostVarValue) - defer os.Unsetenv(tc.prometheusHostVarName) + t.Setenv(tc.prometheusHostVarName, tc.prometheusHostVarValue) } mc, err := createMetricsConfig(context.Background(), tc.ops) if err != nil { diff --git a/network/domain_test.go b/network/domain_test.go index e6a61d2358..3aa27d0e83 100644 --- a/network/domain_test.go +++ b/network/domain_test.go @@ -17,7 +17,6 @@ limitations under the License. package network import ( - "os" "strings" "testing" ) @@ -75,24 +74,15 @@ options ndots:5 want: defaultDomainName, }} - domainWas := os.Getenv(clusterDomainEnvKey) - t.Cleanup(func() { - if len(domainWas) > 0 { - _ = os.Setenv(clusterDomainEnvKey, domainWas) - } else { - _ = os.Unsetenv(clusterDomainEnvKey) - } - }) - for _, tt := range tests { - if len(tt.env) > 0 { - _ = os.Setenv(clusterDomainEnvKey, tt.env) - } else { - _ = os.Unsetenv(clusterDomainEnvKey) - } - got := getClusterDomainName(strings.NewReader(tt.resolvConf)) - if got != tt.want { - t.Errorf("Test %s failed expected: %s but got: %s", tt.name, tt.want, got) - } + t.Run(tt.name, func(t *testing.T) { + if len(tt.env) > 0 { + t.Setenv(clusterDomainEnvKey, tt.env) + } + got := getClusterDomainName(strings.NewReader(tt.resolvConf)) + if got != tt.want { + t.Errorf("Test %s failed expected: %s but got: %s", tt.name, tt.want, got) + } + }) } } diff --git a/test/gke/client_test.go b/test/gke/client_test.go index 69659b00cc..88abbb68dc 100644 --- a/test/gke/client_test.go +++ b/test/gke/client_test.go @@ -29,10 +29,7 @@ const credEnvKey = "GOOGLE_APPLICATION_CREDENTIALS" // func NewSDKClient(opts ...option.ClientOption) (SDKOperations, error) { func TestNewSDKClient(t *testing.T) { pwd, _ := os.Getwd() - if err := os.Setenv(credEnvKey, filepath.Join(pwd, "fake/credentials.json")); err != nil { - t.Errorf("Failed to set %s to fake/credentials.json: %v", credEnvKey, err) - } - defer os.Unsetenv(credEnvKey) + t.Setenv(credEnvKey, filepath.Join(pwd, "fake/credentials.json")) datas := []struct { req option.ClientOption diff --git a/test/prow/env_test.go b/test/prow/env_test.go index 9b1db00805..2b1ece3f90 100644 --- a/test/prow/env_test.go +++ b/test/prow/env_test.go @@ -17,16 +17,11 @@ limitations under the License. package prow import ( - "os" "testing" ) func TestGetEnvConfig(t *testing.T) { - isCI := os.Getenv("CI") - // Set it to the original value - defer os.Setenv("CI", isCI) - - os.Setenv("CI", "true") + t.Setenv("CI", "true") ec, err := GetEnvConfig() t.Log("EnvConfig is:", ec) if err != nil { @@ -36,7 +31,7 @@ func TestGetEnvConfig(t *testing.T) { t.Fatal("Expected CI to be true but is false") } - os.Setenv("CI", "false") + t.Setenv("CI", "false") if _, err = GetEnvConfig(); err == nil { t.Fatal("Expected an error if called from a non-CI environment but got nil") } diff --git a/test/prow/prow_test.go b/test/prow/prow_test.go index 84cde9cf27..67b47bc8a7 100644 --- a/test/prow/prow_test.go +++ b/test/prow/prow_test.go @@ -19,7 +19,6 @@ limitations under the License. package prow import ( - "os" "testing" ) @@ -66,30 +65,22 @@ func TestInvalidJobPath(t *testing.T) { } func TestIsCI(t *testing.T) { - isCI := os.Getenv("CI") - // Set it to the original value - defer os.Setenv("CI", isCI) - - os.Setenv("CI", "true") + t.Setenv("CI", "true") if ic := IsCI(); !ic { t.Fatal("Expected: true, actual: false") } } func TestGetArtifacts(t *testing.T) { - dir := os.Getenv("ARTIFACTS") - // Set it to the original value - defer os.Setenv("ARTIFACTS", dir) - // Test we can read from the env var - os.Setenv("ARTIFACTS", "test") + t.Setenv("ARTIFACTS", "test") v := GetLocalArtifactsDir() if v != "test" { t.Fatalf("Actual artifacts dir: '%s' and Expected: 'test'", v) } // Test we can use the default - os.Setenv("ARTIFACTS", "") + t.Setenv("ARTIFACTS", "") v = GetLocalArtifactsDir() if v != "artifacts" { t.Fatalf("Actual artifacts dir: '%s' and Expected: 'artifacts'", v) diff --git a/webhook/env_test.go b/webhook/env_test.go index c08eaba779..3f3c705ba0 100644 --- a/webhook/env_test.go +++ b/webhook/env_test.go @@ -17,7 +17,6 @@ limitations under the License. package webhook import ( - "os" "testing" ) @@ -70,7 +69,7 @@ func TestPort(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // portEnvKey is unset when testing missing input. if tc.name != testMissingInputName { - os.Setenv(portEnvKey, tc.in) + t.Setenv(portEnvKey, tc.in) } defer func() { @@ -79,7 +78,6 @@ func TestPort(t *testing.T) { } else if r != nil && !tc.wantPanic { t.Error("Got unexpected panic") } - os.Unsetenv(portEnvKey) }() if got := PortFromEnv(testDefaultPort); got != tc.want { @@ -104,7 +102,7 @@ func TestWebhookName(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // webhookNameEnv is unset when testing missing input. if tc.name != testMissingInputName { - os.Setenv(webhookNameEnv, tc.in) + t.Setenv(webhookNameEnv, tc.in) } defer func() { @@ -113,7 +111,6 @@ func TestWebhookName(t *testing.T) { } else if r != nil && !tc.wantPanic { t.Error("Got unexpected panic") } - os.Unsetenv(webhookNameEnv) }() if got := NameFromEnv(); got != tc.want {