diff --git a/internal/k8s/exec.go b/internal/k8s/exec.go index 079f3697..3cd296ad 100644 --- a/internal/k8s/exec.go +++ b/internal/k8s/exec.go @@ -18,6 +18,10 @@ import ( "k8s.io/client-go/tools/remotecommand" ) +const ( + idleAnnotation = "idling.amazee.io/unidle-replicas" +) + // podContainer returns the first pod and first container inside that pod for // the given namespace and deployment. func (c *Client) podContainer(ctx context.Context, namespace, @@ -68,7 +72,7 @@ func (c *Client) hasRunningPod(ctx context.Context, // replicas to restore. If the label cannot be read or parsed, 1 is returned. // The return value is clamped to the interval [1,16]. func unidleReplicas(deploy appsv1.Deployment) int { - rs, ok := deploy.Annotations["idling.amazee.io/unidle-replicas"] + rs, ok := deploy.Annotations[idleAnnotation] if !ok { return 1 } diff --git a/internal/k8s/exec_test.go b/internal/k8s/exec_test.go new file mode 100644 index 00000000..3dba5d04 --- /dev/null +++ b/internal/k8s/exec_test.go @@ -0,0 +1,37 @@ +package k8s + +import ( + "testing" + + "github.com/alecthomas/assert/v2" + appsv1 "k8s.io/api/apps/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +func TestUnidleReplicas(t *testing.T) { + var testCases = map[string]struct { + input string + expect int + }{ + "simple": {input: "4", expect: 4}, + "high edge": {input: "16", expect: 16}, + "low edge": {input: "1", expect: 1}, + "zero": {input: "0", expect: 1}, + "too high": {input: "17", expect: 16}, + "way too high": {input: "17000000", expect: 16}, + "overflow too high": {input: "9223372036854775808", expect: 1}, + "too low": {input: "-1", expect: 1}, + "way too low": {input: "-17000000", expect: 1}, + "overflow too low": {input: "-9223372036854775808", expect: 1}, + } + for name, tc := range testCases { + t.Run(name, func(tt *testing.T) { + deploy := appsv1.Deployment{ + ObjectMeta: metav1.ObjectMeta{ + Annotations: map[string]string{idleAnnotation: tc.input}, + }, + } + assert.Equal(tt, tc.expect, unidleReplicas(deploy), name) + }) + } +} diff --git a/internal/k8s/logs_test.go b/internal/k8s/logs_test.go new file mode 100644 index 00000000..010b9ebf --- /dev/null +++ b/internal/k8s/logs_test.go @@ -0,0 +1,46 @@ +package k8s + +import ( + "context" + "io" + "strings" + "testing" + "time" + + "github.com/alecthomas/assert/v2" +) + +func TestLinewiseCopy(t *testing.T) { + var testCases = map[string]struct { + input string + expect []string + prefix string + }{ + "logs": { + input: "foo\nbar\nbaz\n", + expect: []string{"test: foo", "test: bar", "test: baz"}, + prefix: "test:", + }, + } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + for name, tc := range testCases { + t.Run(name, func(tt *testing.T) { + out := make(chan string, 1) + in := io.NopCloser(strings.NewReader(tc.input)) + go linewiseCopy(ctx, tc.prefix, out, in) + timer := time.NewTimer(500 * time.Millisecond) + var lines []string + loop: + for { + select { + case <-timer.C: + break loop + case line := <-out: + lines = append(lines, line) + } + } + assert.Equal(tt, tc.expect, lines, name) + }) + } +} diff --git a/internal/k8s/namespacedetails_test.go b/internal/k8s/namespacedetails_test.go new file mode 100644 index 00000000..33b8e1d0 --- /dev/null +++ b/internal/k8s/namespacedetails_test.go @@ -0,0 +1,41 @@ +package k8s + +import ( + "testing" + + "github.com/alecthomas/assert/v2" +) + +func TestIntFromLabel(t *testing.T) { + labels := map[string]string{ + "foo": "1", + "bar": "hello", + "baz": "true", + "negative": "-1", + "max": "9223372036854775807", + "overflow": "9223372036854775808", + } + var testCases = map[string]struct { + target string + expect int + expectErr bool + }{ + "foo": {target: "foo", expect: 1}, + "bar": {target: "bar", expectErr: true}, + "baz": {target: "baz", expectErr: true}, + "negative": {target: "negative", expect: -1}, + "max": {target: "max", expect: 9223372036854775807}, + "overflow": {target: "overflow", expectErr: true}, + } + for name, tc := range testCases { + t.Run(name, func(tt *testing.T) { + result, err := intFromLabel(labels, tc.target) + if tc.expectErr { + assert.Error(tt, err, name) + } else { + assert.NoError(tt, err, name) + assert.Equal(tt, tc.expect, result, name) + } + }) + } +} diff --git a/internal/k8s/spin_test.go b/internal/k8s/spin_test.go new file mode 100644 index 00000000..619aaecf --- /dev/null +++ b/internal/k8s/spin_test.go @@ -0,0 +1,37 @@ +package k8s + +import ( + "context" + "strings" + "testing" + "time" + + "github.com/alecthomas/assert/v2" +) + +func TestSpinAfter(t *testing.T) { + wait := 500 * time.Millisecond + var testCases = map[string]struct { + connectTime time.Duration + expectSpinner bool + }{ + "spinner": {connectTime: 600 * time.Millisecond, expectSpinner: true}, + "no spinner": {connectTime: 400 * time.Millisecond, expectSpinner: false}, + } + for name, tc := range testCases { + t.Run(name, func(tt *testing.T) { + var buf strings.Builder + // start the spinner with a given connect time + ctx, cancel := context.WithTimeout(context.Background(), tc.connectTime) + wg := spinAfter(ctx, &buf, wait) + wg.Wait() + cancel() + // check if the builder has spinner animations + if tc.expectSpinner { + assert.NotZero(tt, buf.Len(), name) + } else { + assert.Zero(tt, buf.Len(), name) + } + }) + } +} diff --git a/internal/k8s/termsizequeue_test.go b/internal/k8s/termsizequeue_test.go new file mode 100644 index 00000000..448082e7 --- /dev/null +++ b/internal/k8s/termsizequeue_test.go @@ -0,0 +1,39 @@ +package k8s + +import ( + "context" + "testing" + + "github.com/alecthomas/assert/v2" + "github.com/gliderlabs/ssh" + "k8s.io/client-go/tools/remotecommand" +) + +func TestTermSizeQueue(t *testing.T) { + var testCases = map[string]struct { + input ssh.Window + expect remotecommand.TerminalSize + }{ + "term size change": { + input: ssh.Window{ + Width: 100, + Height: 200, + }, + expect: remotecommand.TerminalSize{ + Width: 100, + Height: 200, + }, + }, + } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + for name, tc := range testCases { + t.Run(name, func(tt *testing.T) { + in := make(chan ssh.Window, 1) + tsq := newTermSizeQueue(ctx, in) + in <- tc.input + output := tsq.Next() + assert.Equal(tt, tc.expect, *output, name) + }) + } +} diff --git a/internal/k8s/validate_test.go b/internal/k8s/validate_test.go new file mode 100644 index 00000000..5545db7b --- /dev/null +++ b/internal/k8s/validate_test.go @@ -0,0 +1,27 @@ +package k8s_test + +import ( + "testing" + + "github.com/alecthomas/assert/v2" + "github.com/uselagoon/ssh-portal/internal/k8s" +) + +func TestValidateLabelValues(t *testing.T) { + var testCases = map[string]struct { + input string + expectError bool + }{ + "valid": {input: "foo", expectError: false}, + "invalid": {input: "naïve", expectError: true}, + } + for name, tc := range testCases { + t.Run(name, func(tt *testing.T) { + if tc.expectError { + assert.Error(tt, k8s.ValidateLabelValue(tc.input), name) + } else { + assert.NoError(tt, k8s.ValidateLabelValue(tc.input), name) + } + }) + } +} diff --git a/internal/rbac/usercansshtoenvironment.go b/internal/rbac/usercansshtoenvironment.go index 7a0d3a64..691ebb46 100644 --- a/internal/rbac/usercansshtoenvironment.go +++ b/internal/rbac/usercansshtoenvironment.go @@ -37,8 +37,13 @@ var defaultEnvTypeRoleCanSSH = map[lagoon.EnvironmentType][]lagoon.UserRole{ // UserCanSSHToEnvironment returns true if the given environment can be // connected to via SSH by the user with the given realm roles and user groups, // and false otherwise. -func (p *Permission) UserCanSSHToEnvironment(ctx context.Context, env *lagoondb.Environment, - realmRoles, userGroups []string, groupProjectIDs map[string][]int) bool { +func (p *Permission) UserCanSSHToEnvironment( + ctx context.Context, + env *lagoondb.Environment, + realmRoles, + userGroups []string, + groupProjectIDs map[string][]int, +) bool { // set up tracing _, span := otel.Tracer(pkgName).Start(ctx, "UserCanSSHToEnvironment") defer span.End() diff --git a/internal/sshserver/serve_test.go b/internal/sshserver/serve_test.go new file mode 100644 index 00000000..a727dfe1 --- /dev/null +++ b/internal/sshserver/serve_test.go @@ -0,0 +1,24 @@ +package sshserver + +import ( + "slices" + "testing" + + "github.com/alecthomas/assert/v2" +) + +func TestDisableSHA1Kex(t *testing.T) { + var testCases = map[string]struct { + input string + expect bool + }{ + "no sha1": {input: "diffie-hellman-group14-sha1", expect: false}, + } + for name, tc := range testCases { + t.Run(name, func(tt *testing.T) { + conf := disableSHA1Kex(nil) + assert.Equal(tt, tc.expect, + slices.Contains(conf.Config.KeyExchanges, tc.input), name) + }) + } +} diff --git a/internal/sshserver/sessionhandler.go b/internal/sshserver/sessionhandler.go index ad028deb..da208296 100644 --- a/internal/sshserver/sessionhandler.go +++ b/internal/sshserver/sessionhandler.go @@ -30,8 +30,42 @@ var ( }) ) +// authCtxValues extracts the context values set by the authhandler. +func authCtxValues(ctx ssh.Context) (int, string, int, string, string, error) { + var ok bool + var eid, pid int + var ename, pname, fingerprint string + eid, ok = ctx.Value(environmentIDKey).(int) + if !ok { + return eid, ename, pid, pname, fingerprint, + fmt.Errorf("couldn't extract environment ID from session context") + } + ename, ok = ctx.Value(environmentNameKey).(string) + if !ok { + return eid, ename, pid, pname, fingerprint, + fmt.Errorf("couldn't extract environment name from session context") + } + pid, ok = ctx.Value(projectIDKey).(int) + if !ok { + return eid, ename, pid, pname, fingerprint, + fmt.Errorf("couldn't extract project ID from session context") + } + pname, ok = ctx.Value(projectNameKey).(string) + if !ok { + return eid, ename, pid, pname, fingerprint, + fmt.Errorf("couldn't extract project name from session context") + } + fingerprint, ok = ctx.Value(sshFingerprint).(string) + if !ok { + return eid, ename, pid, pname, fingerprint, + fmt.Errorf("couldn't extract SSH key fingerprint from session context") + } + return eid, ename, pid, pname, fingerprint, nil +} + // getSSHIntent analyses the SFTP flag and the raw command strings to determine -// if the command should be wrapped. +// if the command should be wrapped, and returns the given cmd wrapped +// appropriately. func getSSHIntent(sftp bool, cmd []string) []string { // if this is an sftp session we ignore any commands if sftp { @@ -104,25 +138,16 @@ func sessionHandler(log *slog.Logger, c K8SAPIService, return } // extract info passed through the context by the authhandler - eid, ok := ctx.Value(environmentIDKey).(int) - if !ok { - log.Warn("couldn't extract environment ID from session context") - } - ename, ok := ctx.Value(environmentNameKey).(string) - if !ok { - log.Warn("couldn't extract environment name from session context") - } - pid, ok := ctx.Value(projectIDKey).(int) - if !ok { - log.Warn("couldn't extract project ID from session context") - } - pname, ok := ctx.Value(projectNameKey).(string) - if !ok { - log.Warn("couldn't extract project name from session context") - } - fingerprint, ok := ctx.Value(sshFingerprint).(string) - if !ok { - log.Warn("couldn't extract SSH key fingerprint from session context") + eid, ename, pid, pname, fingerprint, err := authCtxValues(ctx) + if err != nil { + log.Error("couldn't extract auth values from context", + slog.Any("error", err)) + _, err = fmt.Fprintf(s.Stderr(), "error executing command. SID: %s\r\n", + ctx.SessionID()) + if err != nil { + log.Debug("couldn't write to session stream", slog.Any("error", err)) + } + return } if len(logs) != 0 { if !logAccessEnabled {