Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement subgroup support #476

Merged
merged 7 commits into from
Oct 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmd/keycloak-debug/dumpgroups.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func (cmd *DumpGroupsCmd) Run(log *slog.Logger) error {
if err != nil {
return fmt.Errorf("couldn't init keycloak client: %v", err)
}
groupMap, err := k.GroupNameGroupIDMap(ctx)
groupMap, err := k.TopLevelGroupNameGroupIDMap(ctx)
if err != nil {
return fmt.Errorf("couldn't get keycloak group map: %v", err)
}
Expand Down
2 changes: 1 addition & 1 deletion cmd/keycloak-debug/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
// CLI represents the command-line interface.
type CLI struct {
Debug bool `kong:"env='DEBUG',help='Enable debug logging'"`
DumpGroups DumpGroupsCmd `kong:"cmd,default=1,help='(default) Serve ssh-portal-api requests'"`
DumpGroups DumpGroupsCmd `kong:"cmd,default=1,help='(default) Dump top-level Keycloak groups to stdout'"`
}

func main() {
Expand Down
18 changes: 9 additions & 9 deletions cmd/ssh-portal-api/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,21 +39,14 @@ func (cmd *ServeCmd) Run(log *slog.Logger) error {
// get main process context, which cancels on SIGTERM
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGTERM)
defer stop()
// init RBAC permission engine
var p *rbac.Permission
if cmd.BlockDeveloperSSH {
p = rbac.NewPermission(rbac.BlockDeveloperSSH())
} else {
p = rbac.NewPermission()
}
// init lagoon DB client
dbConf := mysql.NewConfig()
dbConf.Addr = cmd.APIDBAddress
dbConf.DBName = cmd.APIDBDatabase
dbConf.Net = "tcp"
dbConf.Passwd = cmd.APIDBPassword
dbConf.User = cmd.APIDBUsername
l, err := lagoondb.NewClient(ctx, dbConf.FormatDSN())
ldb, err := lagoondb.NewClient(ctx, dbConf.FormatDSN())
if err != nil {
return fmt.Errorf("couldn't init lagoondb client: %v", err)
}
Expand All @@ -66,14 +59,21 @@ func (cmd *ServeCmd) Run(log *slog.Logger) error {
if err != nil {
return fmt.Errorf("couldn't init keycloak client: %v", err)
}
// init RBAC permission engine
var p *rbac.Permission
if cmd.BlockDeveloperSSH {
p = rbac.NewPermission(k, ldb, rbac.BlockDeveloperSSH())
} else {
p = rbac.NewPermission(k, ldb)
}
// set up goroutine handler
eg, ctx := errgroup.WithContext(ctx)
// start the metrics server
metrics.Serve(ctx, eg, metricsPort)
// start serving SSH token requests
eg.Go(func() error {
// start serving NATS requests
return sshportalapi.ServeNATS(ctx, stop, log, p, l, k, cmd.NATSURL)
return sshportalapi.ServeNATS(ctx, stop, log, p, ldb, cmd.NATSURL)
})
return eg.Wait()
}
17 changes: 8 additions & 9 deletions cmd/ssh-token/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,6 @@ func (cmd *ServeCmd) Run(log *slog.Logger) error {
// get main process context, which cancels on SIGTERM
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGTERM)
defer stop()
// init RBAC permission engine
var p *rbac.Permission
if cmd.BlockDeveloperSSH {
p = rbac.NewPermission(rbac.BlockDeveloperSSH())
} else {
p = rbac.NewPermission()
}
// init lagoon DB client
dbConf := mysql.NewConfig()
dbConf.Addr = cmd.APIDBAddress
Expand Down Expand Up @@ -81,6 +74,13 @@ func (cmd *ServeCmd) Run(log *slog.Logger) error {
if err != nil {
return fmt.Errorf("couldn't init keycloak permission client: %v", err)
}
// init RBAC permission engine
var p *rbac.Permission
if cmd.BlockDeveloperSSH {
p = rbac.NewPermission(keycloakPermission, ldb, rbac.BlockDeveloperSSH())
} else {
p = rbac.NewPermission(keycloakPermission, ldb)
}
// start listening on TCP port
l, err := net.Listen("tcp", fmt.Sprintf(":%d", cmd.SSHServerPort))
if err != nil {
Expand All @@ -101,8 +101,7 @@ func (cmd *ServeCmd) Run(log *slog.Logger) error {
metrics.Serve(ctx, eg, metricsPort)
// start serving SSH token requests
eg.Go(func() error {
return sshtoken.Serve(ctx, log, l, p, ldb, keycloakToken, keycloakPermission,
hostkeys)
return sshtoken.Serve(ctx, log, l, p, ldb, keycloakToken, hostkeys)
})
return eg.Wait()
}
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ require (
github.com/gliderlabs/ssh v0.3.7
github.com/go-sql-driver/mysql v1.8.1
github.com/golang-jwt/jwt/v5 v5.2.1
github.com/google/uuid v1.6.0
github.com/google/uuid v1.6.1-0.20240806143717-0e97ed3b5379
github.com/jmoiron/sqlx v1.4.0
github.com/moby/spdystream v0.5.0
github.com/nats-io/nats.go v1.37.0
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0=
github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/pprof v0.0.0-20240525223248-4bfdf5a9a2af h1:kmjWCqn2qkEml422C2Rrd27c3VGxi6a/6HNq8QmHRKM=
github.com/google/pprof v0.0.0-20240525223248-4bfdf5a9a2af/go.mod h1:K1liHPHnj73Fdn/EKuT8nrFqBihUSKXoLYU0BuatOYo=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/uuid v1.6.1-0.20240806143717-0e97ed3b5379 h1:9pvPp/2VCtCB2xdSUCaKe1VKCzVHMR+GGgIAVLfQxIs=
github.com/google/uuid v1.6.1-0.20240806143717-0e97ed3b5379/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kXD8ePA=
github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo=
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
Expand Down
24 changes: 12 additions & 12 deletions internal/cache/cache.go → internal/cache/any.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,28 +10,28 @@ const (
defaultTTL = time.Minute
)

// Cache is a generic, thread-safe, in-memory cache that stores a value with a
// Any is a generic, thread-safe, in-memory cache that stores a value with a
// TTL, after which the cache expires.
type Cache[T any] struct {
type Any[T any] struct {
data T
expiry time.Time
ttl time.Duration
mu sync.Mutex
}

// Option is a functional option argument to NewCache().
type Option[T any] func(*Cache[T])
// AnyOption is a functional option argument to NewCache().
type AnyOption[T any] func(*Any[T])

// WithTTL sets the the Cache time-to-live to ttl.
func WithTTL[T any](ttl time.Duration) Option[T] {
return func(c *Cache[T]) {
// AnyWithTTL sets the the Cache time-to-live to ttl.
func AnyWithTTL[T any](ttl time.Duration) AnyOption[T] {
return func(c *Any[T]) {
c.ttl = ttl
}
}

// NewCache instantiates a Cache for type T with a default TTL of 1 minute.
func NewCache[T any](options ...Option[T]) *Cache[T] {
c := Cache[T]{
// NewAny instantiates an Any cache for type T with a default TTL of 1 minute.
func NewAny[T any](options ...AnyOption[T]) *Any[T] {
c := Any[T]{
ttl: defaultTTL,
}
for _, option := range options {
Expand All @@ -41,7 +41,7 @@ func NewCache[T any](options ...Option[T]) *Cache[T] {
}

// Set updates the value in the cache and sets the expiry to now+TTL.
func (c *Cache[T]) Set(value T) {
func (c *Any[T]) Set(value T) {
c.mu.Lock()
defer c.mu.Unlock()
c.data = value
Expand All @@ -50,7 +50,7 @@ func (c *Cache[T]) Set(value T) {

// Get retrieves the value from the cache. If cache has expired, the second
// return value will be false.
func (c *Cache[T]) Get() (T, bool) {
func (c *Any[T]) Get() (T, bool) {
c.mu.Lock()
defer c.mu.Unlock()
if time.Now().After(c.expiry) {
Expand Down
26 changes: 14 additions & 12 deletions internal/cache/cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ func TestIntCache(t *testing.T) {
}
for name, tc := range testCases {
t.Run(name, func(tt *testing.T) {
c := cache.NewCache[int](cache.WithTTL[int](time.Second))
c := cache.NewAny[int](cache.AnyWithTTL[int](time.Second))
c.Set(tc.input)
if tc.expired {
time.Sleep(2 * time.Second)
Expand All @@ -36,33 +36,35 @@ func TestIntCache(t *testing.T) {

func TestMapCache(t *testing.T) {
var testCases = map[string]struct {
input map[string]string
expect map[string]string
key string
value string
expired bool
}{
"expired": {
input: map[string]string{"foo": "bar"},
key: "foo",
value: "bar",
expired: true,
},
"not expired": {
input: map[string]string{"foo": "bar"},
expect: map[string]string{"foo": "bar"},
key: "foo",
value: "bar",
},
}
for name, tc := range testCases {
t.Run(name, func(tt *testing.T) {
c := cache.NewCache[map[string]string](
cache.WithTTL[map[string]string](time.Second),
c := cache.NewMap[string, string](
cache.MapWithTTL[string, string](time.Second),
)
c.Set(tc.input)
c.Set(tc.key, tc.value)
if tc.expired {
time.Sleep(2 * time.Second)
_, ok := c.Get()
value, ok := c.Get(tc.key)
assert.False(tt, ok, name)
assert.Equal(tt, "", value, name)
} else {
value, ok := c.Get()
value, ok := c.Get(tc.key)
assert.True(tt, ok, name)
assert.Equal(tt, tc.expect, value, name)
assert.Equal(tt, tc.value, value, name)
}
})
}
Expand Down
69 changes: 69 additions & 0 deletions internal/cache/map.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package cache

import (
"sync"
"time"
)

type mapValue[T any] struct {
data T
expiry time.Time
}

// Map is a generic, thread-safe, in-memory cache map that stores a key-value
// pairs with a TTL, after which the cache expires.
type Map[K comparable, V any] struct {
data map[K]mapValue[V]
ttl time.Duration
mu sync.Mutex
}

// MapOption is a functional option argument to NewCache().
type MapOption[K comparable, V any] func(*Map[K, V])

// MapWithTTL sets the the Cache time-to-live to ttl.
func MapWithTTL[K comparable, V any](ttl time.Duration) MapOption[K, V] {
return func(c *Map[K, V]) {
c.ttl = ttl
}
}

// NewMap instantiates a Map for key type K and value type V with a default TTL
// of 1 minute.
func NewMap[K comparable, V any](options ...MapOption[K, V]) *Map[K, V] {
c := Map[K, V]{
data: map[K]mapValue[V]{},
ttl: defaultTTL,
}
for _, option := range options {
option(&c)
}
return &c
}

// Set updates the value in the cache and sets the expiry to now+TTL.
func (c *Map[K, V]) Set(key K, data V) {
c.mu.Lock()
defer c.mu.Unlock()
c.data[key] = mapValue[V]{
data: data,
expiry: time.Now().Add(c.ttl),
}
}

// Get retrieves the value from the cache. If the value doesn't exist in the
// cache, or if the cache has expired, the second return value will be false.
func (c *Map[K, V]) Get(key K) (V, bool) {
c.mu.Lock()
defer c.mu.Unlock()
var zero mapValue[V]
value, ok := c.data[key]
if !ok {
return zero.data, false
}
if time.Now().After(value.expiry) {
delete(c.data, key)
return zero.data, false
}
return value.data, true
}
Loading
Loading