Skip to content

Commit

Permalink
Use t.Setenv instead of os.Setenv in tests (#2454)
Browse files Browse the repository at this point in the history
Go 1.17 introduced a new handy API for setting env vars scoped for
a single test so we can avoid the hard to read set and reset env
loops.

Signed-off-by: Pierangelo Di Pilato <[email protected]>
  • Loading branch information
pierDipi authored Mar 10, 2022
1 parent d2cdc68 commit a850b9e
Show file tree
Hide file tree
Showing 12 changed files with 36 additions and 100 deletions.
7 changes: 2 additions & 5 deletions changeset/commit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package changeset
import (
"errors"
"fmt"
"os"
"testing"
)

Expand Down Expand Up @@ -82,10 +81,8 @@ func TestReadFile(t *testing.T) {
}}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
if test.koDataPathEnvDoesNotExist {
os.Clearenv()
} else {
os.Setenv(koDataPathEnvName, test.koDataPath)
if !test.koDataPathEnvDoesNotExist {
t.Setenv(koDataPathEnvName, test.koDataPath)
}

got, err := Get()
Expand Down
14 changes: 5 additions & 9 deletions leaderelection/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@ package leaderelection

import (
"fmt"
"os"
"strconv"
"strings"
"testing"
"time"

"github.com/google/go-cmp/cmp"
corev1 "k8s.io/api/core/v1"

"knative.dev/pkg/kmap"
)

Expand Down Expand Up @@ -248,20 +248,16 @@ func TestNewStatefulSetConfig(t *testing.T) {
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
if tc.pod != "" {
os.Setenv(controllerOrdinalEnv, tc.pod)
defer os.Unsetenv(controllerOrdinalEnv)
t.Setenv(controllerOrdinalEnv, tc.pod)
}
if tc.service != "" {
os.Setenv(serviceNameEnv, tc.service)
defer os.Unsetenv(serviceNameEnv)
t.Setenv(serviceNameEnv, tc.service)
}
if tc.port != "" {
os.Setenv(servicePortEnv, tc.port)
defer os.Unsetenv(servicePortEnv)
t.Setenv(servicePortEnv, tc.port)
}
if tc.protocol != "" {
os.Setenv(serviceProtocolEnv, tc.protocol)
defer os.Unsetenv(serviceProtocolEnv)
t.Setenv(serviceProtocolEnv, tc.protocol)
}

ssc, err := newStatefulSetConfig()
Expand Down
21 changes: 4 additions & 17 deletions leaderelection/context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package leaderelection
import (
"context"
"fmt"
"os"
"testing"
"time"

Expand Down Expand Up @@ -225,12 +224,8 @@ func TestNewStatefulSetBucketAndSet(t *testing.T) {
"http://as-2.autoscaler.knative-testing.svc.cluster.local:80",
}

os.Setenv(controllerOrdinalEnv, "as-2")
os.Setenv(serviceNameEnv, "autoscaler")
t.Cleanup(func() {
os.Unsetenv(controllerOrdinalEnv)
os.Unsetenv(serviceNameEnv)
})
t.Setenv(controllerOrdinalEnv, "as-2")
t.Setenv(serviceNameEnv, "autoscaler")

_, _, err := NewStatefulSetBucketAndSet(2)
if err == nil {
Expand Down Expand Up @@ -271,16 +266,8 @@ func TestWithStatefulSetBuilder(t *testing.T) {
}
enq := func(reconciler.Bucket, types.NamespacedName) {}

if os.Setenv(controllerOrdinalEnv, "as-2") != nil {
t.Fatalf("Failed to set env var %s=%s", controllerOrdinalEnv, "as-2")
}
if os.Setenv(serviceNameEnv, "autoscaler") != nil {
t.Fatalf("Failed to set env var %s=%s", serviceNameEnv, "autoscaler")
}
t.Cleanup(func() {
os.Unsetenv(controllerOrdinalEnv)
os.Unsetenv(serviceNameEnv)
})
t.Setenv(controllerOrdinalEnv, "as-2")
t.Setenv(serviceNameEnv, "autoscaler")

ctx = WithDynamicLeaderElectorBuilder(ctx, nil, cc)
if !HasLeaderElection(ctx) {
Expand Down
8 changes: 2 additions & 6 deletions logging/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ package logging

import (
"fmt"
"os"
"testing"

"github.com/google/go-cmp/cmp"
Expand Down Expand Up @@ -481,14 +480,11 @@ func TestConfigMapName(t *testing.T) {
if got, want := ConfigMapName(), "config-logging"; got != want {
t.Errorf("ConfigMapName = %q, want: %q", got, want)
}
t.Cleanup(func() {
os.Unsetenv(configMapNameEnv)
})
os.Setenv(configMapNameEnv, "")
t.Setenv(configMapNameEnv, "")
if got, want := ConfigMapName(), "config-logging"; got != want {
t.Errorf("ConfigMapName = %q, want: %q", got, want)
}
os.Setenv(configMapNameEnv, "slowly-dying-inside")
t.Setenv(configMapNameEnv, "slowly-dying-inside")
if got, want := ConfigMapName(), "slowly-dying-inside"; got != want {
t.Errorf("ConfigMapName = %q, want: %q", got, want)
}
Expand Down
8 changes: 2 additions & 6 deletions metrics/config_observability_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ limitations under the License.
package metrics

import (
"os"
"testing"

"github.com/google/go-cmp/cmp"
Expand Down Expand Up @@ -146,14 +145,11 @@ func TestConfigMapName(t *testing.T) {
if got, want := ConfigMapName(), "config-observability"; got != want {
t.Errorf("ConfigMapName = %q, want: %q", got, want)
}
t.Cleanup(func() {
os.Unsetenv(configMapNameEnv)
})
os.Setenv(configMapNameEnv, "")
t.Setenv(configMapNameEnv, "")
if got, want := ConfigMapName(), "config-observability"; got != want {
t.Errorf("ConfigMapName = %q, want: %q", got, want)
}
os.Setenv(configMapNameEnv, "why-is-living-so-hard?")
t.Setenv(configMapNameEnv, "why-is-living-so-hard?")
if got, want := ConfigMapName(), "why-is-living-so-hard?"; got != want {
t.Errorf("ConfigMapName = %q, want: %q", got, want)
}
Expand Down
7 changes: 2 additions & 5 deletions metrics/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ package metrics
import (
"context"
"math"
"os"
"strconv"
"strings"
"testing"
Expand Down Expand Up @@ -380,8 +379,7 @@ func TestGetMetricsConfig_fromEnv(t *testing.T) {

for _, test := range successTests {
t.Run(test.name, func(t *testing.T) {
os.Setenv(test.varName, test.varValue)
defer os.Unsetenv(test.varName)
t.Setenv(test.varName, test.varValue)

mc, err := createMetricsConfig(ctx, test.ops)
if err != nil {
Expand All @@ -395,8 +393,7 @@ func TestGetMetricsConfig_fromEnv(t *testing.T) {

for _, test := range failureTests {
t.Run(test.name, func(t *testing.T) {
os.Setenv(test.varName, test.varValue)
defer os.Unsetenv(test.varName)
t.Setenv(test.varName, test.varValue)

mc, err := createMetricsConfig(ctx, test.ops)
if mc != nil {
Expand Down
7 changes: 2 additions & 5 deletions metrics/prometheus_exporter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ package metrics

import (
"context"
"os"
"testing"
"time"

Expand Down Expand Up @@ -103,12 +102,10 @@ func TestNewPrometheusExporter_fromEnv(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
if tc.prometheusPortVarName != "" {
os.Setenv(tc.prometheusPortVarName, tc.prometheusPortVarValue)
defer os.Unsetenv(tc.prometheusPortVarName)
t.Setenv(tc.prometheusPortVarName, tc.prometheusPortVarValue)
}
if tc.prometheusHostVarName != "" {
os.Setenv(tc.prometheusHostVarName, tc.prometheusHostVarValue)
defer os.Unsetenv(tc.prometheusHostVarName)
t.Setenv(tc.prometheusHostVarName, tc.prometheusHostVarValue)
}
mc, err := createMetricsConfig(context.Background(), tc.ops)
if err != nil {
Expand Down
28 changes: 9 additions & 19 deletions network/domain_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ limitations under the License.
package network

import (
"os"
"strings"
"testing"
)
Expand Down Expand Up @@ -75,24 +74,15 @@ options ndots:5
want: defaultDomainName,
}}

domainWas := os.Getenv(clusterDomainEnvKey)
t.Cleanup(func() {
if len(domainWas) > 0 {
_ = os.Setenv(clusterDomainEnvKey, domainWas)
} else {
_ = os.Unsetenv(clusterDomainEnvKey)
}
})

for _, tt := range tests {
if len(tt.env) > 0 {
_ = os.Setenv(clusterDomainEnvKey, tt.env)
} else {
_ = os.Unsetenv(clusterDomainEnvKey)
}
got := getClusterDomainName(strings.NewReader(tt.resolvConf))
if got != tt.want {
t.Errorf("Test %s failed expected: %s but got: %s", tt.name, tt.want, got)
}
t.Run(tt.name, func(t *testing.T) {
if len(tt.env) > 0 {
t.Setenv(clusterDomainEnvKey, tt.env)
}
got := getClusterDomainName(strings.NewReader(tt.resolvConf))
if got != tt.want {
t.Errorf("Test %s failed expected: %s but got: %s", tt.name, tt.want, got)
}
})
}
}
5 changes: 1 addition & 4 deletions test/gke/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,7 @@ const credEnvKey = "GOOGLE_APPLICATION_CREDENTIALS"
// func NewSDKClient(opts ...option.ClientOption) (SDKOperations, error) {
func TestNewSDKClient(t *testing.T) {
pwd, _ := os.Getwd()
if err := os.Setenv(credEnvKey, filepath.Join(pwd, "fake/credentials.json")); err != nil {
t.Errorf("Failed to set %s to fake/credentials.json: %v", credEnvKey, err)
}
defer os.Unsetenv(credEnvKey)
t.Setenv(credEnvKey, filepath.Join(pwd, "fake/credentials.json"))

datas := []struct {
req option.ClientOption
Expand Down
9 changes: 2 additions & 7 deletions test/prow/env_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,11 @@ limitations under the License.
package prow

import (
"os"
"testing"
)

func TestGetEnvConfig(t *testing.T) {
isCI := os.Getenv("CI")
// Set it to the original value
defer os.Setenv("CI", isCI)

os.Setenv("CI", "true")
t.Setenv("CI", "true")
ec, err := GetEnvConfig()
t.Log("EnvConfig is:", ec)
if err != nil {
Expand All @@ -36,7 +31,7 @@ func TestGetEnvConfig(t *testing.T) {
t.Fatal("Expected CI to be true but is false")
}

os.Setenv("CI", "false")
t.Setenv("CI", "false")
if _, err = GetEnvConfig(); err == nil {
t.Fatal("Expected an error if called from a non-CI environment but got nil")
}
Expand Down
15 changes: 3 additions & 12 deletions test/prow/prow_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ limitations under the License.
package prow

import (
"os"
"testing"
)

Expand Down Expand Up @@ -66,30 +65,22 @@ func TestInvalidJobPath(t *testing.T) {
}

func TestIsCI(t *testing.T) {
isCI := os.Getenv("CI")
// Set it to the original value
defer os.Setenv("CI", isCI)

os.Setenv("CI", "true")
t.Setenv("CI", "true")
if ic := IsCI(); !ic {
t.Fatal("Expected: true, actual: false")
}
}

func TestGetArtifacts(t *testing.T) {
dir := os.Getenv("ARTIFACTS")
// Set it to the original value
defer os.Setenv("ARTIFACTS", dir)

// Test we can read from the env var
os.Setenv("ARTIFACTS", "test")
t.Setenv("ARTIFACTS", "test")
v := GetLocalArtifactsDir()
if v != "test" {
t.Fatalf("Actual artifacts dir: '%s' and Expected: 'test'", v)
}

// Test we can use the default
os.Setenv("ARTIFACTS", "")
t.Setenv("ARTIFACTS", "")
v = GetLocalArtifactsDir()
if v != "artifacts" {
t.Fatalf("Actual artifacts dir: '%s' and Expected: 'artifacts'", v)
Expand Down
7 changes: 2 additions & 5 deletions webhook/env_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ limitations under the License.
package webhook

import (
"os"
"testing"
)

Expand Down Expand Up @@ -70,7 +69,7 @@ func TestPort(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
// portEnvKey is unset when testing missing input.
if tc.name != testMissingInputName {
os.Setenv(portEnvKey, tc.in)
t.Setenv(portEnvKey, tc.in)
}

defer func() {
Expand All @@ -79,7 +78,6 @@ func TestPort(t *testing.T) {
} else if r != nil && !tc.wantPanic {
t.Error("Got unexpected panic")
}
os.Unsetenv(portEnvKey)
}()

if got := PortFromEnv(testDefaultPort); got != tc.want {
Expand All @@ -104,7 +102,7 @@ func TestWebhookName(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
// webhookNameEnv is unset when testing missing input.
if tc.name != testMissingInputName {
os.Setenv(webhookNameEnv, tc.in)
t.Setenv(webhookNameEnv, tc.in)
}

defer func() {
Expand All @@ -113,7 +111,6 @@ func TestWebhookName(t *testing.T) {
} else if r != nil && !tc.wantPanic {
t.Error("Got unexpected panic")
}
os.Unsetenv(webhookNameEnv)
}()

if got := NameFromEnv(); got != tc.want {
Expand Down

0 comments on commit a850b9e

Please sign in to comment.