From 1a91d6369a4e71087290cf3df03d13008116068e Mon Sep 17 00:00:00 2001 From: Markus Wiegand Date: Mon, 14 Aug 2023 18:20:14 +0200 Subject: [PATCH] rework certificate caching Cache only the leaf certificate and reduce the verification of cached certificates to period check. --- middleware/jwt/cert.go | 135 +++++++++++++++++------------------------ middleware/jwt/jwt.go | 29 ++++++--- 2 files changed, 74 insertions(+), 90 deletions(-) diff --git a/middleware/jwt/cert.go b/middleware/jwt/cert.go index 947ab59..68be6e3 100644 --- a/middleware/jwt/cert.go +++ b/middleware/jwt/cert.go @@ -9,7 +9,6 @@ import ( "fmt" "log" "os" - "slices" "sync" "github.com/golang-jwt/jwt/v5" @@ -17,7 +16,7 @@ import ( var store = certStore{ roots: x509.NewCertPool(), - certs: make(map[string]*certChain), + certs: make(map[string]*x509.Certificate), } func init() { @@ -39,31 +38,35 @@ func init() { } } } +} - if path := os.Getenv("JWT_CERTS"); path != "" { - certs, err := parseCertsFromPEM(path) - if err != nil { - log.Printf("failed to parse certificates: %v", err) - os.Exit(2) - } +type certStore struct { + roots *x509.CertPool + certs map[string]*x509.Certificate + sync.RWMutex +} - if len(certs) == 0 { - log.Printf("no certificates") - os.Exit(2) - } +func (s *certStore) get(fingerprint string) (*x509.Certificate, bool) { + s.RLock() + defer s.RUnlock() - lastIdx := len(certs) - 1 + chain, ok := s.certs[fingerprint] - intermediates := x509.NewCertPool() - for _, cert := range certs[:lastIdx] { - intermediates.AddCert(cert) - } + return chain, ok +} - store.add(&certChain{ - intermediates: intermediates, - leaf: certs[lastIdx], - }) - } +func (s *certStore) add(cert *x509.Certificate) { + s.Lock() + defer s.Unlock() + + s.certs[fingerprintBase64(cert)] = cert +} + +func (s *certStore) remove(fingerprint string) { + s.Lock() + defer s.Unlock() + + delete(s.certs, fingerprint) } func parseCertsFromPEM(path string) ([]*x509.Certificate, error) { @@ -80,7 +83,7 @@ func parseCertsFromPEM(path string) ([]*x509.Certificate, error) { break } - if block.Type != "CERTIFICATE" { + if block.Type != "CERTIFICATE" || len(block.Headers) != 0 { continue } @@ -92,44 +95,20 @@ func parseCertsFromPEM(path string) ([]*x509.Certificate, error) { certs = append(certs, cert) } - return certs, nil -} - -type certStore struct { - roots *x509.CertPool - certs map[string]*certChain - sync.RWMutex -} - -func (s *certStore) add(chain *certChain) error { - err := chain.verify(s.roots) - if err != nil { - return err + lastIdx := len(certs) - 1 + for i := lastIdx; i >= 0; i-- { + c := certs[i] + if i == lastIdx && c.IsCA || i < lastIdx && !c.IsCA { + return nil, errors.New("invalid certificate chain") + } } - s.Lock() - defer s.Unlock() - - s.certs[fingerprintBase64(chain.leaf)] = chain - - return nil -} - -func (s *certStore) get(fingerprint string) (*certChain, bool) { - s.RLock() - defer s.RUnlock() - - chain, ok := s.certs[fingerprint] - - return chain, ok -} - -type certChain struct { - intermediates *x509.CertPool - leaf *x509.Certificate + return certs, nil } -func parseTokenCerts(token *jwt.Token) (*certChain, error) { +// Parses the certificate chain from the JWT x5c header defined in RFC7515. +// The returned slice starts with the leaf certificate followed by the intermediates. +func parseCertsFromToken(token *jwt.Token) ([]*x509.Certificate, error) { x5c, ok := token.Header["x5c"].([]interface{}) if !ok { return nil, errors.New("invalid x5c header") @@ -150,44 +129,40 @@ func parseTokenCerts(token *jwt.Token) (*certChain, error) { return nil, fmt.Errorf("failed to parse certificate %d: %v", i, err) } - certs[i] = c - } - - slices.Reverse(certs) - - lastIdx := len(certs) - 1 + if i == 0 && c.IsCA || i > 0 && !c.IsCA { + return nil, errors.New("invalid certificate chain") + } - intermediates := x509.NewCertPool() - for _, cert := range certs[:lastIdx] { - intermediates.AddCert(cert) + certs[i] = c } - return &certChain{ - intermediates: intermediates, - leaf: certs[lastIdx], - }, nil -} - -func (c *certChain) publicKey() interface{} { - return c.leaf.PublicKey + return certs, nil } -func (c *certChain) verify(roots *x509.CertPool) error { +func verifyCert(leaf *x509.Certificate, intermediates []*x509.Certificate, roots *x509.CertPool) error { opts := x509.VerifyOptions{ Roots: roots, - Intermediates: c.intermediates, + Intermediates: x509.NewCertPool(), KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageAny}, } - if _, err := c.leaf.Verify(opts); err != nil { - return err + for _, cert := range intermediates { + if !cert.IsCA { + return errors.New("invalid intermediate certificate") + } + opts.Intermediates.AddCert(cert) + } + + _, err := leaf.Verify(opts) + if err != nil { + return fmt.Errorf("failed to verify certificate: %w", err) } - if c.leaf.KeyUsage&x509.KeyUsageDigitalSignature == 0 { + if leaf.KeyUsage&x509.KeyUsageDigitalSignature == 0 { return errors.New("invalid key usage") } - return nil + return err } func fingerprintBase64(cert *x509.Certificate) string { diff --git a/middleware/jwt/jwt.go b/middleware/jwt/jwt.go index 035f922..9af3043 100644 --- a/middleware/jwt/jwt.go +++ b/middleware/jwt/jwt.go @@ -1,6 +1,7 @@ package jwt import ( + "crypto/x509" "errors" "fmt" "net/http" @@ -227,26 +228,34 @@ func keyFunc(token *jwt.Token) (interface{}, error) { return nil, errors.New("invalid fingerprint") } - chain, ok := store.get(fingerprint) + cert, ok := store.get(fingerprint) if ok { - if err := chain.verify(store.roots); err != nil { - return nil, err + if time.Now().After(cert.NotAfter) { + store.remove(fingerprint) + return nil, errors.New("certificate expired") } } else { - var err error - - chain, err = parseTokenCerts(token) + certs, err := parseCertsFromToken(token) if err != nil { return nil, fmt.Errorf("failed to parse certificate chain: %w", err) } - // The add method verifies the chain - if err := store.add(chain); err != nil { - return nil, fmt.Errorf("failed to add certificate chain: %w", err) + leaf := certs[0] + + var intermediates []*x509.Certificate + if len(certs) > 1 { + intermediates = certs[1:] } + + if err := verifyCert(leaf, intermediates, store.roots); err != nil { + return nil, fmt.Errorf("failed to verify certificate: %w", err) + } + + store.add(leaf) + cert = leaf } - return chain.publicKey(), nil + return cert.PublicKey, nil } // AuhtorizationHandler returns a JWT authorization handler