diff --git a/pkg/apiserver/certificate/selfsignedcert_provider.go b/pkg/apiserver/certificate/selfsignedcert_provider.go index b7f194808ee..7ae466ee54c 100644 --- a/pkg/apiserver/certificate/selfsignedcert_provider.go +++ b/pkg/apiserver/certificate/selfsignedcert_provider.go @@ -44,11 +44,11 @@ import ( "antrea.io/antrea/pkg/util/env" ) -var ( - loopbackAddresses = []net.IP{net.ParseIP("127.0.0.1"), net.IPv6loopback} - // Declared for unit testing. - generateSelfSignedCertKey = certutil.GenerateSelfSignedCertKey -) +var loopbackAddresses = []net.IP{net.ParseIP("127.0.0.1"), net.IPv6loopback} + +// generateSelfSignedCertKeyFn represents a function which can create a self-signed certificate and +// key for the given host. +type generateSelfSignedCertKeyFn func(host string, alternateIPs []net.IP, alternateDNS []string) ([]byte, []byte, error) type selfSignedCertProvider struct { client kubernetes.Interface @@ -69,23 +69,46 @@ type selfSignedCertProvider struct { cert []byte key []byte verifyOptions *x509.VerifyOptions + + // generateSelfSignedCertKey is the function used to generate self-signed certificates and keys. + // We use a struct member for unit testing. + generateSelfSignedCertKey generateSelfSignedCertKeyFn } var _ dynamiccertificates.CAContentProvider = &selfSignedCertProvider{} var _ dynamiccertificates.ControllerRunner = &selfSignedCertProvider{} -func newSelfSignedCertProvider(client kubernetes.Interface, secureServing *options.SecureServingOptionsWithLoopback, caConfig *CAConfig) (*selfSignedCertProvider, error) { +type providerOption func(p *selfSignedCertProvider) + +func withGenerateSelfSignedCertKeyFn(fn generateSelfSignedCertKeyFn) providerOption { + return func(p *selfSignedCertProvider) { + p.generateSelfSignedCertKey = fn + } +} + +func withClock(clock clockutils.Clock) providerOption { + return func(p *selfSignedCertProvider) { + p.clock = clock + } +} + +func newSelfSignedCertProvider(client kubernetes.Interface, secureServing *options.SecureServingOptionsWithLoopback, caConfig *CAConfig, options ...providerOption) (*selfSignedCertProvider, error) { // Set the CertKey and CertDirectory to generate the certificate files. secureServing.ServerCert.CertDirectory = caConfig.SelfSignedCertDir secureServing.ServerCert.CertKey.CertFile = filepath.Join(caConfig.SelfSignedCertDir, caConfig.PairName+".crt") secureServing.ServerCert.CertKey.KeyFile = filepath.Join(caConfig.SelfSignedCertDir, caConfig.PairName+".key") provider := &selfSignedCertProvider{ - client: client, - secureServing: secureServing, - caConfig: caConfig, - queue: workqueue.NewNamedRateLimitingQueue(workqueue.DefaultControllerRateLimiter(), "selfSignedCertProvider"), - clock: clockutils.RealClock{}, + client: client, + secureServing: secureServing, + caConfig: caConfig, + queue: workqueue.NewNamedRateLimitingQueue(workqueue.DefaultControllerRateLimiter(), "selfSignedCertProvider"), + clock: clockutils.RealClock{}, + generateSelfSignedCertKey: certutil.GenerateSelfSignedCertKey, + } + + for _, option := range options { + option(provider) } if caConfig.TLSSecretName != "" { @@ -233,7 +256,7 @@ func (p *selfSignedCertProvider) rotateSelfSignedCertificate() error { } if p.shouldRotateCertificate(cert) { klog.InfoS("Generating self-signed cert") - if cert, key, err = generateSelfSignedCertKey(p.caConfig.ServiceName, loopbackAddresses, GetAntreaServerNames(p.caConfig.ServiceName)); err != nil { + if cert, key, err = p.generateSelfSignedCertKey(p.caConfig.ServiceName, loopbackAddresses, GetAntreaServerNames(p.caConfig.ServiceName)); err != nil { return fmt.Errorf("unable to generate self-signed cert: %v", err) } // If Secret is specified, we should save the new certificate and key to it. diff --git a/pkg/apiserver/certificate/selfsignedcert_provider_test.go b/pkg/apiserver/certificate/selfsignedcert_provider_test.go index 95b214636b3..2123f3ab307 100644 --- a/pkg/apiserver/certificate/selfsignedcert_provider_test.go +++ b/pkg/apiserver/certificate/selfsignedcert_provider_test.go @@ -48,7 +48,7 @@ var ( testOneYearCert3, testOneYearKey3, _ = certutil.GenerateSelfSignedCertKeyWithFixtures("localhost", loopbackAddresses, nil, "") ) -func newTestSelfSignedCertProvider(t *testing.T, client *fakeclientset.Clientset, tlsSecretName string, minValidDuration time.Duration) *selfSignedCertProvider { +func newTestSelfSignedCertProvider(t *testing.T, client *fakeclientset.Clientset, tlsSecretName string, minValidDuration time.Duration, options ...providerOption) *selfSignedCertProvider { secureServing := genericoptions.NewSecureServingOptions().WithLoopback() caConfig := &CAConfig{ TLSSecretName: tlsSecretName, @@ -57,7 +57,7 @@ func newTestSelfSignedCertProvider(t *testing.T, client *fakeclientset.Clientset ServiceName: testServiceName, PairName: testPairName, } - p, err := newSelfSignedCertProvider(client, secureServing, caConfig) + p, err := newSelfSignedCertProvider(client, secureServing, caConfig, options...) require.NoError(t, err) return p } @@ -107,8 +107,7 @@ func TestSelfSignedCertProviderRotate(t *testing.T) { defer cancel() client := fakeclientset.NewSimpleClientset() fakeClock := clocktesting.NewFakeClock(time.Now()) - p := newTestSelfSignedCertProvider(t, client, testSecretName, time.Hour*24*90) - p.clock = fakeClock + p := newTestSelfSignedCertProvider(t, client, testSecretName, time.Hour*24*90, withClock(fakeClock)) certInFile, err := os.ReadFile(p.secureServing.ServerCert.CertKey.CertFile) require.NoError(t, err) keyInFile, _ := os.ReadFile(p.secureServing.ServerCert.CertKey.KeyFile) @@ -161,7 +160,7 @@ func TestSelfSignedCertProviderRotate(t *testing.T) { assert.NotEqual(c, map[string][]byte{ corev1.TLSCertKey: testOneYearCert, corev1.TLSPrivateKeyKey: testOneYearKey, - }, gotSecret.Data, "Secret doesn't match") + }, gotSecret.Data, "Secret should not match") }, 2*time.Second, 50*time.Millisecond) } @@ -264,7 +263,6 @@ func TestSelfSignedCertProviderRun(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - defer mockGenerateSelfSignedCertKey(testOneYearCert2, testOneYearKey2)() ctx, cancel := context.WithCancel(context.Background()) defer cancel() var objs []runtime.Object @@ -272,7 +270,11 @@ func TestSelfSignedCertProviderRun(t *testing.T) { objs = append(objs, tt.existingSecret) } client := fakeclientset.NewSimpleClientset(objs...) - p := newTestSelfSignedCertProvider(t, client, tt.tlsSecretName, tt.minValidDuration) + // mock the generateSelfSignedCertKey fuction + generateSelfSignedCertKey := func(_ string, _ []net.IP, _ []string) ([]byte, []byte, error) { + return testOneYearCert2, testOneYearKey2, nil + } + p := newTestSelfSignedCertProvider(t, client, tt.tlsSecretName, tt.minValidDuration, withGenerateSelfSignedCertKeyFn(generateSelfSignedCertKey)) go p.Run(ctx, 1) if tt.updatedSecret != nil { client.CoreV1().Secrets(tt.updatedSecret.Namespace).Update(ctx, tt.updatedSecret, metav1.UpdateOptions{}) @@ -291,13 +293,3 @@ func TestSelfSignedCertProviderRun(t *testing.T) { }) } } - -func mockGenerateSelfSignedCertKey(cert, key []byte) func() { - originalFn := generateSelfSignedCertKey - generateSelfSignedCertKey = func(_ string, _ []net.IP, _ []string) ([]byte, []byte, error) { - return cert, key, nil - } - return func() { - generateSelfSignedCertKey = originalFn - } -}