From d129266cc038e90cda7c993edb7cb90afcb11433 Mon Sep 17 00:00:00 2001 From: Antonin Bas Date: Fri, 23 Feb 2024 01:49:03 -0800 Subject: [PATCH] Fix race condition in pkg/apiserver/certificate unit tests (#6004) There was some "interference" between TestSelfSignedCertProviderRotate and TestSelfSignedCertProviderRun. The root cause is that the certutil.GenerateSelfSignedCertKey does not support a custom clock implementation and always calls time.Now() to determine the current time. It then adds a year to the current time to set the expiration time of the certificate. This means that when rotateSelfSignedCertificate() is called as part of TestSelfSignedCertProviderRotate, the new certificate is already expired, and rotateSelfSignedCertificate() will be called immediately a second time. By this time however, TestSelfSignedCertProviderRotate has already exited, and we are already running the next test, TestSelfSignedCertProviderRun. This creates a race condition because the next test will overwrite generateSelfSignedCertKey with a mock version, right as it is called by the second call to rotateSelfSignedCertificate() from the previous test's provider. To avoid this race condition, we make generateSelfSignedCertKey a member of selfSignedCertProvider. Fixes #5977 Signed-off-by: Antonin Bas Co-authored-by: Quan Tian --- .../certificate/selfsignedcert_provider.go | 47 ++++++++++++++----- .../selfsignedcert_provider_test.go | 26 ++++------ 2 files changed, 44 insertions(+), 29 deletions(-) diff --git a/pkg/apiserver/certificate/selfsignedcert_provider.go b/pkg/apiserver/certificate/selfsignedcert_provider.go index b7f194808ee..7ae466ee54c 100644 --- a/pkg/apiserver/certificate/selfsignedcert_provider.go +++ b/pkg/apiserver/certificate/selfsignedcert_provider.go @@ -44,11 +44,11 @@ import ( "antrea.io/antrea/pkg/util/env" ) -var ( - loopbackAddresses = []net.IP{net.ParseIP("127.0.0.1"), net.IPv6loopback} - // Declared for unit testing. - generateSelfSignedCertKey = certutil.GenerateSelfSignedCertKey -) +var loopbackAddresses = []net.IP{net.ParseIP("127.0.0.1"), net.IPv6loopback} + +// generateSelfSignedCertKeyFn represents a function which can create a self-signed certificate and +// key for the given host. +type generateSelfSignedCertKeyFn func(host string, alternateIPs []net.IP, alternateDNS []string) ([]byte, []byte, error) type selfSignedCertProvider struct { client kubernetes.Interface @@ -69,23 +69,46 @@ type selfSignedCertProvider struct { cert []byte key []byte verifyOptions *x509.VerifyOptions + + // generateSelfSignedCertKey is the function used to generate self-signed certificates and keys. + // We use a struct member for unit testing. + generateSelfSignedCertKey generateSelfSignedCertKeyFn } var _ dynamiccertificates.CAContentProvider = &selfSignedCertProvider{} var _ dynamiccertificates.ControllerRunner = &selfSignedCertProvider{} -func newSelfSignedCertProvider(client kubernetes.Interface, secureServing *options.SecureServingOptionsWithLoopback, caConfig *CAConfig) (*selfSignedCertProvider, error) { +type providerOption func(p *selfSignedCertProvider) + +func withGenerateSelfSignedCertKeyFn(fn generateSelfSignedCertKeyFn) providerOption { + return func(p *selfSignedCertProvider) { + p.generateSelfSignedCertKey = fn + } +} + +func withClock(clock clockutils.Clock) providerOption { + return func(p *selfSignedCertProvider) { + p.clock = clock + } +} + +func newSelfSignedCertProvider(client kubernetes.Interface, secureServing *options.SecureServingOptionsWithLoopback, caConfig *CAConfig, options ...providerOption) (*selfSignedCertProvider, error) { // Set the CertKey and CertDirectory to generate the certificate files. secureServing.ServerCert.CertDirectory = caConfig.SelfSignedCertDir secureServing.ServerCert.CertKey.CertFile = filepath.Join(caConfig.SelfSignedCertDir, caConfig.PairName+".crt") secureServing.ServerCert.CertKey.KeyFile = filepath.Join(caConfig.SelfSignedCertDir, caConfig.PairName+".key") provider := &selfSignedCertProvider{ - client: client, - secureServing: secureServing, - caConfig: caConfig, - queue: workqueue.NewNamedRateLimitingQueue(workqueue.DefaultControllerRateLimiter(), "selfSignedCertProvider"), - clock: clockutils.RealClock{}, + client: client, + secureServing: secureServing, + caConfig: caConfig, + queue: workqueue.NewNamedRateLimitingQueue(workqueue.DefaultControllerRateLimiter(), "selfSignedCertProvider"), + clock: clockutils.RealClock{}, + generateSelfSignedCertKey: certutil.GenerateSelfSignedCertKey, + } + + for _, option := range options { + option(provider) } if caConfig.TLSSecretName != "" { @@ -233,7 +256,7 @@ func (p *selfSignedCertProvider) rotateSelfSignedCertificate() error { } if p.shouldRotateCertificate(cert) { klog.InfoS("Generating self-signed cert") - if cert, key, err = generateSelfSignedCertKey(p.caConfig.ServiceName, loopbackAddresses, GetAntreaServerNames(p.caConfig.ServiceName)); err != nil { + if cert, key, err = p.generateSelfSignedCertKey(p.caConfig.ServiceName, loopbackAddresses, GetAntreaServerNames(p.caConfig.ServiceName)); err != nil { return fmt.Errorf("unable to generate self-signed cert: %v", err) } // If Secret is specified, we should save the new certificate and key to it. diff --git a/pkg/apiserver/certificate/selfsignedcert_provider_test.go b/pkg/apiserver/certificate/selfsignedcert_provider_test.go index 95b214636b3..2123f3ab307 100644 --- a/pkg/apiserver/certificate/selfsignedcert_provider_test.go +++ b/pkg/apiserver/certificate/selfsignedcert_provider_test.go @@ -48,7 +48,7 @@ var ( testOneYearCert3, testOneYearKey3, _ = certutil.GenerateSelfSignedCertKeyWithFixtures("localhost", loopbackAddresses, nil, "") ) -func newTestSelfSignedCertProvider(t *testing.T, client *fakeclientset.Clientset, tlsSecretName string, minValidDuration time.Duration) *selfSignedCertProvider { +func newTestSelfSignedCertProvider(t *testing.T, client *fakeclientset.Clientset, tlsSecretName string, minValidDuration time.Duration, options ...providerOption) *selfSignedCertProvider { secureServing := genericoptions.NewSecureServingOptions().WithLoopback() caConfig := &CAConfig{ TLSSecretName: tlsSecretName, @@ -57,7 +57,7 @@ func newTestSelfSignedCertProvider(t *testing.T, client *fakeclientset.Clientset ServiceName: testServiceName, PairName: testPairName, } - p, err := newSelfSignedCertProvider(client, secureServing, caConfig) + p, err := newSelfSignedCertProvider(client, secureServing, caConfig, options...) require.NoError(t, err) return p } @@ -107,8 +107,7 @@ func TestSelfSignedCertProviderRotate(t *testing.T) { defer cancel() client := fakeclientset.NewSimpleClientset() fakeClock := clocktesting.NewFakeClock(time.Now()) - p := newTestSelfSignedCertProvider(t, client, testSecretName, time.Hour*24*90) - p.clock = fakeClock + p := newTestSelfSignedCertProvider(t, client, testSecretName, time.Hour*24*90, withClock(fakeClock)) certInFile, err := os.ReadFile(p.secureServing.ServerCert.CertKey.CertFile) require.NoError(t, err) keyInFile, _ := os.ReadFile(p.secureServing.ServerCert.CertKey.KeyFile) @@ -161,7 +160,7 @@ func TestSelfSignedCertProviderRotate(t *testing.T) { assert.NotEqual(c, map[string][]byte{ corev1.TLSCertKey: testOneYearCert, corev1.TLSPrivateKeyKey: testOneYearKey, - }, gotSecret.Data, "Secret doesn't match") + }, gotSecret.Data, "Secret should not match") }, 2*time.Second, 50*time.Millisecond) } @@ -264,7 +263,6 @@ func TestSelfSignedCertProviderRun(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - defer mockGenerateSelfSignedCertKey(testOneYearCert2, testOneYearKey2)() ctx, cancel := context.WithCancel(context.Background()) defer cancel() var objs []runtime.Object @@ -272,7 +270,11 @@ func TestSelfSignedCertProviderRun(t *testing.T) { objs = append(objs, tt.existingSecret) } client := fakeclientset.NewSimpleClientset(objs...) - p := newTestSelfSignedCertProvider(t, client, tt.tlsSecretName, tt.minValidDuration) + // mock the generateSelfSignedCertKey fuction + generateSelfSignedCertKey := func(_ string, _ []net.IP, _ []string) ([]byte, []byte, error) { + return testOneYearCert2, testOneYearKey2, nil + } + p := newTestSelfSignedCertProvider(t, client, tt.tlsSecretName, tt.minValidDuration, withGenerateSelfSignedCertKeyFn(generateSelfSignedCertKey)) go p.Run(ctx, 1) if tt.updatedSecret != nil { client.CoreV1().Secrets(tt.updatedSecret.Namespace).Update(ctx, tt.updatedSecret, metav1.UpdateOptions{}) @@ -291,13 +293,3 @@ func TestSelfSignedCertProviderRun(t *testing.T) { }) } } - -func mockGenerateSelfSignedCertKey(cert, key []byte) func() { - originalFn := generateSelfSignedCertKey - generateSelfSignedCertKey = func(_ string, _ []net.IP, _ []string) ([]byte, []byte, error) { - return cert, key, nil - } - return func() { - generateSelfSignedCertKey = originalFn - } -}