Skip to content

Commit

Permalink
Reduce resource consumption when generating Kubernetes certificates (#…
Browse files Browse the repository at this point in the history
…52109)

Closes #52073.

The requested Kubernetes cluster is now cross referenced with the
KubeServers in the unified resource cache. This results in a
reduction in CPU, memory, and cert generation latency. This also
cleans up some of the helper functions in lib/kube/utils that
were no longer needed, and suboptimal.

The client side changes here shouldn't have any impact, as the
server is performing the same check, and returning the equivalent
error the client side code used to. This will also cut the time
of `tctl auth sign` in half as both the client and server were
performing the same expensive CheckKubeCluster operation.
  • Loading branch information
rosstimothy authored Feb 13, 2025
1 parent 047bfb5 commit 8a5107c
Show file tree
Hide file tree
Showing 7 changed files with 151 additions and 267 deletions.
23 changes: 21 additions & 2 deletions lib/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@ import (
"github.com/gravitational/teleport/lib/gitlab"
"github.com/gravitational/teleport/lib/inventory"
kubetoken "github.com/gravitational/teleport/lib/kube/token"
kubeutils "github.com/gravitational/teleport/lib/kube/utils"
"github.com/gravitational/teleport/lib/limiter"
"github.com/gravitational/teleport/lib/loginrule"
"github.com/gravitational/teleport/lib/modules"
Expand Down Expand Up @@ -3317,9 +3316,29 @@ func generateCert(ctx context.Context, a *Server, req certRequest, caType types.
// If the certificate is targeting a trusted Teleport cluster, it is the
// responsibility of the cluster to ensure its existence.
if req.routeToCluster == clusterName && req.kubernetesCluster != "" {
if err := kubeutils.CheckKubeCluster(a.closeCtx, a, req.kubernetesCluster); err != nil {
found, _, err := a.UnifiedResourceCache.IterateUnifiedResources(a.closeCtx, func(rwl types.ResourceWithLabels) (bool, error) {
if rwl.GetKind() != types.KindKubeServer {
return false, nil
}

ks, ok := rwl.(types.KubeServer)
if !ok {
return false, nil
}

return ks.GetCluster().GetName() == req.kubernetesCluster, nil
}, &proto.ListUnifiedResourcesRequest{
Kinds: []string{types.KindKubeServer},
SortBy: types.SortBy{Field: services.SortByName},
Limit: 1,
})
if err != nil {
return nil, trace.Wrap(err)
}

if len(found) == 0 {
return nil, trace.BadParameter("Kubernetes cluster %q is not registered in this Teleport cluster; you can list registered Kubernetes clusters using 'tsh kube ls'", req.kubernetesCluster)
}
}

// See which database names and users this user is allowed to use.
Expand Down
107 changes: 105 additions & 2 deletions lib/auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,12 +149,25 @@ func newTestPack(
}
p.a.SetLockWatcher(lockWatcher)

// set cluster name
err = p.a.SetClusterName(p.clusterName)
urc, err := services.NewUnifiedResourceCache(ctx, services.UnifiedResourceCacheConfig{
Clock: p.a.clock,
ResourceWatcherConfig: services.ResourceWatcherConfig{
Component: teleport.ComponentAuth,
Client: p.a,
},
ResourceGetter: p.a,
})
if err != nil {
return p, trace.Wrap(err)
}

p.a.SetUnifiedResourcesCache(urc)

// set cluster name
if err := p.a.SetClusterName(p.clusterName); err != nil {
return p, trace.Wrap(err)
}

// set static tokens
staticTokens, err := types.NewStaticTokens(types.StaticTokensSpecV2{
StaticTokens: []types.ProvisionTokenV1{},
Expand Down Expand Up @@ -3000,6 +3013,96 @@ func TestGenerateUserCertWithHardwareKeySupport(t *testing.T) {
}
}

func TestGenerateKubernetesUserCert(t *testing.T) {
ctx := context.Background()
p, err := newTestPack(ctx, t.TempDir())
require.NoError(t, err)

user, _, err := CreateUserAndRole(p.a, "test-user", []string{}, nil)
require.NoError(t, err)

rc, err := types.NewRemoteCluster("leaf")
require.NoError(t, err)
_, err = p.a.CreateRemoteCluster(ctx, rc)
require.NoError(t, err)

kubeCluster, err := types.NewKubernetesClusterV3(types.Metadata{Name: "kube-cluster"}, types.KubernetesClusterSpecV3{})
require.NoError(t, err)
kubeServer, err := types.NewKubernetesServerV3FromCluster(kubeCluster, "foo", "1")
require.NoError(t, err)
_, err = p.a.UpsertKubernetesServer(ctx, kubeServer)
require.NoError(t, err)

// Wait for cache propagation of the kubernetes resources before proceeding with the tests.
require.EventuallyWithT(t, func(t *assert.CollectT) {
found, _, err := p.a.UnifiedResourceCache.IterateUnifiedResources(ctx, func(rwl types.ResourceWithLabels) (bool, error) {
if rwl.GetKind() != types.KindKubeServer {
return false, nil
}

ks, ok := rwl.(types.KubeServer)
if !ok {
return false, nil
}

return ks.GetCluster().GetName() == kubeCluster.GetName(), nil
}, &proto.ListUnifiedResourcesRequest{
Kinds: []string{types.KindKubeServer},
SortBy: types.SortBy{Field: services.SortByName},
Limit: 1,
})

assert.NoError(t, err)
assert.Len(t, found, 1)
}, 10*time.Second, 100*time.Millisecond)

accessInfo := services.AccessInfoFromUserState(user)
accessChecker, err := services.NewAccessChecker(accessInfo, p.clusterName.GetClusterName(), p.a)
require.NoError(t, err)

_, sshPubKey, _, tlsPubKey := newSSHAndTLSKeyPairs(t)

for _, tt := range []struct {
name string
teleportCluster string
kubernetesCluster string
assertErr require.ErrorAssertionFunc
}{
{
name: "leaf clusters not validated",
teleportCluster: "leaf",
kubernetesCluster: "foo",
assertErr: require.NoError,
},
{
name: "kubernetes cluster not registered",
teleportCluster: p.clusterName.GetClusterName(),
kubernetesCluster: "foo",
assertErr: require.Error,
},
{
name: "kubernetes cluster registered",
teleportCluster: p.clusterName.GetClusterName(),
kubernetesCluster: kubeCluster.GetName(),
assertErr: require.NoError,
},
} {
t.Run(tt.name, func(t *testing.T) {
certReq := certRequest{
user: user,
checker: accessChecker,
sshPublicKey: sshPubKey,
tlsPublicKey: tlsPubKey,
routeToCluster: tt.teleportCluster,
kubernetesCluster: tt.kubernetesCluster,
}

_, err = p.a.generateUserCert(ctx, certReq)
tt.assertErr(t, err)
})
}
}

func TestNewWebSession(t *testing.T) {
t.Parallel()
ctx := context.Background()
Expand Down
59 changes: 0 additions & 59 deletions lib/kube/utils/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import (
"context"
"encoding/hex"
"errors"
"slices"
"strings"

"github.com/gravitational/trace"
Expand Down Expand Up @@ -147,48 +146,6 @@ func EncodeClusterName(clusterName string) string {
return "k" + hex.EncodeToString([]byte(clusterName))
}

// KubeServicesPresence fetches a list of registered kubernetes servers.
// It's a subset of services.Presence.
type KubeServicesPresence interface {
// GetKubernetesServers returns a list of registered kubernetes servers.
GetKubernetesServers(context.Context) ([]types.KubeServer, error)
}

// KubeClusterNames returns a sorted list of unique kubernetes cluster
// names registered in p.
//
// DELETE IN 11.0.0, replaced by ListKubeClustersWithFilters
func KubeClusterNames(ctx context.Context, p KubeServicesPresence) ([]string, error) {
kss, err := p.GetKubernetesServers(ctx)
if err != nil {
return nil, trace.Wrap(err)
}
return extractAndSortKubeClusterNames(kss), nil
}

func extractAndSortKubeClusterNames(kubeServers []types.KubeServer) []string {
kubeClusters := extractAndSortKubeClusters(kubeServers)
kubeClusterNames := make([]string, len(kubeClusters))
for i := range kubeClusters {
kubeClusterNames[i] = kubeClusters[i].GetName()
}

return kubeClusterNames
}

// KubeClusters returns a sorted list of unique kubernetes clusters
// registered in p.
//
// DELETE IN 11.0.0, replaced by ListKubeClustersWithFilters
func KubeClusters(ctx context.Context, p KubeServicesPresence) ([]types.KubeCluster, error) {
kubeServers, err := p.GetKubernetesServers(ctx)
if err != nil {
return nil, trace.Wrap(err)
}

return extractAndSortKubeClusters(kubeServers), nil
}

// ListKubeClustersWithFilters returns a sorted list of unique kubernetes clusters
// registered in p.
func ListKubeClustersWithFilters(ctx context.Context, p client.GetResourcesClient, req proto.ListResourcesRequest) ([]types.KubeCluster, error) {
Expand Down Expand Up @@ -244,19 +201,3 @@ func GetKubeAgentVersion(ctx context.Context, pinger Pinger, clusterFeatures pro

return strings.TrimPrefix(agentVersion, "v"), nil
}

// CheckKubeCluster validates kubeClusterName is registered with this Teleport cluster.
func CheckKubeCluster(ctx context.Context, p KubeServicesPresence, kubeClusterName string) error {
if kubeClusterName == "" {
return trace.BadParameter("kube cluster name should not be empty.")
}
kubeClusterNames, err := KubeClusterNames(ctx, p)
if err != nil {
return trace.Wrap(err, "failed to get list of available Kubernetes clusters.")
}
if !slices.Contains(kubeClusterNames, kubeClusterName) {
return trace.BadParameter("Kubernetes cluster %q is not registered in this Teleport cluster; you can list registered Kubernetes clusters using 'tsh kube ls'", kubeClusterName)
}

return nil
}
82 changes: 0 additions & 82 deletions lib/kube/utils/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,74 +26,9 @@ import (
"github.com/stretchr/testify/require"

"github.com/gravitational/teleport/api/client/proto"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/automaticupgrades"
)

func TestCheckKubeCluster(t *testing.T) {
t.Parallel()
ctx := context.Background()

kubeServers := []types.KubeServer{
kubeServer(t, "k8s-1", "server1", "uuuid"),
kubeServer(t, "k8s-2", "server1", "uuuid"),
kubeServer(t, "k8s-3", "server1", "uuuid"),
kubeServer(t, "k8s-4", "server1", "uuuid"),
}

tests := []struct {
desc string
services []types.KubeServer
kubeCluster string
assertErr require.ErrorAssertionFunc
}{
{
desc: "valid cluster name",
services: kubeServers,
kubeCluster: "k8s-4",
assertErr: require.NoError,
},
{
desc: "invalid cluster name",
services: kubeServers,
kubeCluster: "k8s-5",
assertErr: require.Error,
},
{
desc: "no registered clusters",
services: []types.KubeServer{},
kubeCluster: "k8s-1",
assertErr: require.Error,
},
{
desc: "empty cluster provided",
services: kubeServers,
kubeCluster: "",
assertErr: require.Error,
},
}
for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
err := CheckKubeCluster(ctx, mockKubeServicesPresence(tt.services), tt.kubeCluster)
tt.assertErr(t, err)
})
}
}

type mockKubeServicesPresence []types.KubeServer

func (p mockKubeServicesPresence) GetKubernetesServers(context.Context) ([]types.KubeServer, error) {
return p, nil
}

func kubeServer(t *testing.T, kubeCluster, hostname, hostID string) types.KubeServer {
cluster, err := types.NewKubernetesClusterV3(types.Metadata{Name: kubeCluster}, types.KubernetesClusterSpecV3{})
require.NoError(t, err)
server, err := types.NewKubernetesServerV3FromCluster(cluster, hostname, hostID)
require.NoError(t, err)
return server
}

func TestGetAgentVersion(t *testing.T) {
t.Parallel()

Expand Down Expand Up @@ -161,20 +96,3 @@ type pinger struct {
func (p *pinger) Ping(ctx context.Context) (proto.PingResponse, error) {
return p.pingFn(ctx)
}

func TestExtractAndSortKubeClusterNames(t *testing.T) {
t.Parallel()

server1 := kubeServer(t, "watermelon", "server1", "uuuid")

server2 := kubeServer(t, "watermelon", "server1", "uuuid")

server3 := kubeServer(t, "banana", "server2", "uuuid2")

server4 := kubeServer(t, "apple", "server2", "uuuid2")

server5 := kubeServer(t, "pear", "server2", "uuuid2")

names := extractAndSortKubeClusterNames(types.KubeServers{server1, server2, server3, server4, server5})
require.Equal(t, []string{"apple", "banana", "pear", "watermelon"}, names)
}
20 changes: 2 additions & 18 deletions tool/tctl/common/auth_command.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ import (
"github.com/gravitational/teleport/lib/client/identityfile"
"github.com/gravitational/teleport/lib/cryptosuites"
"github.com/gravitational/teleport/lib/defaults"
kubeutils "github.com/gravitational/teleport/lib/kube/utils"
"github.com/gravitational/teleport/lib/service/servicecfg"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/utils"
Expand Down Expand Up @@ -330,7 +329,6 @@ func (a *AuthCommand) GenerateKeys(ctx context.Context, clusterAPI authCommandCl
// certificateSigner is an interface for the methods used by GenerateAndSignKeys
// to sign certificates using the Auth Server.
type certificateSigner interface {
kubeutils.KubeServicesPresence
GenerateDatabaseCert(context.Context, *proto.DatabaseCertRequest) (*proto.DatabaseCertResponse, error)
GenerateUserCerts(ctx context.Context, req proto.UserCertsRequest) (*proto.Certs, error)
GenerateWindowsDesktopCert(context.Context, *proto.WindowsDesktopCertRequest) (*proto.WindowsDesktopCertResponse, error)
Expand Down Expand Up @@ -944,7 +942,7 @@ func (a *AuthCommand) generateUserKeys(ctx context.Context, clusterAPI certifica
}
keyRing.ClusterName = a.leafCluster

if err := a.checkKubeCluster(ctx, clusterAPI); err != nil {
if err := a.checkKubeCluster(); err != nil {
return trace.Wrap(err)
}

Expand Down Expand Up @@ -1105,7 +1103,7 @@ func (a *AuthCommand) checkLeafCluster(clusterAPI certificateSigner) error {
return trace.BadParameter("couldn't find leaf cluster named %q", a.leafCluster)
}

func (a *AuthCommand) checkKubeCluster(ctx context.Context, clusterAPI certificateSigner) error {
func (a *AuthCommand) checkKubeCluster() error {
if a.kubeCluster == "" {
return nil
}
Expand All @@ -1118,20 +1116,6 @@ func (a *AuthCommand) checkKubeCluster(ctx context.Context, clusterAPI certifica
return nil
}

localCluster, err := clusterAPI.GetClusterName()
if err != nil {
return trace.Wrap(err)
}
if localCluster.GetClusterName() != a.leafCluster {
// Skip validation on remote clusters, since we don't know their
// registered kube clusters.
return nil
}

if err := kubeutils.CheckKubeCluster(ctx, clusterAPI, a.kubeCluster); err != nil {
return trace.Wrap(err)
}

return nil
}

Expand Down
Loading

0 comments on commit 8a5107c

Please sign in to comment.