Skip to content

Commit

Permalink
verify cluster name of TLS peer certificates (#52130)
Browse files Browse the repository at this point in the history
  • Loading branch information
capnspacehook authored Feb 13, 2025
1 parent f22cc97 commit 31281ba
Show file tree
Hide file tree
Showing 18 changed files with 640 additions and 212 deletions.
14 changes: 14 additions & 0 deletions api/types/authority_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,17 @@ func TestRotationZero(t *testing.T) {
require.Equal(t, tt.z, tt.r.IsZero(), tt.d)
}
}

// Test that the spec cluster name name will be set to match the resource name
func TestCheckAndSetDefaults(t *testing.T) {
ca := CertAuthorityV2{
Metadata: Metadata{Name: "caName"},
Spec: CertAuthoritySpecV2{
ClusterName: "clusterName",
Type: HostCA,
},
}
err := ca.CheckAndSetDefaults()
require.NoError(t, err)
require.Equal(t, ca.Metadata.Name, ca.Spec.ClusterName)
}
103 changes: 63 additions & 40 deletions integration/helpers/instance.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,15 @@ type InstanceSecrets struct {
// PrivKey is instance private key
PrivKey []byte `json:"priv"`
// Cert is SSH host certificate
Cert []byte `json:"cert"`
// TLSCACert is the certificate of the trusted certificate authority
TLSCACert []byte `json:"tls_ca_cert"`
// TLSCert is client TLS X509 certificate
TLSCert []byte `json:"tls_cert"`
SSHHostCert []byte `json:"cert"`
// TLSHostCACert is the certificate of the trusted host certificate authority
TLSHostCACert []byte `json:"tls_host_ca_cert"`
// TLSCert is client TLS host X509 certificate
TLSHostCert []byte `json:"tls_host_cert"`
// TLSUserCACert is the certificate of the trusted user certificate authority
TLSUserCACert []byte `json:"tls_user_ca_cert"`
// TLSUserCert is client TLS user X509 certificate
TLSUserCert []byte `json:"tls_user_cert"`
// TunnelAddr is a reverse tunnel listening port, allowing
// other sites to connect to i instance. Set to empty
// string if i instance is not allowing incoming tunnels
Expand Down Expand Up @@ -138,9 +142,7 @@ func (s *InstanceSecrets) GetRoles(t *testing.T) []types.Role {
return roles
}

// GetCAs return an array of CAs stored by the secrets object. In i
// case we always return hard-coded userCA + hostCA (and they share keys
// for simplicity)
// GetCAs return an array of CAs stored by the secrets object
func (s *InstanceSecrets) GetCAs() ([]types.CertAuthority, error) {
hostCA, err := types.NewCertAuthority(types.CertAuthoritySpecV2{
Type: types.HostCA,
Expand All @@ -154,7 +156,7 @@ func (s *InstanceSecrets) GetCAs() ([]types.CertAuthority, error) {
TLS: []*types.TLSKeyPair{{
Key: s.PrivKey,
KeyType: types.PrivateKeyType_RAW,
Cert: s.TLSCACert,
Cert: s.TLSHostCACert,
}},
},
})
Expand All @@ -174,7 +176,7 @@ func (s *InstanceSecrets) GetCAs() ([]types.CertAuthority, error) {
TLS: []*types.TLSKeyPair{{
Key: s.PrivKey,
KeyType: types.PrivateKeyType_RAW,
Cert: s.TLSCACert,
Cert: s.TLSUserCACert,
}},
},
Roles: []string{services.RoleNameForCertAuthority(s.SiteName)},
Expand All @@ -190,7 +192,7 @@ func (s *InstanceSecrets) GetCAs() ([]types.CertAuthority, error) {
TLS: []*types.TLSKeyPair{{
Key: s.PrivKey,
KeyType: types.PrivateKeyType_RAW,
Cert: s.TLSCACert,
Cert: s.TLSHostCACert,
}},
},
})
Expand All @@ -205,7 +207,7 @@ func (s *InstanceSecrets) GetCAs() ([]types.CertAuthority, error) {
TLS: []*types.TLSKeyPair{{
Key: s.PrivKey,
KeyType: types.PrivateKeyType_RAW,
Cert: s.TLSCACert,
Cert: s.TLSHostCACert,
}},
},
})
Expand Down Expand Up @@ -262,9 +264,9 @@ func (s *InstanceSecrets) AsSlice() []*InstanceSecrets {

func (s *InstanceSecrets) GetIdentity() *state.Identity {
i, err := state.ReadIdentityFromKeyPair(s.PrivKey, &clientproto.Certs{
SSH: s.Cert,
TLS: s.TLSCert,
TLSCACerts: [][]byte{s.TLSCACert},
SSH: s.SSHHostCert,
TLS: s.TLSHostCert,
TLSCACerts: [][]byte{s.TLSHostCACert},
})
fatalIf(err)
return i
Expand Down Expand Up @@ -356,17 +358,11 @@ func NewInstance(t *testing.T, cfg InstanceConfig) *TeleInstance {
key, err := keys.ParsePrivateKey(cfg.Priv)
fatalIf(err)

tlsCACert, err := tlsca.GenerateSelfSignedCAWithSigner(key, pkix.Name{
CommonName: cfg.ClusterName,
Organization: []string{cfg.ClusterName},
}, nil, defaults.CATTL)
fatalIf(err)

sshSigner, err := ssh.NewSignerFromSigner(key)
fatalIf(err)

keygen := keygen.New(context.TODO())
cert, err := keygen.GenerateHostCert(sshca.HostCertificateRequest{
hostCert, err := keygen.GenerateHostCert(sshca.HostCertificateRequest{
CASigner: sshSigner,
PublicHostKey: cfg.Pub,
HostID: cfg.HostID,
Expand All @@ -378,23 +374,48 @@ func NewInstance(t *testing.T, cfg InstanceConfig) *TeleInstance {
},
})
fatalIf(err)
tlsCA, err := tlsca.FromKeys(tlsCACert, cfg.Priv)
fatalIf(err)
cryptoPubKey, err := sshutils.CryptoPublicKey(cfg.Pub)
fatalIf(err)
identity := tlsca.Identity{
Username: fmt.Sprintf("%v.%v", cfg.HostID, cfg.ClusterName),
Groups: []string{string(types.RoleAdmin)},
}

clock := cfg.Clock
if clock == nil {
clock = clockwork.NewRealClock()
}

identity := tlsca.Identity{
Username: fmt.Sprintf("%v.%v", cfg.HostID, cfg.ClusterName),
Groups: []string{string(types.RoleAdmin)},
}
subject, err := identity.Subject()
fatalIf(err)
tlsCert, err := tlsCA.GenerateCertificate(tlsca.CertificateRequest{

tlsCAHostCert, err := tlsca.GenerateSelfSignedCAWithSigner(key, pkix.Name{
CommonName: cfg.ClusterName,
Organization: []string{cfg.ClusterName},
}, nil, defaults.CATTL)
fatalIf(err)
tlsHostCA, err := tlsca.FromKeys(tlsCAHostCert, cfg.Priv)
fatalIf(err)
hostCryptoPubKey, err := sshutils.CryptoPublicKey(cfg.Pub)
fatalIf(err)
tlsHostCert, err := tlsHostCA.GenerateCertificate(tlsca.CertificateRequest{
Clock: clock,
PublicKey: hostCryptoPubKey,
Subject: subject,
NotAfter: clock.Now().UTC().Add(time.Hour * 24),
})
fatalIf(err)

tlsCAUserCert, err := tlsca.GenerateSelfSignedCAWithSigner(key, pkix.Name{
CommonName: cfg.ClusterName,
Organization: []string{cfg.ClusterName},
}, nil, defaults.CATTL)
fatalIf(err)
tlsUserCA, err := tlsca.FromKeys(tlsCAHostCert, cfg.Priv)
fatalIf(err)
userCryptoPubKey, err := sshutils.CryptoPublicKey(cfg.Pub)
fatalIf(err)
tlsUserCert, err := tlsUserCA.GenerateCertificate(tlsca.CertificateRequest{
Clock: clock,
PublicKey: cryptoPubKey,
PublicKey: userCryptoPubKey,
Subject: subject,
NotAfter: clock.Now().UTC().Add(time.Hour * 24),
})
Expand All @@ -409,14 +430,16 @@ func NewInstance(t *testing.T, cfg InstanceConfig) *TeleInstance {
}

secrets := InstanceSecrets{
SiteName: cfg.ClusterName,
PrivKey: cfg.Priv,
PubKey: cfg.Pub,
Cert: cert,
TLSCACert: tlsCACert,
TLSCert: tlsCert,
TunnelAddr: i.ReverseTunnel,
Users: make(map[string]*User),
SiteName: cfg.ClusterName,
PrivKey: cfg.Priv,
PubKey: cfg.Pub,
SSHHostCert: hostCert,
TLSHostCACert: tlsCAHostCert,
TLSHostCert: tlsHostCert,
TLSUserCACert: tlsCAUserCert,
TLSUserCert: tlsUserCert,
TunnelAddr: i.ReverseTunnel,
Users: make(map[string]*User),
}

i.Secrets = secrets
Expand Down
108 changes: 81 additions & 27 deletions lib/auth/authclient/tls.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,59 +39,113 @@ type CAGetter interface {
GetCertAuthorities(ctx context.Context, caType types.CertAuthType, loadKeys bool) ([]types.CertAuthority, error)
}

// ClientCertPool returns trusted x509 certificate authority pool with CAs provided as caTypes.
// HostAndUserCAInfo is a map of CA raw subjects and type info for Host
// and User CAs. The key is the RawSubject of the X.509 certificate authority
// (so it's ASN.1 data, not printable).
type HostAndUserCAInfo = map[string]CATypeInfo

// CATypeInfo indicates whether the CA is a host or user CA, or both.
type CATypeInfo struct {
IsHostCA bool
IsUserCA bool
}

// ClientCertPool returns trusted x509 certificate authority pool with CAs provided as caType.
// In addition, it returns the total length of all subjects added to the cert pool, allowing
// the caller to validate that the pool doesn't exceed the maximum 2-byte length prefix before
// using it.
func ClientCertPool(ctx context.Context, client CAGetter, clusterName string, caTypes ...types.CertAuthType) (*x509.CertPool, int64, error) {
if len(caTypes) == 0 {
return nil, 0, trace.BadParameter("at least one CA type is required")
func ClientCertPool(ctx context.Context, client CAGetter, clusterName string, caType types.CertAuthType) (*x509.CertPool, int64, error) {
authorities, err := getCACerts(ctx, client, clusterName, caType)
if err != nil {
return nil, 0, trace.Wrap(err)
}

pool := x509.NewCertPool()
var authorities []types.CertAuthority
if clusterName == "" {
for _, caType := range caTypes {
cas, err := client.GetCertAuthorities(ctx, caType, false)
if err != nil {
return nil, 0, trace.Wrap(err)
}
authorities = append(authorities, cas...)
}
} else {
for _, caType := range caTypes {
ca, err := client.GetCertAuthority(
ctx,
types.CertAuthID{Type: caType, DomainName: clusterName},
false)
var totalSubjectsLen int64
for _, auth := range authorities {
for _, keyPair := range auth.GetTrustedTLSKeyPairs() {
cert, err := tlsca.ParseCertificatePEM(keyPair.Cert)
if err != nil {
return nil, 0, trace.Wrap(err)
}
pool.AddCert(cert)

authorities = append(authorities, ca)
// Each subject in the list gets a separate 2-byte length prefix.
totalSubjectsLen += 2
totalSubjectsLen += int64(len(cert.RawSubject))
}
}
return pool, totalSubjectsLen, nil
}

// DefaultClientCertPool returns default trusted x509 certificate authority pool.
func DefaultClientCertPool(ctx context.Context, client CAGetter, clusterName string) (*x509.CertPool, HostAndUserCAInfo, int64, error) {
authorities, err := getCACerts(ctx, client, clusterName, types.HostCA, types.UserCA)
if err != nil {
return nil, nil, 0, trace.Wrap(err)
}

pool := x509.NewCertPool()
caInfos := make(HostAndUserCAInfo, len(authorities))
var totalSubjectsLen int64
for _, auth := range authorities {
for _, keyPair := range auth.GetTrustedTLSKeyPairs() {
cert, err := tlsca.ParseCertificatePEM(keyPair.Cert)
if err != nil {
return nil, 0, trace.Wrap(err)
return nil, nil, 0, trace.Wrap(err)
}
pool.AddCert(cert)

caType := auth.GetType()
caInfo := caInfos[string(cert.RawSubject)]
switch caType {
case types.HostCA:
caInfo.IsHostCA = true
case types.UserCA:
caInfo.IsUserCA = true
default:
return nil, nil, 0, trace.BadParameter("unexpected CA type %q", caType)
}
caInfos[string(cert.RawSubject)] = caInfo

// Each subject in the list gets a separate 2-byte length prefix.
totalSubjectsLen += 2
totalSubjectsLen += int64(len(cert.RawSubject))
}
}
return pool, totalSubjectsLen, nil

return pool, caInfos, totalSubjectsLen, nil
}

// DefaultClientCertPool returns default trusted x509 certificate authority pool.
func DefaultClientCertPool(ctx context.Context, client CAGetter, clusterName string) (*x509.CertPool, int64, error) {
return ClientCertPool(ctx, client, clusterName, types.HostCA, types.UserCA)
func getCACerts(ctx context.Context, client CAGetter, clusterName string, caTypes ...types.CertAuthType) ([]types.CertAuthority, error) {
if len(caTypes) == 0 {
return nil, trace.BadParameter("at least one CA type is required")
}

var authorities []types.CertAuthority
if clusterName == "" {
for _, caType := range caTypes {
cas, err := client.GetCertAuthorities(ctx, caType, false)
if err != nil {
return nil, trace.Wrap(err)
}
authorities = append(authorities, cas...)
}
} else {
for _, caType := range caTypes {
ca, err := client.GetCertAuthority(
ctx,
types.CertAuthID{Type: caType, DomainName: clusterName},
false)
if err != nil {
return nil, trace.Wrap(err)
}

authorities = append(authorities, ca)
}
}

return authorities, nil
}

// WithClusterCAs returns a TLS hello callback that returns a copy of the provided
Expand All @@ -110,7 +164,7 @@ func WithClusterCAs(tlsConfig *tls.Config, ap CAGetter, currentClusterName strin
}
}
}
pool, totalSubjectsLen, err := DefaultClientCertPool(info.Context(), ap, clusterName)
pool, _, totalSubjectsLen, err := DefaultClientCertPool(info.Context(), ap, clusterName)
if err != nil {
log.WithError(err).Errorf("Failed to retrieve client pool for %q.", clusterName)
// this falls back to the default config
Expand All @@ -132,7 +186,7 @@ func WithClusterCAs(tlsConfig *tls.Config, ap CAGetter, currentClusterName strin
if totalSubjectsLen >= int64(math.MaxUint16) {
log.Debugf("Number of CAs in client cert pool is too large and cannot be encoded in a TLS handshake; this is due to a large number of trusted clusters; will use only the CA of the current cluster to validate.")

pool, _, err = DefaultClientCertPool(info.Context(), ap, currentClusterName)
pool, _, _, err = DefaultClientCertPool(info.Context(), ap, currentClusterName)
if err != nil {
log.WithError(err).Errorf("Failed to retrieve client pool for %q.", currentClusterName)
// this falls back to the default config
Expand Down
Loading

0 comments on commit 31281ba

Please sign in to comment.