From ae1b9f1684c1bf96427cd685305c8d3f0d6b5359 Mon Sep 17 00:00:00 2001 From: rverdile Date: Wed, 15 Jan 2025 20:53:45 -0500 Subject: [PATCH] add CertUser interface --- pkg/admin_client/client.go | 51 +---------- pkg/candlepin_client/client.go | 21 +---- pkg/config/certificates.go | 150 +++++++++++++++++++++++++++++++++ pkg/config/config.go | 34 -------- 4 files changed, 154 insertions(+), 102 deletions(-) diff --git a/pkg/admin_client/client.go b/pkg/admin_client/client.go index f836db7c4..cf9fae3ff 100644 --- a/pkg/admin_client/client.go +++ b/pkg/admin_client/client.go @@ -4,11 +4,10 @@ import ( "context" "encoding/json" "fmt" - "github.com/content-services/content-sources-backend/pkg/config" "io" "net/http" - "os" - "time" + + "github.com/content-services/content-sources-backend/pkg/config" ) type AdminClient interface { @@ -20,7 +19,7 @@ type adminClientImpl struct { } func NewAdminClient() (AdminClient, error) { - httpClient, err := getHTTPClient() + httpClient, err := config.GetHTTPClient(&config.SubsAsFeatsCertUser{}) if err != nil { return nil, err } @@ -82,47 +81,3 @@ func (ac adminClientImpl) ListFeatures(ctx context.Context) (FeaturesResponse, i return featResp, statusCode, nil } - -func getHTTPClient() (http.Client, error) { - timeout := 90 * time.Second - - var cert []byte - if config.Get().Clients.SubsAsFeatures.ClientCert != "" { - cert = []byte(config.Get().Clients.SubsAsFeatures.ClientCert) - } else if config.Get().Clients.SubsAsFeatures.ClientCertPath != "" { - file, err := os.ReadFile(config.Get().Clients.SubsAsFeatures.ClientCertPath) - if err != nil { - return http.Client{}, err - } - cert = file - } - - var key []byte - if config.Get().Clients.SubsAsFeatures.ClientKey != "" { - key = []byte(config.Get().Clients.SubsAsFeatures.ClientKey) - } else if config.Get().Clients.SubsAsFeatures.ClientKeyPath != "" { - file, err := os.ReadFile(config.Get().Clients.SubsAsFeatures.ClientKeyPath) - if err != nil { - return http.Client{}, err - } - key = file - } - - var caCert []byte - if config.Get().Clients.SubsAsFeatures.CACert != "" { - caCert = []byte(config.Get().Clients.SubsAsFeatures.CACert) - } else if config.Get().Clients.SubsAsFeatures.CACertPath != "" { - file, err := os.ReadFile(config.Get().Clients.SubsAsFeatures.CACertPath) - if err != nil { - return http.Client{}, err - } - caCert = file - } - - transport, err := config.GetTransport(cert, key, caCert, timeout) - if err != nil { - return http.Client{}, fmt.Errorf("error creating http transport: %w", err) - } - - return http.Client{Transport: transport, Timeout: timeout}, nil -} diff --git a/pkg/candlepin_client/client.go b/pkg/candlepin_client/client.go index 0ccba3948..1ce8d8f4e 100644 --- a/pkg/candlepin_client/client.go +++ b/pkg/candlepin_client/client.go @@ -5,7 +5,6 @@ import ( "fmt" "io" "net/http" - "time" caliri "github.com/content-services/caliri/release/v4" "github.com/content-services/content-sources-backend/pkg/config" @@ -32,24 +31,6 @@ func errorWithResponseBody(message string, httpResp *http.Response, err error) e return err } -func getHTTPClient() (http.Client, error) { - timeout := 90 * time.Second - transport := &http.Transport{ResponseHeaderTimeout: timeout} - var err error - - certStr := config.Get().Clients.Candlepin.ClientCert - keyStr := config.Get().Clients.Candlepin.ClientKey - ca := config.Get().Clients.Candlepin.CACert - - if certStr != "" { - transport, err = config.GetTransport([]byte(certStr), []byte(keyStr), []byte(ca), timeout) - if err != nil { - return http.Client{}, fmt.Errorf("could not create http transport: %w", err) - } - } - return http.Client{Transport: transport, Timeout: timeout}, nil -} - func getCorrelationId(ctx context.Context) string { value := ctx.Value(config.ContextRequestIDKey{}) if value != nil { @@ -68,7 +49,7 @@ func NewCandlepinClient() CandlepinClient { } func getCandlepinClient(ctx context.Context) (context.Context, *caliri.APIClient, error) { - httpClient, err := getHTTPClient() + httpClient, err := config.GetHTTPClient(&config.CandlepinCertUser{}) if err != nil { return nil, nil, err } diff --git a/pkg/config/certificates.go b/pkg/config/certificates.go index d912156be..f2f5d9bac 100644 --- a/pkg/config/certificates.go +++ b/pkg/config/certificates.go @@ -1 +1,151 @@ package config + +import ( + "crypto/tls" + "crypto/x509" + "fmt" + "net/http" + "os" + "time" +) + +type CertUser interface { + ClientCert() string + ClientKey() string + CACert() string + ClientCertPath() string + ClientKeyPath() string + CACertPath() string +} + +func GetHTTPClient(certUser CertUser) (http.Client, error) { + timeout := 90 * time.Second + + var cert []byte + if certUser.ClientCert() != "" { + cert = []byte(certUser.ClientCert()) + } else if certUser.ClientCertPath() != "" { + file, err := os.ReadFile(certUser.ClientCertPath()) + if err != nil { + return http.Client{}, err + } + cert = file + } + + var key []byte + if certUser.ClientKey() != "" { + key = []byte(certUser.ClientKey()) + } else if certUser.ClientKeyPath() != "" { + file, err := os.ReadFile(certUser.ClientKeyPath()) + if err != nil { + return http.Client{}, err + } + key = file + } + + var caCert []byte + if certUser.CACert() != "" { + caCert = []byte(certUser.CACert()) + } else if certUser.CACertPath() != "" { + file, err := os.ReadFile(certUser.CACertPath()) + if err != nil { + return http.Client{}, err + } + caCert = file + } + + transport, err := GetTransport(cert, key, caCert, timeout) + if err != nil { + return http.Client{}, fmt.Errorf("error creating http transport: %w", err) + } + + return http.Client{Transport: transport, Timeout: timeout}, nil +} + +func GetTransport(certBytes, keyBytes, caCertBytes []byte, timeout time.Duration) (*http.Transport, error) { + transport := &http.Transport{ResponseHeaderTimeout: timeout} + + if certBytes != nil && keyBytes != nil { + cert, err := tls.X509KeyPair(certBytes, keyBytes) + if err != nil { + return transport, fmt.Errorf("could not load keypair: %w", err) + } + tlsConfig := &tls.Config{ + Certificates: []tls.Certificate{cert}, + MinVersion: tls.VersionTLS12, + } + + if caCertBytes != nil { + pool, err := certPool(caCertBytes) + if err != nil { + return transport, err + } + tlsConfig.RootCAs = pool + } + transport.TLSClientConfig = tlsConfig + } + return transport, nil +} + +func certPool(caCert []byte) (*x509.CertPool, error) { + pool := x509.NewCertPool() + ok := pool.AppendCertsFromPEM(caCert) + if !ok { + return nil, fmt.Errorf("could not parse candlepin ca cert") + } + return pool, nil +} + +type SubsAsFeatsCertUser struct { +} + +func (c *SubsAsFeatsCertUser) ClientCert() string { + return Get().Clients.SubsAsFeatures.ClientCert +} + +func (c *SubsAsFeatsCertUser) ClientKey() string { + return Get().Clients.SubsAsFeatures.ClientKey +} + +func (c *SubsAsFeatsCertUser) CACert() string { + return Get().Clients.SubsAsFeatures.CACert +} + +func (c *SubsAsFeatsCertUser) CACertPath() string { + return Get().Clients.SubsAsFeatures.CACertPath +} + +func (c *SubsAsFeatsCertUser) ClientCertPath() string { + return Get().Clients.SubsAsFeatures.ClientCertPath +} + +func (c *SubsAsFeatsCertUser) ClientKeyPath() string { + return Get().Clients.SubsAsFeatures.ClientKeyPath +} + +type CandlepinCertUser struct { +} + +func (c *CandlepinCertUser) ClientCert() string { + return Get().Clients.Candlepin.ClientCert +} + +func (c *CandlepinCertUser) ClientKey() string { + return Get().Clients.Candlepin.ClientKey +} + +func (c *CandlepinCertUser) CACert() string { + return Get().Clients.Candlepin.CACert +} + +func (c *CandlepinCertUser) CACertPath() string { + return "" +} + +func (c *CandlepinCertUser) ClientCertPath() string { + return "" +} + +func (c *CandlepinCertUser) ClientKeyPath() string { + return "" +} diff --git a/pkg/config/config.go b/pkg/config/config.go index 03c10e483..c1f15deb1 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -491,40 +491,6 @@ func ConfigureCertificate() (*tls.Certificate, *string, error) { return &cert, &certString, nil } -func GetTransport(certBytes, keyBytes, caCertBytes []byte, timeout time.Duration) (*http.Transport, error) { - transport := &http.Transport{ResponseHeaderTimeout: timeout} - - if certBytes != nil && keyBytes != nil { - cert, err := tls.X509KeyPair(certBytes, keyBytes) - if err != nil { - return transport, fmt.Errorf("could not load keypair: %w", err) - } - tlsConfig := &tls.Config{ - Certificates: []tls.Certificate{cert}, - MinVersion: tls.VersionTLS12, - } - - if caCertBytes != nil { - pool, err := certPool(caCertBytes) - if err != nil { - return transport, err - } - tlsConfig.RootCAs = pool - } - transport.TLSClientConfig = tlsConfig - } - return transport, nil -} - -func certPool(caCert []byte) (*x509.CertPool, error) { - pool := x509.NewCertPool() - ok := pool.AppendCertsFromPEM(caCert) - if !ok { - return nil, fmt.Errorf("could not parse candlepin ca cert") - } - return pool, nil -} - func CDNCertDaysTillExpiration() (int, error) { if Get().Certs.CdnCertPair == nil { return 0, nil