diff --git a/pkg/queue/certificate/watcher.go b/pkg/queue/certificate/watcher.go new file mode 100644 index 000000000000..cbd27ea9d863 --- /dev/null +++ b/pkg/queue/certificate/watcher.go @@ -0,0 +1,135 @@ +/* +Copyright 2023 The Knative Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package certificate + +import ( + "crypto/sha256" + "crypto/tls" + "fmt" + "os" + "path" + "sync" + "time" + + "go.uber.org/zap" +) + +// CertWatcher watches certificate and key files and reloads them if they change on disk. +type CertWatcher struct { + certPath string + certChecksum [sha256.Size]byte + keyPath string + keyChecksum [sha256.Size]byte + + certificate *tls.Certificate + + logger *zap.SugaredLogger + ticker *time.Ticker + stop chan struct{} + mux sync.RWMutex +} + +// NewCertWatcher creates a CertWatcher and watches +// the certificate and key files. It reloads the contents on file change. +// Make sure to stop the CertWatcher using Stop() upon destroy. +func NewCertWatcher(certPath, keyPath string, reloadInterval time.Duration, logger *zap.SugaredLogger) (*CertWatcher, error) { + cw := &CertWatcher{ + certPath: certPath, + keyPath: keyPath, + logger: logger, + ticker: time.NewTicker(reloadInterval), + stop: make(chan struct{}), + mux: sync.RWMutex{}, + } + + certDir := path.Dir(cw.certPath) + keyDir := path.Dir(cw.keyPath) + + cw.logger.Info("Starting to watch the following directories for changes", + zap.String("certDir", certDir), zap.String("keyDir", keyDir)) + + // initial load + if err := cw.loadCert(); err != nil { + return nil, err + } + + go cw.watch() + + return cw, nil +} + +// Stop shuts down the CertWatcher. Use this with `defer`. +func (cw *CertWatcher) Stop() { + cw.logger.Info("Stopping file watcher") + close(cw.stop) + cw.ticker.Stop() +} + +// GetCertificate returns the server certificate for a client-hello request. +func (cw *CertWatcher) GetCertificate(_ *tls.ClientHelloInfo) (*tls.Certificate, error) { + cw.mux.RLock() + defer cw.mux.RUnlock() + return cw.certificate, nil +} + +func (cw *CertWatcher) watch() { + for { + select { + case <-cw.stop: + return + + case <-cw.ticker.C: + // On error, we do not want to stop trying + if err := cw.loadCert(); err != nil { + cw.logger.Error(err) + } + } + } +} + +func (cw *CertWatcher) loadCert() error { + var err error + certFile, err := os.ReadFile(cw.certPath) + if err != nil { + return fmt.Errorf("failed to load certificate file in %s: %w", cw.certPath, err) + } + keyFile, err := os.ReadFile(cw.keyPath) + if err != nil { + return fmt.Errorf("failed to load key file in %s: %w", cw.keyPath, err) + } + + certChecksum := sha256.Sum256(certFile) + keyChecksum := sha256.Sum256(keyFile) + + if certChecksum != cw.certChecksum || keyChecksum != cw.keyChecksum { + keyPair, err := tls.LoadX509KeyPair(cw.certPath, cw.keyPath) + if err != nil { + return fmt.Errorf("failed to parse certificate: %w", err) + } + + cw.mux.Lock() + defer cw.mux.Unlock() + + cw.certificate = &keyPair + cw.certChecksum = certChecksum + cw.keyChecksum = keyChecksum + + cw.logger.Info("Certificate and/or key have changed on disk and were reloaded.") + } + + return nil +} diff --git a/pkg/queue/certificate/watcher_test.go b/pkg/queue/certificate/watcher_test.go new file mode 100644 index 000000000000..36af949018d0 --- /dev/null +++ b/pkg/queue/certificate/watcher_test.go @@ -0,0 +1,123 @@ +/* +Copyright 2023 The Knative Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package certificate + +import ( + "crypto/tls" + "crypto/x509" + "fmt" + "os" + "testing" + "time" + + "k8s.io/apimachinery/pkg/util/wait" + "knative.dev/networking/pkg/certificates" + ktesting "knative.dev/pkg/logging/testing" +) + +const ( + initialSAN = "initial.knative" + updatedSAN = "updated.knative" +) + +func TestCertificateRotation(t *testing.T) { + // Create initial certificate and key on disk + dir := t.TempDir() + + err := createAndSaveCertificate(initialSAN, dir) + if err != nil { + t.Fatal("failed to create and save initial certificate", err) + } + + // Watch the certificate files + cw, err := NewCertWatcher(dir+"/"+certificates.CertName, dir+"/"+certificates.PrivateKeyName, 1*time.Second, ktesting.TestLogger(t)) + if err != nil { + t.Fatal("failed to create CertWatcher", err) + } + + // CertWatcher should return the expected certificate + c, err := cw.GetCertificate(nil) + if err != nil { + t.Fatal("failed to call GetCertificate on CertWatcher", err) + } + san, err := getSAN(c) + if err != nil { + t.Fatal("failed to parse SAN of certificate", err) + } + if san != initialSAN { + t.Errorf("CertWatcher did not return the expected certificate. want: %s, got: %s", initialSAN, san) + } + + // Update the certificate and key on disk + err = createAndSaveCertificate(updatedSAN, dir) + if err != nil { + t.Fatal("failed to update and save initial certificate", err) + } + + // CertWatcher should return the new certificate + // Give CertWatcher some time to update the certificate + if err := wait.Poll(1*time.Second, 30*time.Second, func() (bool, error) { + c, err = cw.GetCertificate(nil) + if err != nil { + return false, err + } + + san, err = getSAN(c) + if err != nil { + return false, err + } + + if san != updatedSAN { + return false, fmt.Errorf("CertWatcher did not return the expected certificate. want: %s, got: %s", updatedSAN, san) + } + + return true, nil + }); err != nil { + t.Fatal(err) + } +} + +func createAndSaveCertificate(san, dir string) error { + ca, err := certificates.CreateCACerts(1 * time.Hour) + if err != nil { + return err + } + + caCert, caKey, err := ca.Parse() + if err != nil { + return err + } + + cert, err := certificates.CreateCert(caKey, caCert, 1*time.Hour, san) + if err != nil { + return err + } + + if err := os.WriteFile(dir+"/"+certificates.CertName, cert.CertBytes(), 0644); err != nil { + return err + } + + return os.WriteFile(dir+"/"+certificates.PrivateKeyName, cert.PrivateKeyBytes(), 0644) +} + +func getSAN(c *tls.Certificate) (string, error) { + parsed, err := x509.ParseCertificate(c.Certificate[0]) + if err != nil { + return "", err + } + return parsed.DNSNames[0], nil +} diff --git a/pkg/queue/sharedmain/main.go b/pkg/queue/sharedmain/main.go index 0400d416457a..6a38b12d8488 100644 --- a/pkg/queue/sharedmain/main.go +++ b/pkg/queue/sharedmain/main.go @@ -18,6 +18,7 @@ package sharedmain import ( "context" + "crypto/tls" "errors" "fmt" "net/http" @@ -29,6 +30,7 @@ import ( "go.opencensus.io/plugin/ochttp" "go.uber.org/automaxprocs/maxprocs" "go.uber.org/zap" + "knative.dev/serving/pkg/queue/certificate" "k8s.io/apimachinery/pkg/types" @@ -245,16 +247,22 @@ func Main(opts ...Option) error { httpServers["profile"] = profiling.NewServer(profiling.NewHandler(logger, true)) } - tlsServers := map[string]*http.Server{ - "main": mainServer(":"+env.QueueServingTLSPort, mainHandler), - "admin": adminServer(":"+strconv.Itoa(networking.QueueAdminPort), adminHandler), - } + tlsServers := make(map[string]*http.Server) + var certWatcher *certificate.CertWatcher + var err error if tlsEnabled { + tlsServers["main"] = mainServer(":"+env.QueueServingTLSPort, mainHandler) + tlsServers["admin"] = adminServer(":"+strconv.Itoa(networking.QueueAdminPort), adminHandler) + + certWatcher, err = certificate.NewCertWatcher(certPath, keyPath, 1*time.Minute, logger) + if err != nil { + logger.Fatal("failed to create certWatcher", zap.Error(err)) + } + defer certWatcher.Stop() + // Drop admin http server since the admin TLS server is listening on the same port delete(httpServers, "admin") - } else { - tlsServers = map[string]*http.Server{} } logger.Info("Starting queue-proxy") @@ -271,9 +279,13 @@ func Main(opts ...Option) error { } for name, server := range tlsServers { go func(name string, s *http.Server) { - // Don't forward ErrServerClosed as that indicates we're already shutting down. logger.Info("Starting tls server ", name, s.Addr) - if err := s.ListenAndServeTLS(certPath, keyPath); err != nil && !errors.Is(err, http.ErrServerClosed) { + s.TLSConfig = &tls.Config{ + GetCertificate: certWatcher.GetCertificate, + MinVersion: tls.VersionTLS13, + } + // Don't forward ErrServerClosed as that indicates we're already shutting down. + if err := s.ListenAndServeTLS("", ""); err != nil && !errors.Is(err, http.ErrServerClosed) { errCh <- fmt.Errorf("%s server failed to serve: %w", name, err) } }(name, server) @@ -303,6 +315,7 @@ func Main(opts ...Option) error { logger.Errorw("Failed to shutdown server", zap.String("server", name), zap.Error(err)) } } + logger.Info("Shutdown complete, exiting...") } return nil