Skip to content

Commit

Permalink
Merge pull request #478 from uselagoon/limits
Browse files Browse the repository at this point in the history
Add limits to logs sessions
  • Loading branch information
smlx authored Oct 15, 2024
2 parents a908924 + 243b039 commit 7c93f5c
Show file tree
Hide file tree
Showing 6 changed files with 173 additions and 16 deletions.
19 changes: 11 additions & 8 deletions cmd/ssh-portal/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"net"
"os/signal"
"syscall"
"time"

"github.com/nats-io/nats.go"
"github.com/uselagoon/ssh-portal/internal/k8s"
Expand All @@ -21,13 +22,15 @@ const (

// ServeCmd represents the serve command.
type ServeCmd struct {
NATSServer string `kong:"required,env='NATS_URL',help='NATS server URL (nats://... or tls://...)'"`
SSHServerPort uint `kong:"default='2222',env='SSH_SERVER_PORT',help='Port the SSH server will listen on for SSH client connections'"`
HostKeyECDSA string `kong:"env='HOST_KEY_ECDSA',help='PEM encoded ECDSA host key'"`
HostKeyED25519 string `kong:"env='HOST_KEY_ED25519',help='PEM encoded Ed25519 host key'"`
HostKeyRSA string `kong:"env='HOST_KEY_RSA',help='PEM encoded RSA host key'"`
LogAccessEnabled bool `kong:"env='LOG_ACCESS_ENABLED',help='Allow any user who can SSH into a pod to also access its logs'"`
Banner string `kong:"env='BANNER',help='Text sent to remote users before authentication'"`
NATSServer string `kong:"required,env='NATS_URL',help='NATS server URL (nats://... or tls://...)'"`
SSHServerPort uint `kong:"default='2222',env='SSH_SERVER_PORT',help='Port the SSH server will listen on for SSH client connections'"`
HostKeyECDSA string `kong:"env='HOST_KEY_ECDSA',help='PEM encoded ECDSA host key'"`
HostKeyED25519 string `kong:"env='HOST_KEY_ED25519',help='PEM encoded Ed25519 host key'"`
HostKeyRSA string `kong:"env='HOST_KEY_RSA',help='PEM encoded RSA host key'"`
LogAccessEnabled bool `kong:"env='LOG_ACCESS_ENABLED',help='Allow any user who can SSH into a pod to also access its logs'"`
Banner string `kong:"env='BANNER',help='Text sent to remote users before authentication'"`
ConcurrentLogLimit uint `kong:"default='32',env='CONCURRENT_LOG_LIMIT',help='Maximum number of concurrent log sessions'"`
LogTimeLimit time.Duration `kong:"default='4h',env='LOG_TIME_LIMIT',help='Maximum lifetime of each logs session'"`
}

// Run the serve command to handle SSH connection requests.
Expand Down Expand Up @@ -60,7 +63,7 @@ func (cmd *ServeCmd) Run(log *slog.Logger) error {
}
defer l.Close()
// get kubernetes client
c, err := k8s.NewClient()
c, err := k8s.NewClient(cmd.ConcurrentLogLimit, cmd.LogTimeLimit)
if err != nil {
return fmt.Errorf("couldn't create k8s client: %v", err)
}
Expand Down
11 changes: 8 additions & 3 deletions internal/k8s/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"sync"
"time"

"golang.org/x/sync/semaphore"
"k8s.io/client-go/kubernetes"
"k8s.io/client-go/rest"
)
Expand All @@ -24,10 +25,12 @@ type Client struct {
config *rest.Config
clientset kubernetes.Interface
logStreamIDs sync.Map
logSem *semaphore.Weighted
logTimeLimit time.Duration
}

// NewClient creates a new kubernetes API client.
func NewClient() (*Client, error) {
func NewClient(concurrentLogLimit uint, logTimeLimit time.Duration) (*Client, error) {
// create the in-cluster config
config, err := rest.InClusterConfig()
if err != nil {
Expand All @@ -39,7 +42,9 @@ func NewClient() (*Client, error) {
return nil, err
}
return &Client{
config: config,
clientset: clientset,
config: config,
clientset: clientset,
logSem: semaphore.NewWeighted(int64(concurrentLogLimit)),
logTimeLimit: logTimeLimit,
}, nil
}
2 changes: 1 addition & 1 deletion internal/k8s/exec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ func TestIdledDeployLabels(t *testing.T) {
t.Run(name, func(tt *testing.T) {
// create fake Kubernetes client with test deploys
c := &Client{
clientset: fake.NewSimpleClientset(tc.deploys),
clientset: fake.NewClientset(tc.deploys),
}
deploys, err := c.idledDeploys(context.Background(), testNS)
assert.NoError(tt, err, name)
Expand Down
38 changes: 34 additions & 4 deletions internal/k8s/logs.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package k8s
import (
"bufio"
"context"
"errors"
"fmt"
"io"
"sync"
Expand All @@ -27,6 +28,13 @@ var (
// limitBytes defines the maximum number of bytes of logs returned from a
// single container
limitBytes int64 = 1 * 1024 * 1024 // 1MiB

// ErrConcurrentLogLimit indicates that the maximum number of concurrent log
// sessions has been reached.
ErrConcurrentLogLimit = errors.New("reached concurrent log limit")
// ErrLogTimeLimit indicates that the maximum log session time has been
// exceeded.
ErrLogTimeLimit = errors.New("exceeded maximum log session time")
)

// linewiseCopy reads strings separated by \n from logStream, and writes them
Expand Down Expand Up @@ -202,11 +210,27 @@ func (c *Client) newPodInformer(ctx context.Context,
// follow=false.
// 2. ctx is cancelled (signalling that the SSH channel was closed).
// 3. An unrecoverable error occurs.
func (c *Client) Logs(ctx context.Context,
namespace, deployment, container string, follow bool, tailLines int64,
stdio io.ReadWriter) error {
//
// If a call to Logs would exceed the configured maximum number of concurrent
// log sessions, ErrConcurrentLogLimit is returned.
//
// If the configured log time limit is exceeded, ErrLogTimeLimit is returned.
func (c *Client) Logs(
ctx context.Context,
namespace,
deployment,
container string,
follow bool,
tailLines int64,
stdio io.ReadWriter,
) error {
// Exit with an error if we have hit the concurrent log limit.
if !c.logSem.TryAcquire(1) {
return ErrConcurrentLogLimit
}
defer c.logSem.Release(1)
// Wrap the context so we can cancel subroutines of this function on error.
childCtx, cancel := context.WithCancel(ctx)
childCtx, cancel := context.WithTimeout(ctx, c.logTimeLimit)
defer cancel()
// Generate a requestID value to uniquely distinguish between multiple calls
// to this function. This requestID is used in readLogs() to distinguish
Expand Down Expand Up @@ -253,6 +277,9 @@ func (c *Client) Logs(ctx context.Context,
return fmt.Errorf("couldn't construct new pod informer: %v", err)
}
podInformer.Run(childCtx.Done())
if errors.Is(childCtx.Err(), context.DeadlineExceeded) {
return ErrLogTimeLimit
}
return nil
})
} else {
Expand Down Expand Up @@ -280,6 +307,9 @@ func (c *Client) Logs(ctx context.Context,
if readLogsErr != nil {
return fmt.Errorf("couldn't read logs on existing pods: %v", readLogsErr)
}
if errors.Is(childCtx.Err(), context.DeadlineExceeded) {
return ErrLogTimeLimit
}
return nil
})
}
Expand Down
105 changes: 105 additions & 0 deletions internal/k8s/logs_test.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
package k8s

import (
"bytes"
"context"
"io"
"strings"
"testing"
"time"

"github.com/alecthomas/assert/v2"
"golang.org/x/sync/errgroup"
"golang.org/x/sync/semaphore"
appsv1 "k8s.io/api/apps/v1"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/client-go/kubernetes/fake"
)

func TestLinewiseCopy(t *testing.T) {
Expand Down Expand Up @@ -44,3 +51,101 @@ func TestLinewiseCopy(t *testing.T) {
})
}
}

func TestLogs(t *testing.T) {
testNS := "testns"
testDeploy := "foo"
testPod := "bar"
deploys := &appsv1.DeploymentList{
Items: []appsv1.Deployment{
{
ObjectMeta: metav1.ObjectMeta{
Name: testDeploy,
Namespace: testNS,
Labels: map[string]string{
"idling.lagoon.sh/watch": "true",
},
},
Spec: appsv1.DeploymentSpec{
Selector: &metav1.LabelSelector{
MatchLabels: map[string]string{
"app.kubernetes.io/name": "foo-app",
},
},
},
},
},
}
pods := &corev1.PodList{
Items: []corev1.Pod{
{
ObjectMeta: metav1.ObjectMeta{
Name: "foo-123xyz",
Namespace: testNS,
Labels: map[string]string{
"app.kubernetes.io/name": "foo-app",
},
},
Status: corev1.PodStatus{
ContainerStatuses: []corev1.ContainerStatus{
{
Name: testPod,
},
},
},
},
},
}
var testCases = map[string]struct {
follow bool
sessionCount uint
expectError bool
expectedError error
}{
"no follow": {
sessionCount: 1,
},
"no follow two sessions": {
sessionCount: 2,
},
"no follow session count limit exceeded": {
sessionCount: 3,
expectError: true,
expectedError: ErrConcurrentLogLimit,
},
"follow session timeout": {
follow: true,
sessionCount: 1,
expectError: true,
expectedError: ErrLogTimeLimit,
},
}
for name, tc := range testCases {
t.Run(name, func(tt *testing.T) {
// create fake Kubernetes client with test deploys
c := &Client{
clientset: fake.NewClientset(deploys, pods),
logSem: semaphore.NewWeighted(int64(2)),
logTimeLimit: time.Second,
}
// execute test
var buf bytes.Buffer
var eg errgroup.Group
ctx := context.Background()
for range tc.sessionCount {
eg.Go(func() error {
return c.Logs(ctx, testNS, testDeploy, testPod, tc.follow, 10, &buf)
})
}
// check results
err := eg.Wait()
if tc.expectError {
assert.Error(tt, err, name)
assert.Equal(tt, err, tc.expectedError, name)
} else {
assert.NoError(tt, err, name)
tt.Log(buf.String())
}
})
}
}
14 changes: 14 additions & 0 deletions internal/sshserver/sessionhandler.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,14 @@ var (
Name: "sshportal_sessions_total",
Help: "The total number of ssh-portal sessions started",
})
execSessions = promauto.NewGauge(prometheus.GaugeOpts{
Name: "sshportal_exec_sessions",
Help: "Current number of ssh-portal exec sessions",
})
logsSessions = promauto.NewGauge(prometheus.GaugeOpts{
Name: "sshportal_logs_sessions",
Help: "Current number of ssh-portal logs sessions",
})
)

// authCtxValues extracts the context values set by the authhandler.
Expand Down Expand Up @@ -246,6 +254,9 @@ func startClientKeepalive(ctx context.Context, cancel context.CancelFunc,

func doLogs(ctx ssh.Context, log *slog.Logger, s ssh.Session, deployment,
container string, follow bool, tailLines int64, c K8SAPIService) {
// update metrics
logsSessions.Inc()
defer logsSessions.Dec()
// Wrap the ssh.Context so we can cancel goroutines started from this
// function without affecting the SSH session.
childCtx, cancel := context.WithCancel(ctx)
Expand Down Expand Up @@ -280,6 +291,9 @@ func doLogs(ctx ssh.Context, log *slog.Logger, s ssh.Session, deployment,
func doExec(ctx ssh.Context, log *slog.Logger, s ssh.Session, deployment,
container string, cmd []string, c K8SAPIService, pty bool,
winch <-chan ssh.Window) {
// update metrics
execSessions.Inc()
defer execSessions.Dec()
err := c.Exec(ctx, s.User(), deployment, container, cmd, s,
s.Stderr(), pty, winch)
if err != nil {
Expand Down

0 comments on commit 7c93f5c

Please sign in to comment.