diff --git a/backend/go.mod b/backend/go.mod index 11f70c51..62bb79d6 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -37,6 +37,7 @@ require ( golang.org/x/net v0.31.0 golang.org/x/sys v0.27.0 golang.org/x/term v0.26.0 + golang.org/x/time v0.7.0 gopkg.in/tomb.v2 v2.0.0-20161208151619-d5d1b5820637 gopkg.in/yaml.v3 v3.0.1 ) @@ -121,7 +122,6 @@ require ( golang.org/x/exp v0.0.0-20230905200255-921286631fa9 // indirect golang.org/x/sync v0.9.0 // indirect golang.org/x/text v0.20.0 // indirect - golang.org/x/time v0.7.0 // indirect google.golang.org/protobuf v1.34.1 // indirect gopkg.in/ini.v1 v1.67.0 // indirect ) diff --git a/backend/pkg/rate/limit.go b/backend/pkg/rate/limit.go new file mode 100644 index 00000000..848b1c49 --- /dev/null +++ b/backend/pkg/rate/limit.go @@ -0,0 +1,70 @@ +package rate + +import ( + "context" + "time" + + "golang.org/x/time/rate" +) + +// Limiter implements a rate limiting interface based on golang.org/x/time/rate +// but extending the interface with the ability to expose internal errors. +type Limiter interface { + Reserve(ctx context.Context) (Reservation, error) + Tokens(ctx context.Context) (uint64, error) +} + +// Reservation is a point-in-time reservation of a ratelimit token. If the token +// is approved OK return true, otherways Delay will return the duration for next +// token(s) to become available. While Tokens return the number of available +// tokens after the reservation. +type Reservation interface { + OK() bool + Delay() time.Duration + Tokens() uint64 +} + +type limiter rate.Limiter + +func NewLimiter(limit int, interval time.Duration) Limiter { + return (*limiter)(rate.NewLimiter(rate.Every(interval/time.Duration(limit)), limit)) +} + +func (lim *limiter) Reserve(context.Context) (Reservation, error) { + now := time.Now() + goLimiter := (*rate.Limiter)(lim) + res := &reservation{ + reservation: goLimiter.ReserveN(now, 1), + time: now, + } + if res.OK() { + res.tokens = uint64(goLimiter.TokensAt(now)) + } + return res, nil +} + +func (lim *limiter) Tokens(context.Context) (uint64, error) { + goLimiter := (*rate.Limiter)(lim) + if tokens := goLimiter.Tokens(); tokens > 0 { + return uint64(tokens), nil + } + return 0, nil +} + +type reservation struct { + time time.Time + tokens uint64 + reservation *rate.Reservation +} + +func (r *reservation) OK() bool { + return r.Delay() == 0 +} + +func (r *reservation) Delay() time.Duration { + return r.reservation.DelayFrom(r.time) +} + +func (r *reservation) Tokens() uint64 { + return r.tokens +} diff --git a/backend/pkg/rate/limit_test.go b/backend/pkg/rate/limit_test.go new file mode 100644 index 00000000..40875063 --- /dev/null +++ b/backend/pkg/rate/limit_test.go @@ -0,0 +1,47 @@ +package rate + +import ( + "context" + "testing" + "time" +) + +func TestLimiter(t *testing.T) { + t.Parallel() + + ctx := context.Background() + + limiter := NewLimiter(2, time.Hour) + resApproved, err := limiter.Reserve(ctx) + if err != nil { + t.Errorf("unexpected error reserving time slot #1: %s", err.Error()) + t.FailNow() + } + _, err = limiter.Reserve(ctx) + if err != nil { + t.Errorf("unexpected error reserving time slot #2: %s", err.Error()) + t.FailNow() + } + resDenied, err := limiter.Reserve(ctx) + if err != nil { + t.Errorf("unexpected error reserving time slot #3: %s", err.Error()) + t.FailNow() + } + if !resApproved.OK() { + t.Error("expected first reservation to be available") + } else if resApproved.Delay() > 0 { + t.Error("an approved reservation should not have a delay") + } else if resApproved.Tokens() == 0 { + t.Error("expected more tokens to be available after first reservation") + } + if resDenied.OK() { + t.Error("reservation should not be available before 1h has passed") + } else { + if resDenied.Delay() == 0 { + t.Error("a denied reservation should have a non-zero delay") + } + if resDenied.Tokens() > 0 { + t.Errorf("a deneied reservation should not report free tokens, got: %d", resDenied.Tokens()) + } + } +} diff --git a/backend/pkg/redis/ratelimit.go b/backend/pkg/redis/ratelimit.go new file mode 100644 index 00000000..8797d441 --- /dev/null +++ b/backend/pkg/redis/ratelimit.go @@ -0,0 +1,129 @@ +package redis + +import ( + "context" + "errors" + "fmt" + "time" + + "github.com/redis/go-redis/v9" + + "github.com/mendersoftware/mender-server/pkg/rate" +) + +func NewFixedWindowRateLimiter( + client Client, + paramsFromContext RatelimitParamsFunc, +) rate.Limiter { + return &fixedWindowRatelimiter{ + client: client, + paramsFunc: paramsFromContext, + nowFunc: time.Now, + } +} + +type RatelimitParams struct { + Burst uint64 + Interval time.Duration + KeyPrefix string +} + +type RatelimitParamsFunc func(context.Context) (*RatelimitParams, error) + +func FixedRatelimitParams(params RatelimitParams) RatelimitParamsFunc { + return func(ctx context.Context) (*RatelimitParams, error) { + return ¶ms, nil + } +} + +type fixedWindowRatelimiter struct { + client Client + paramsFunc RatelimitParamsFunc + nowFunc func() time.Time +} + +type simpleReservation struct { + ok bool + tokens uint64 + delay time.Duration +} + +func (r *simpleReservation) OK() bool { + return r.ok +} + +func (r *simpleReservation) Delay() time.Duration { + return r.delay +} + +func (r *simpleReservation) Tokens() uint64 { + return r.tokens +} + +func epoch(t time.Time, interval time.Duration) int64 { + return t.UnixMilli() / interval.Milliseconds() +} + +func fixedWindowKey(prefix string, epoch int64) string { + if prefix == "" { + prefix = "ratelimit" + } + return fmt.Sprintf("%s:e:%d:c", prefix, epoch) +} + +func (rl *fixedWindowRatelimiter) Reserve(ctx context.Context) (rate.Reservation, error) { + now := rl.nowFunc() + params, err := rl.paramsFunc(ctx) + if err != nil { + return nil, err + } else if params == nil { + return &simpleReservation{ + ok: true, + }, nil + } + epoch := epoch(now, params.Interval) + key := fixedWindowKey(params.KeyPrefix, epoch) + count := uint64(1) + + err = rl.client.SetArgs(ctx, key, count, redis.SetArgs{ + TTL: params.Interval, + Mode: `NX`, + }).Err() + if errors.Is(err, redis.Nil) { + count, err = rl.client.Incr(ctx, key).Uint64() + } + if err != nil { + return nil, fmt.Errorf("redis: error computing rate limit: %w", err) + } + if count <= params.Burst { + return &simpleReservation{ + delay: 0, + ok: true, + tokens: params.Burst - count, + }, nil + } + return &simpleReservation{ + delay: now.Sub(time.UnixMilli((epoch + 1) * + params.Interval.Milliseconds())), + ok: false, + tokens: 0, + }, nil +} + +func (rl *fixedWindowRatelimiter) Tokens(ctx context.Context) (uint64, error) { + params, err := rl.paramsFunc(ctx) + if err != nil { + return 0, err + } + count, err := rl.client.Get(ctx, + fixedWindowKey(params.KeyPrefix, epoch(rl.nowFunc(), params.Interval)), + ).Uint64() + if errors.Is(err, redis.Nil) { + return params.Burst, nil + } else if err != nil { + return 0, fmt.Errorf("redis: error getting free tokens: %w", err) + } else if count > params.Burst { + return 0, nil + } + return params.Burst - count, nil +} diff --git a/backend/pkg/redis/ratelimit_test.go b/backend/pkg/redis/ratelimit_test.go new file mode 100644 index 00000000..20fa74a1 --- /dev/null +++ b/backend/pkg/redis/ratelimit_test.go @@ -0,0 +1,66 @@ +package redis + +import ( + "context" + "fmt" + "strings" + "testing" + "time" + + "github.com/mendersoftware/mender-server/pkg/rate" +) + +func TestFixedWindowRatelimit(t *testing.T) { + requireRedis(t) + t.Parallel() + + ctx := context.Background() + + client, err := ClientFromConnectionString(ctx, RedisURL) + if err != nil { + t.Errorf("could not connect to redis (%s): is redis running?", + RedisURL) + t.FailNow() + } + tMicro := time.Now().UnixMicro() + params := FixedRatelimitParams(RatelimitParams{ + Burst: 1, + Interval: time.Hour, + KeyPrefix: fmt.Sprintf("%s_%x", strings.ToLower(t.Name()), tMicro), + }) + rateLimiter := NewFixedWindowRateLimiter(client, params) + + // Freeze time to avoid time to progress to next window. + nowFrozen := time.Now() + rateLimiter.(*fixedWindowRatelimiter).nowFunc = func() time.Time { return nowFrozen } + + if tokens, _ := rateLimiter.Tokens(ctx); tokens != 1 { + t.Errorf("expected token available after initialization, actual: %d", tokens) + } + + var reservations [2]rate.Reservation + for i := 0; i < len(reservations); i++ { + reservations[i], err = rateLimiter.Reserve(ctx) + if err != nil { + t.Errorf("unexpected error reserving rate limit: %s", err.Error()) + t.FailNow() + } + } + if !reservations[0].OK() { + t.Errorf("expected first event to pass, but didn't") + } + if reservations[1].OK() { + t.Errorf("expected the second event to block, but didn't") + } + if remaining, err := rateLimiter.Tokens(ctx); err != nil { + t.Errorf("unexpected error retrieving remaining tokens: %s", err.Error()) + } else if remaining != 0 { + t.Errorf("expected 0 tokens remaining, actual: %d", remaining) + } + + if reservations[0].Tokens() != 0 { + t.Errorf("there should be no tokens left after first event") + } else if reservations[1].Tokens() != 0 { + t.Errorf("there should be no tokens left after second event") + } +} diff --git a/backend/pkg/redis/redis.go b/backend/pkg/redis/redis.go index fe2dda85..afcc4a7e 100644 --- a/backend/pkg/redis/redis.go +++ b/backend/pkg/redis/redis.go @@ -26,6 +26,8 @@ import ( "github.com/redis/go-redis/v9" ) +type Client = redis.Cmdable + // nolint:lll // NewClient creates a new redis client (Cmdable) from the parameters in the // connectionString URL format: diff --git a/backend/pkg/redis/redis_test.go b/backend/pkg/redis/redis_test.go new file mode 100644 index 00000000..fba9366c --- /dev/null +++ b/backend/pkg/redis/redis_test.go @@ -0,0 +1,18 @@ +package redis + +import ( + "os" + "testing" +) + +const EnvRedisURL = "TEST_REDIS_URL" + +var RedisURL = os.Getenv(EnvRedisURL) + +func requireRedis(t *testing.T) { + if RedisURL == "" { + t.Skipf("skipping test %q due to missing redis URL, "+ + "use environment variable %q to run test", + t.Name(), EnvRedisURL) + } +} diff --git a/backend/services/deviceauth/cache/cache.go b/backend/services/deviceauth/cache/cache.go index fc3b4686..f86d0c23 100644 --- a/backend/services/deviceauth/cache/cache.go +++ b/backend/services/deviceauth/cache/cache.go @@ -60,10 +60,11 @@ import ( "github.com/pkg/errors" "github.com/redis/go-redis/v9" + "github.com/mendersoftware/mender-server/pkg/identity" "github.com/mendersoftware/mender-server/pkg/log" "github.com/mendersoftware/mender-server/pkg/ratelimits" - mredis "github.com/mendersoftware/mender-server/pkg/redis" + "github.com/mendersoftware/mender-server/services/deviceauth/model" "github.com/mendersoftware/mender-server/services/deviceauth/utils" ) @@ -106,6 +107,13 @@ type Cache interface { // DeleteToken deletes the token for 'id' DeleteToken(ctx context.Context, tid, id, idtype string) error + // GetLimit gets a limit from cache (see store.Datastore.GetLimit) + GetLimit(ctx context.Context, name string) (*model.Limit, error) + // SetLimit writes a limit to cache (see store.Datastore.SetLimit) + SetLimit(ctx context.Context, limit *model.Limit) error + // DeleteLimit evicts the limit with the given name from cache + DeleteLimit(ctx context.Context, name string) error + // GetLimits fetches limits for 'id' GetLimits(ctx context.Context, tid, id, idtype string) (*ratelimits.ApiLimits, error) @@ -133,26 +141,22 @@ type RedisCache struct { c redis.Cmdable prefix string LimitsExpireSec int + DefaultExpire time.Duration clock utils.Clock } func NewRedisCache( - ctx context.Context, - connectionString string, + redisClient redis.Cmdable, prefix string, limitsExpireSec int, -) (*RedisCache, error) { - c, err := mredis.ClientFromConnectionString(ctx, connectionString) - if err != nil { - return nil, err - } - +) *RedisCache { return &RedisCache{ - c: c, + c: redisClient, LimitsExpireSec: limitsExpireSec, prefix: prefix, + DefaultExpire: time.Hour * 3, clock: utils.NewClock(), - }, err + } } func (rl *RedisCache) WithClock(c utils.Clock) *RedisCache { @@ -160,6 +164,56 @@ func (rl *RedisCache) WithClock(c utils.Clock) *RedisCache { return rl } +func (rl *RedisCache) keyLimit(tenantID, name string) string { + if tenantID == "" { + tenantID = "default" + } + return fmt.Sprintf("%s:tenant:%s:limit:%s", rl.prefix, tenantID, name) +} + +func (rl *RedisCache) GetLimit(ctx context.Context, name string) (*model.Limit, error) { + var tenantID string + id := identity.FromContext(ctx) + if id != nil { + tenantID = id.Tenant + } + value, err := rl.c.Get(ctx, rl.keyLimit(tenantID, name)).Uint64() + if err != nil { + if errors.Is(err, redis.Nil) { + return nil, nil + } + return nil, err + } + return &model.Limit{ + TenantID: tenantID, + Value: value, + Name: name, + }, nil +} + +func (rl *RedisCache) SetLimit(ctx context.Context, limit *model.Limit) error { + if limit == nil { + return nil + } + var tenantID string + id := identity.FromContext(ctx) + if id != nil { + tenantID = id.Tenant + } + key := rl.keyLimit(tenantID, limit.Name) + return rl.c.SetEx(ctx, key, limit.Value, rl.DefaultExpire).Err() +} + +func (rl *RedisCache) DeleteLimit(ctx context.Context, name string) error { + var tenantID string + id := identity.FromContext(ctx) + if id != nil { + tenantID = id.Tenant + } + key := rl.keyLimit(tenantID, name) + return rl.c.Del(ctx, key).Err() +} + func (rl *RedisCache) Throttle( ctx context.Context, rawToken string, @@ -235,7 +289,6 @@ func (rl *RedisCache) pipeToken( func (rl *RedisCache) checkToken(cmd *redis.StringCmd, raw string) (string, error) { err := cmd.Err() - if err != nil { if isErrRedisNil(err) { return "", nil @@ -381,7 +434,6 @@ func (rl *RedisCache) GetLimits( id, idtype string, ) (*ratelimits.ApiLimits, error) { - version, err := rl.getTenantKeyVersion(ctx, tid) if err != nil { return nil, err @@ -440,7 +492,8 @@ func (rl *RedisCache) KeyQuota(tid, id, idtype, intvlNum string, version int) st } func (rl *RedisCache) KeyBurst( - tid, id, idtype, url, action, intvlNum string, version int) string { + tid, id, idtype, url, action, intvlNum string, version int, +) string { return fmt.Sprintf( "%s:tenant:%s:version:%d:%s:%s:burst:%s:%s:%s", rl.prefix, tid, version, idtype, id, url, action, intvlNum) @@ -513,7 +566,6 @@ func (rl *RedisCache) GetCheckInTime( tid, id string, ) (*time.Time, error) { - version, err := rl.getTenantKeyVersion(ctx, tid) if err != nil { return nil, err diff --git a/backend/services/deviceauth/cache/cache_test.go b/backend/services/deviceauth/cache/cache_test.go index da0ec25a..de8b3246 100644 --- a/backend/services/deviceauth/cache/cache_test.go +++ b/backend/services/deviceauth/cache/cache_test.go @@ -20,6 +20,7 @@ import ( "time" "github.com/alicebob/miniredis" + "github.com/redis/go-redis/v9" "github.com/stretchr/testify/assert" "github.com/mendersoftware/mender-server/pkg/ratelimits" @@ -32,15 +33,24 @@ const ( cachePrefix = "deviceauth:v1" ) -func TestRedisCacheThrottleToken(t *testing.T) { +func newRedisClient(t *testing.T) (*miniredis.Miniredis, redis.Cmdable) { r := miniredis.NewMiniRedis() err := r.Start() - assert.NoError(t, err) - defer r.Close() + if !assert.NoError(t, err) { + t.FailNow() + } + t.Cleanup(r.Close) + client := redis.NewClient(&redis.Options{ + Addr: r.Addr(), + }) + return r, client +} + +func TestRedisCacheThrottleToken(t *testing.T) { + r, client := newRedisClient(t) ctx := context.TODO() - rcache, err := NewRedisCache(ctx, "redis://"+r.Addr(), cachePrefix, limitsExpSec) - assert.NoError(t, err) + rcache := NewRedisCache(client, cachePrefix, limitsExpSec) // token not found tok, err := rcache.Throttle(ctx, @@ -153,12 +163,9 @@ func TestRedisCacheThrottleToken(t *testing.T) { func TestRedisCacheTokenDelete(t *testing.T) { ctx := context.TODO() - r := miniredis.NewMiniRedis() - err := r.Start() - assert.NoError(t, err) - defer r.Close() + _, client := newRedisClient(t) - rcache, err := NewRedisCache(ctx, "redis://"+r.Addr(), cachePrefix, limitsExpSec) + rcache := NewRedisCache(client, cachePrefix, limitsExpSec) // cache 2 tokens, remove first one, other one should still be available rcache.CacheToken(ctx, @@ -175,7 +182,7 @@ func TestRedisCacheTokenDelete(t *testing.T) { "tokenstr-2", time.Duration(10*time.Second)) - err = rcache.DeleteToken(ctx, "tenant-foo", "device-1", IdTypeDevice) + err := rcache.DeleteToken(ctx, "tenant-foo", "device-1", IdTypeDevice) assert.NoError(t, err) tok1, err := rcache.Throttle(ctx, @@ -207,13 +214,9 @@ func TestRedisCacheTokenDelete(t *testing.T) { } func TestRedisCacheLimitsQuota(t *testing.T) { - r := miniredis.NewMiniRedis() - err := r.Start() - assert.NoError(t, err) - defer r.Close() + r, client := newRedisClient(t) - rcache, err := NewRedisCache(context.TODO(), "redis://"+r.Addr(), cachePrefix, limitsExpSec) - assert.NoError(t, err) + rcache := NewRedisCache(client, cachePrefix, limitsExpSec) // apply quota l := ratelimits.ApiLimits{ @@ -249,13 +252,9 @@ func TestRedisCacheLimitsQuota(t *testing.T) { } func TestRedisCacheLimitsBurst(t *testing.T) { - r := miniredis.NewMiniRedis() - err := r.Start() - assert.NoError(t, err) - defer r.Close() + r, client := newRedisClient(t) - rcache, err := NewRedisCache(context.TODO(), "redis://"+r.Addr(), cachePrefix, limitsExpSec) - assert.NoError(t, err) + rcache := NewRedisCache(client, cachePrefix, limitsExpSec) clock := utils.NewMockClock(1590105600) rcache = rcache.WithClock(clock) @@ -293,13 +292,8 @@ func TestRedisCacheLimitsBurst(t *testing.T) { } func TestRedisCacheLimitsQuotaBurst(t *testing.T) { - r := miniredis.NewMiniRedis() - err := r.Start() - assert.NoError(t, err) - defer r.Close() - - rcache, err := NewRedisCache(context.TODO(), "redis://"+r.Addr(), cachePrefix, limitsExpSec) - assert.NoError(t, err) + r, client := newRedisClient(t) + rcache := NewRedisCache(client, cachePrefix, limitsExpSec) clock := utils.NewMockClock(1590105600) rcache = rcache.WithClock(clock) @@ -373,15 +367,11 @@ func TestRedisCacheLimitsQuotaBurst(t *testing.T) { } func TestRedisCacheGetSetLimits(t *testing.T) { - r := miniredis.NewMiniRedis() - err := r.Start() - assert.NoError(t, err) - defer r.Close() + r, client := newRedisClient(t) ctx := context.TODO() - rcache, err := NewRedisCache(ctx, "redis://"+r.Addr(), cachePrefix, limitsExpSec) - assert.NoError(t, err) + rcache := NewRedisCache(client, cachePrefix, limitsExpSec) res, err := rcache.GetLimits(ctx, "tenant-foo", "device-bar", IdTypeDevice) @@ -435,15 +425,11 @@ func fastForward(r *miniredis.Miniredis, c utils.Clock, secs int64) { } func TestRedisCacheGetSetCheckInTime(t *testing.T) { - r := miniredis.NewMiniRedis() - err := r.Start() - assert.NoError(t, err) - defer r.Close() + _, client := newRedisClient(t) ctx := context.TODO() - rcache, err := NewRedisCache(ctx, "redis://"+r.Addr(), cachePrefix, limitsExpSec) - assert.NoError(t, err) + rcache := NewRedisCache(client, cachePrefix, limitsExpSec) res, err := rcache.GetCheckInTime(ctx, "tenant-foo", "device-bar") diff --git a/backend/services/deviceauth/cache/mocks/Cache.go b/backend/services/deviceauth/cache/mocks/Cache.go index d2d55c70..b50da044 100644 --- a/backend/services/deviceauth/cache/mocks/Cache.go +++ b/backend/services/deviceauth/cache/mocks/Cache.go @@ -19,9 +19,11 @@ package mocks import ( context "context" - ratelimits "github.com/mendersoftware/mender-server/pkg/ratelimits" + model "github.com/mendersoftware/mender-server/services/deviceauth/model" mock "github.com/stretchr/testify/mock" + ratelimits "github.com/mendersoftware/mender-server/pkg/ratelimits" + time "time" ) @@ -84,6 +86,24 @@ func (_m *Cache) CacheToken(ctx context.Context, tid string, id string, idtype s return r0 } +// DeleteLimit provides a mock function with given fields: ctx, name +func (_m *Cache) DeleteLimit(ctx context.Context, name string) error { + ret := _m.Called(ctx, name) + + if len(ret) == 0 { + panic("no return value specified for DeleteLimit") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string) error); ok { + r0 = rf(ctx, name) + } else { + r0 = ret.Error(0) + } + + return r0 +} + // DeleteToken provides a mock function with given fields: ctx, tid, id, idtype func (_m *Cache) DeleteToken(ctx context.Context, tid string, id string, idtype string) error { ret := _m.Called(ctx, tid, id, idtype) @@ -162,6 +182,36 @@ func (_m *Cache) GetCheckInTimes(ctx context.Context, tid string, ids []string) return r0, r1 } +// GetLimit provides a mock function with given fields: ctx, name +func (_m *Cache) GetLimit(ctx context.Context, name string) (*model.Limit, error) { + ret := _m.Called(ctx, name) + + if len(ret) == 0 { + panic("no return value specified for GetLimit") + } + + var r0 *model.Limit + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (*model.Limit, error)); ok { + return rf(ctx, name) + } + if rf, ok := ret.Get(0).(func(context.Context, string) *model.Limit); ok { + r0 = rf(ctx, name) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*model.Limit) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, name) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // GetLimits provides a mock function with given fields: ctx, tid, id, idtype func (_m *Cache) GetLimits(ctx context.Context, tid string, id string, idtype string) (*ratelimits.ApiLimits, error) { ret := _m.Called(ctx, tid, id, idtype) @@ -192,6 +242,24 @@ func (_m *Cache) GetLimits(ctx context.Context, tid string, id string, idtype st return r0, r1 } +// SetLimit provides a mock function with given fields: ctx, limit +func (_m *Cache) SetLimit(ctx context.Context, limit *model.Limit) error { + ret := _m.Called(ctx, limit) + + if len(ret) == 0 { + panic("no return value specified for SetLimit") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, *model.Limit) error); ok { + r0 = rf(ctx, limit) + } else { + r0 = ret.Error(0) + } + + return r0 +} + // SuspendTenant provides a mock function with given fields: ctx, tid func (_m *Cache) SuspendTenant(ctx context.Context, tid string) error { ret := _m.Called(ctx, tid) diff --git a/backend/services/deviceauth/config.yaml b/backend/services/deviceauth/config.yaml index a15b29ff..d3caba44 100644 --- a/backend/services/deviceauth/config.yaml +++ b/backend/services/deviceauth/config.yaml @@ -147,8 +147,25 @@ # Overwrite with environment variable: DEVICEAUTH_REDIS_LIMITS_EXPIRE_SEC # redis_limits_expire_sec: "1800" - # Enable addon feature restrictions. # Defaults to: false # Overwrite with environment variable: DEVICEAUTH_HAVE_ADDONS # have_addons: false + +# Adaptive ratelimits based on device limits. +# The number of request allowed in a fixed `interval` is given by the device +# limit multiplied by the `quota_default` for the given plan. +# For Mender Enterprise, other quota factors can be specified in `quota_plan`. +# `quota_plan` can either be a YAML map or a string-slice (key=value) for +# overriding quotas for specific plans. Example follows: +# ratelimit_devices: +# interval: 1m +# quota_default: 1.0 +# quota_plan: +# os: 1.0 +# professional: 1.5 +# enterprise: 2.0 +# Using environment variables: +# DEVICEAUTH_RATELIMIT_INTERVAL: "1m" +# DEVICEAUTH_RATELIMIT_DEVICE_QUOTA_DEFAULT: "1.0" +# DEVICEAUTH_RATELMIIT_DEVICE_QUOTA_PLAN: "professional=1.5 enterprise=2.0" diff --git a/backend/services/deviceauth/config/config.go b/backend/services/deviceauth/config/config.go index aeebad0d..127807b7 100644 --- a/backend/services/deviceauth/config/config.go +++ b/backend/services/deviceauth/config/config.go @@ -15,6 +15,8 @@ package config import ( + "time" + "github.com/mendersoftware/mender-server/pkg/config" ) @@ -79,6 +81,15 @@ const ( // Has no effect if not running in multi-tenancy context. SettingHaveAddons = "have_addons" SettingHaveAddonsDefault = false + + // SettingRatelimits* configures adaptive rate limiting based on device limit. + // The `quota` sets the maximum average number of requests per device within + // `interval`. + SettingRatelimitsInterval = "ratelimits.interval" + SettingRatelimitsIntervalDefault = time.Minute + SettingRatelimitsQuotaDefault = "ratelimits.quota_default" + SettingRatelimitsQuotaDefaultDefault = 1.0 + SettingRatelimitsQuotas = "ratelimits.quota_plan" ) var ( @@ -101,5 +112,7 @@ var ( {Key: SettingRedisLimitsExpSec, Value: SettingRedisLimitsExpSecDefault}, {Key: SettingRedisKeyPrefix, Value: SettingRedisKeyPrefixDefault}, {Key: SettingHaveAddons, Value: SettingHaveAddonsDefault}, + {Key: SettingRatelimitsInterval, Value: SettingRatelimitsIntervalDefault}, + {Key: SettingRatelimitsQuotaDefault, Value: SettingRatelimitsQuotaDefaultDefault}, } ) diff --git a/backend/services/deviceauth/devauth/devauth.go b/backend/services/deviceauth/devauth/devauth.go index a11f4be8..b6b9bb55 100644 --- a/backend/services/deviceauth/devauth/devauth.go +++ b/backend/services/deviceauth/devauth/devauth.go @@ -28,6 +28,7 @@ import ( "github.com/mendersoftware/mender-server/pkg/log" "github.com/mendersoftware/mender-server/pkg/mongo/oid" "github.com/mendersoftware/mender-server/pkg/plan" + "github.com/mendersoftware/mender-server/pkg/rate" "github.com/mendersoftware/mender-server/pkg/ratelimits" "github.com/mendersoftware/mender-server/pkg/requestid" @@ -126,6 +127,11 @@ type DevAuth struct { cache cache.Cache clock utils.Clock checker access.Checker + + // deviceRatelimiter is used to limit authenticated device requests + rateLimiter rate.Limiter + rateLimiterWeights map[string]float64 + rateLimiterWeightDefault float64 } type Config struct { @@ -143,7 +149,8 @@ type Config struct { } func NewDevAuth(d store.DataStore, co orchestrator.ClientRunner, - jwt jwt.Handler, config Config) *DevAuth { + jwt jwt.Handler, config Config, +) *DevAuth { // initialize checker using an empty merge (returns nil on validate) checker := access.Merge() if config.HaveAddons { @@ -189,8 +196,8 @@ func (d *DevAuth) setDeviceIdentity(ctx context.Context, dev *model.Device, tena i := 0 for name, value := range dev.IdDataStruct { if name == "status" { - //we have to forbid the client to override attribute status in identity scope - //since it stands for status of a device (as in: accepted, rejected, preauthorized) + // we have to forbid the client to override attribute status in identity scope + // since it stands for status of a device (as in: accepted, rejected, preauthorized) continue } attribute := model.DeviceAttribute{ @@ -284,7 +291,6 @@ func (d *DevAuth) signToken(ctx context.Context) jwt.SignFunc { func (d *DevAuth) doVerifyTenant(ctx context.Context, token string) (*tenant.Tenant, error) { t, err := d.cTenant.VerifyToken(ctx, token) - if err != nil { if tenant.IsErrTokenVerificationFailed(err) { return nil, MakeErrDevAuthUnauthorized(err) @@ -314,7 +320,6 @@ func (d *DevAuth) getTenantWithDefault( // but continue on errors and maybe try the default token if tenantToken != "" { t, err = d.doVerifyTenant(ctx, tenantToken) - if err != nil { l.Errorf("Failed to verify supplied tenant token: %s", err.Error()) } @@ -429,7 +434,6 @@ func (d *DevAuth) SubmitAuthRequest(ctx context.Context, r *model.AuthReq) (stri // no token, return device unauthorized return "", ErrDevAuthUnauthorized - } func (d *DevAuth) handlePreAuthDevice( @@ -518,7 +522,6 @@ func (d *DevAuth) processPreAuthRequest( ctx context.Context, r *model.AuthReq, ) (*model.AuthSet, error) { - _, idDataSha256, err := parseIdData(r.IdData) if err != nil { return nil, MakeErrDevAuthBadRequest(err) @@ -617,7 +620,6 @@ func (d *DevAuth) processAuthRequest( ctx context.Context, r *model.AuthReq, ) (*model.AuthSet, error) { - l := log.FromContext(ctx) // get device associated with given authorization request @@ -748,7 +750,6 @@ func (d *DevAuth) GetDevice(ctx context.Context, devId string) (*model.Device, e // DecommissionDevice deletes device and all its tokens func (d *DevAuth) DecommissionDevice(ctx context.Context, devID string) error { - l := log.FromContext(ctx) l.Warnf("Decommission device with id: %s", devID) @@ -826,7 +827,6 @@ func (d *DevAuth) DeleteDevice(ctx context.Context, devID string) error { // Deletes device authentication set, and optionally the device. func (d *DevAuth) DeleteAuthSet(ctx context.Context, devID string, authId string) error { - l := log.FromContext(ctx) l.Warnf("Delete authentication set with id: "+ @@ -1221,7 +1221,6 @@ func (d *DevAuth) RevokeToken(ctx context.Context, tokenID string) error { } func verifyTenantClaim(ctx context.Context, verifyTenant bool, tenant string) error { - l := log.FromContext(ctx) if verifyTenant { @@ -1284,6 +1283,8 @@ func (d *DevAuth) VerifyToken(ctx context.Context, raw string) error { if err == cache.ErrTooManyRequests { return err + } else if errRate := d.checkRateLimits(ctx); errRate != nil { + return errRate } if cachedToken != "" && raw == cachedToken { @@ -1387,7 +1388,7 @@ func (d *DevAuth) cacheThrottleVerify( origMethod, origUri string, ) (string, error) { - if d.cache == nil { + if d.cache == nil || d.cTenant == nil { return "", nil } @@ -1395,7 +1396,6 @@ func (d *DevAuth) cacheThrottleVerify( limits, err := d.getApiLimits(ctx, token.Claims.Tenant, token.Claims.Subject.String()) - if err != nil { return "", err } @@ -1502,7 +1502,8 @@ func apiLimitsOverride(src, dest ratelimits.ApiLimits) ratelimits.ApiLimits { ratelimits.ApiBurst{ Action: bdest.Action, Uri: bdest.Uri, - MinIntervalSec: bdest.MinIntervalSec}, + MinIntervalSec: bdest.MinIntervalSec, + }, ) } } @@ -1512,16 +1513,35 @@ func apiLimitsOverride(src, dest ratelimits.ApiLimits) ratelimits.ApiLimits { } func (d *DevAuth) GetLimit(ctx context.Context, name string) (*model.Limit, error) { - lim, err := d.db.GetLimit(ctx, name) - - switch err { - case nil: - return lim, nil - case store.ErrLimitNotFound: - return &model.Limit{Name: name, Value: 0}, nil - default: - return nil, err + l := log.FromContext(ctx) + var ( + limit *model.Limit + err error + ) + if d.cache != nil { + limit, err = d.cache.GetLimit(ctx, name) + if err != nil { + l.Warnf("error fetching limit from cache: %s", err.Error()) + } + } + if limit == nil { + limit, err = d.db.GetLimit(ctx, name) + if err != nil { + if errors.Is(err, store.ErrLimitNotFound) { + limit = &model.Limit{Name: name, Value: 0} + err = nil + } else { + return nil, err + } + } + if d.cache != nil { + errCache := d.cache.SetLimit(ctx, limit) + if errCache != nil { + l.Warnf("failed to store limit %q in cache: %s", name, errCache.Error()) + } + } } + return limit, err } func (d *DevAuth) GetTenantLimit( @@ -1555,6 +1575,19 @@ func (d *DevAuth) WithCache(c cache.Cache) *DevAuth { return d } +func (d *DevAuth) WithRatelimits( + rl rate.Limiter, + weights map[string]float64, + defaultQuota float64, +) *DevAuth { + if rl != nil { + d.rateLimiter = rl + d.rateLimiterWeights = weights + d.rateLimiterWeightDefault = defaultQuota + } + return d +} + func (d *DevAuth) WithClock(c utils.Clock) *DevAuth { d.clock = c return d @@ -1575,6 +1608,12 @@ func (d *DevAuth) SetTenantLimit(ctx context.Context, tenant_id string, limit mo return errors.Wrapf(err, "failed to save limit %v for tenant %v to database", limit, tenant_id) } + if d.cache != nil { + errCache := d.cache.SetLimit(ctx, &limit) + if errCache != nil { + l.Warnf("failed to store limit %q in cache: %s", limit.Name, errCache.Error()) + } + } return nil } @@ -1593,6 +1632,12 @@ func (d *DevAuth) DeleteTenantLimit(ctx context.Context, tenant_id string, limit return errors.Wrapf(err, "failed to delete limit %v for tenant %v to database", limit, tenant_id) } + if d.cache != nil { + errCache := d.cache.DeleteLimit(ctx, limit) + if errCache != nil { + l.Warnf("error removing limit %q from cache: %s", limit, errCache.Error()) + } + } return nil } diff --git a/backend/services/deviceauth/devauth/devauth_ratelimits.go b/backend/services/deviceauth/devauth/devauth_ratelimits.go new file mode 100644 index 00000000..04a633b8 --- /dev/null +++ b/backend/services/deviceauth/devauth/devauth_ratelimits.go @@ -0,0 +1,75 @@ +package devauth + +import ( + "context" + "fmt" + "strings" + + ctxhttpheader "github.com/mendersoftware/mender-server/pkg/context/httpheader" + "github.com/mendersoftware/mender-server/pkg/identity" + + "github.com/mendersoftware/mender-server/services/deviceauth/cache" + "github.com/mendersoftware/mender-server/services/deviceauth/model" +) + +func (d *DevAuth) checkRateLimits(ctx context.Context) error { + if d.rateLimiter != nil { + rsp, err := d.rateLimiter.Reserve(ctx) + if err != nil { + return err + } else if !rsp.OK() { + return cache.ErrTooManyRequests + } + } + return nil +} + +const rateLimitMax = int64(1 << 50) + +func fmtEventID(tenantID, event string) string { + return fmt.Sprintf("tenant:%s:ratelimit:%s", tenantID, event) +} + +// rateLimitFromContext returns the burst quota given the context +func (d *DevAuth) RateLimitsFromContext(ctx context.Context) ( + limit int64, + eventID string, + err error, +) { + var tenantID string = "default" + var weight float64 = d.rateLimiterWeightDefault + id := identity.FromContext(ctx) + if id != nil { + tenantID = id.Tenant + plan := id.Plan + if w, ok := d.rateLimiterWeights[plan]; ok { + weight = w + } + } + origUri := ctxhttpheader.FromContext(ctx, "X-Forwarded-Uri") + origUri = purgeUriArgs(origUri) + if splitPath := strings.SplitN(origUri, "/", 5); len(splitPath) == 5 { + // discard `/api/devices/v*/` + origUri = splitPath[4] + } + lim, err := d.GetLimit(ctx, model.LimitMaxDeviceCount) + if err != nil { + return -1, "", err + } else if lim.Value == 0 { + return -1, "", nil + } + var limitf64 float64 + if lim.Value >= uint64(rateLimitMax) { + // overflow protection: 1 << 50 is practically unlimited + limitf64 = float64(rateLimitMax) + } else { + limitf64 = float64(lim.Value) + } + limitf64 *= weight + if limitf64 > float64(rateLimitMax) { + limit = rateLimitMax + } else { + limit = int64(limitf64) + } + return limit, fmtEventID(tenantID, origUri), nil +} diff --git a/backend/services/deviceauth/devauth/devauth_ratelimits_test.go b/backend/services/deviceauth/devauth/devauth_ratelimits_test.go new file mode 100644 index 00000000..8ea8ec53 --- /dev/null +++ b/backend/services/deviceauth/devauth/devauth_ratelimits_test.go @@ -0,0 +1,152 @@ +package devauth + +import ( + "context" + "errors" + "net/http" + "testing" + "time" + + "github.com/mendersoftware/mender-server/pkg/context/httpheader" + "github.com/mendersoftware/mender-server/pkg/identity" + "github.com/mendersoftware/mender-server/pkg/plan" + "github.com/mendersoftware/mender-server/pkg/rate" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + + "github.com/mendersoftware/mender-server/services/deviceauth/cache" + "github.com/mendersoftware/mender-server/services/deviceauth/model" + "github.com/mendersoftware/mender-server/services/deviceauth/store" + mstore "github.com/mendersoftware/mender-server/services/deviceauth/store/mocks" +) + +type errLimiter struct { + rate.Limiter + err error +} + +func (l errLimiter) Reserve(ctx context.Context) (rate.Reservation, error) { + return nil, l.err +} + +func TestCheckRateLimits(t *testing.T) { + t.Parallel() + + t.Run("ok/token bucket", func(t *testing.T) { + rateLimiter := rate.NewLimiter(1, time.Hour) + d := new(DevAuth).WithRatelimits(rateLimiter, make(map[string]float64), 1.0) + ctx := context.Background() + + err := d.checkRateLimits(ctx) + if err != nil { + t.Errorf("unexpected error on first rate limiter event: %s", err.Error()) + } + err = d.checkRateLimits(ctx) + if !errors.Is(err, cache.ErrTooManyRequests) { + if err == nil { + t.Errorf("expected error %q, received none", cache.ErrTooManyRequests.Error()) + } else { + t.Errorf("unexpected error on second rate limiter event: %s", err.Error()) + } + } + }) + t.Run("error/unknown error propagated", func(t *testing.T) { + var errExpected = errors.New("test error") + d := new(DevAuth).WithRatelimits(errLimiter{err: errExpected}, map[string]float64{}, 1.0) + errActual := d.checkRateLimits(context.Background()) + if !errors.Is(errActual, errExpected) { + t.Errorf("unexpected error: %s", errActual.Error()) + } + }) +} + +func TestRateLimitParamsFromContext(t *testing.T) { + t.Parallel() + + type testCase struct { + CTX context.Context + Store func(t *testing.T) store.DataStore + Weights map[string]float64 + + ResultLimit int64 + ResultEventID string + ResultError error + } + contextArgMatcher := mock.MatchedBy(func(context.Context) bool { return true }) + + for name, _tc := range map[string]testCase{ + "ok/no tenant context": testCase{ + CTX: httpheader.WithContext(context.Background(), + http.Header{ + "X-Forwarded-Uri": []string{"/api/devices/v1/foo/bar"}, + }, "X-Forwarded-Uri"), + Store: func(t *testing.T) store.DataStore { + ds := mstore.NewDataStore(t) + ds.On("GetLimit", contextArgMatcher, model.LimitMaxDeviceCount). + Return(&model.Limit{Name: model.LimitMaxDeviceCount, Value: 69}, nil) + return ds + }, + + ResultEventID: fmtEventID("default", "foo/bar"), + ResultLimit: 69, + }, + "ok/with tenant context": testCase{ + CTX: identity.WithContext( + httpheader.WithContext( + context.Background(), + http.Header{ + "X-Forwarded-Uri": []string{"/api/devices/v1/foo/bar"}, + }, "X-Forwarded-Uri"), + &identity.Identity{ + Tenant: "1234", + }), + Store: func(t *testing.T) store.DataStore { + ds := mstore.NewDataStore(t) + ds.On("GetLimit", contextArgMatcher, model.LimitMaxDeviceCount). + Return(&model.Limit{Name: model.LimitMaxDeviceCount, Value: 123}, nil) + return ds + }, + + ResultEventID: fmtEventID("1234", "foo/bar"), + ResultLimit: 123, + }, + "ok/float and int overflow": testCase{ + CTX: identity.WithContext( + httpheader.WithContext( + context.Background(), + http.Header{ + "X-Forwarded-Uri": []string{"/api/devices/v1/foo/bar"}, + }, "X-Forwarded-Uri"), + &identity.Identity{ + Tenant: "1234", + Plan: plan.PlanEnterprise, + }), + Store: func(t *testing.T) store.DataStore { + ds := mstore.NewDataStore(t) + ds.On("GetLimit", contextArgMatcher, model.LimitMaxDeviceCount). + Return(&model.Limit{Name: model.LimitMaxDeviceCount, Value: (1 << 61)}, nil) + return ds + }, + Weights: map[string]float64{ + plan.PlanEnterprise: 10.0, + plan.PlanProfessional: 5.0, + plan.PlanOpenSource: 2.0, + }, + + ResultEventID: fmtEventID("1234", "foo/bar"), + ResultLimit: rateLimitMax, + }, + } { + tc := _tc + t.Run(name, func(t *testing.T) { + ds := tc.Store(t) + devauth := NewDevAuth(ds, nil, nil, Config{}). + WithRatelimits(rate.NewLimiter(1, time.Minute), tc.Weights, 1.0) + limit, eventID, err := devauth.RateLimitsFromContext(tc.CTX) + assert.Equal(t, tc.ResultLimit, limit) + assert.Equal(t, tc.ResultEventID, eventID) + assert.ErrorIs(t, err, tc.ResultError) + }) + } + +} diff --git a/backend/services/deviceauth/main.go b/backend/services/deviceauth/main.go index 8551e218..36a6e571 100644 --- a/backend/services/deviceauth/main.go +++ b/backend/services/deviceauth/main.go @@ -233,6 +233,7 @@ func doMain(args []string) { // Enable setting config values by environment variables config.Config.SetEnvPrefix("DEVICEAUTH") + config.Config.SetEnvKeyReplacer(strings.NewReplacer(".", "_", "-", "_")) config.Config.AutomaticEnv() return nil @@ -267,7 +268,7 @@ func cmdServer(args *cli.Context) error { 3) } - l.Print("Device Authentication Service starting up") + l.Printf("Device Authentication Service %s starting up", args.App.Version) err = RunServer(config.Config) if err != nil { diff --git a/backend/services/deviceauth/server.go b/backend/services/deviceauth/server.go index 3383d576..14fe2f8d 100644 --- a/backend/services/deviceauth/server.go +++ b/backend/services/deviceauth/server.go @@ -16,13 +16,18 @@ package main import ( "context" + "fmt" "net/http" + "reflect" + "strconv" + "strings" "time" "github.com/pkg/errors" "github.com/mendersoftware/mender-server/pkg/config" "github.com/mendersoftware/mender-server/pkg/log" + "github.com/mendersoftware/mender-server/pkg/redis" api_http "github.com/mendersoftware/mender-server/services/deviceauth/api/http" "github.com/mendersoftware/mender-server/services/deviceauth/cache" @@ -105,18 +110,24 @@ func RunServer(c config.Reader) error { if cacheConnStr != "" { l.Infof("setting up redis cache") - cache, err := cache.NewRedisCache( - context.TODO(), - cacheConnStr, - c.GetString(dconfig.SettingRedisKeyPrefix), - c.GetInt(dconfig.SettingRedisLimitsExpSec), - ) - + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + redisClient, err := redis.ClientFromConnectionString(ctx, cacheConnStr) + cancel() if err != nil { - return err + return fmt.Errorf("failed to initialize redis client: %w", err) } + redisKeyPrefix := c.GetString(dconfig.SettingRedisKeyPrefix) + cache := cache.NewRedisCache( + redisClient, + redisKeyPrefix, + c.GetInt(dconfig.SettingRedisLimitsExpSec), + ) devauth = devauth.WithCache(cache) + err = setupRatelimits(c, devauth, redisKeyPrefix, redisClient) + if err != nil { + return fmt.Errorf("error configuring rate limits: %w", err) + } } devauthapi := api_http.NewDevAuthApiHandlers(devauth, db) @@ -131,3 +142,89 @@ func RunServer(c config.Reader) error { return http.ListenAndServe(addr, apiHandler) } + +func setupRatelimits( + c config.Reader, + devauth *devauth.DevAuth, + redisKeyPrefix string, + redisClient redis.Client, +) error { + if !c.IsSet(dconfig.SettingRatelimitsQuotas) { + return nil + } + quotas := make(map[string]float64) + // quotas can be given as either "plan=quota plan2=quota2" + // or as a map of string -> float64 + // Only the former can be backed by environment variables + quotaSlice := c.GetStringSlice(dconfig.SettingRatelimitsQuotas) + if len(quotaSlice) > 0 { + for i, keyValue := range quotaSlice { + key, value, ok := strings.Cut(keyValue, "=") + if !ok { + return fmt.Errorf( + `invalid config %s: value %v item #%d: missing key/value separator '='`, + dconfig.SettingRatelimitsQuotas, quotaSlice, i+1, + ) + } + valueF64, err := strconv.ParseFloat(value, 64) + if err != nil { + return fmt.Errorf("error parsing quota value: %w", err) + } + quotas[key] = valueF64 + } + } else { + // Check for map in config file + quotaMap := c.GetStringMap(dconfig.SettingRatelimitsQuotas) + if len(quotaMap) == 0 { + return fmt.Errorf( + "invalid config value %s: cannot be empty", + dconfig.SettingRatelimitsQuotas) + } + for key, valueAny := range quotaMap { + rVal := reflect.ValueOf(valueAny) + if rVal.CanFloat() { + quotas[key] = rVal.Float() + } else if rVal.CanInt() { + quotas[key] = float64(rVal.Int()) + } else if rVal.CanUint() { + quotas[key] = float64(rVal.Uint()) + } else { + return fmt.Errorf( + "invalid config value %s[%s]: not a numeric value", + dconfig.SettingRatelimitsQuotas, key, + ) + } + } + } + for key := range quotas { + if quotas[key] < 0.0 { + return fmt.Errorf("invalid config value %s[%s]: value must be a positive value", + dconfig.SettingRatelimitsQuotas, key) + } + } + log.NewEmpty().Infof("using rate limit quotas: %v", quotas) + + interval := c.GetDuration(dconfig.SettingRatelimitsInterval) + rateLimiter := redis.NewFixedWindowRateLimiter(redisClient, + func(ctx context.Context) (*redis.RatelimitParams, error) { + limit, eventID, err := devauth.RateLimitsFromContext(ctx) + if err != nil { + return nil, err + } else if limit < 0 { + return nil, nil + } + keyPrefix := redisKeyPrefix + ":" + eventID + return &redis.RatelimitParams{ + Burst: uint64(limit), + Interval: interval, + KeyPrefix: keyPrefix, + }, nil + }, + ) + devauth.WithRatelimits( + rateLimiter, + quotas, + c.GetFloat64(dconfig.SettingRatelimitsQuotaDefault), + ) + return nil +} diff --git a/backend/services/deviceauth/server_test.go b/backend/services/deviceauth/server_test.go new file mode 100644 index 00000000..b9c6d62a --- /dev/null +++ b/backend/services/deviceauth/server_test.go @@ -0,0 +1,106 @@ +package main + +import ( + "fmt" + "testing" + + "github.com/spf13/viper" + "github.com/stretchr/testify/assert" + + dconfig "github.com/mendersoftware/mender-server/services/deviceauth/config" + "github.com/mendersoftware/mender-server/services/deviceauth/devauth" +) + +func TestSetupRateLimits(t *testing.T) { + t.Parallel() + + type testCase struct { + Config *viper.Viper + + Error error + } + + for name, _tc := range map[string]testCase{ + "ok/slice": { + Config: func() *viper.Viper { + cfg := viper.New() + cfg.Set(dconfig.SettingRatelimitsQuotas, "enterprise=1.5 professional=0.75 os=0.5") + return cfg + }(), + }, + "ok/map": { + Config: func() *viper.Viper { + cfg := viper.New() + cfg.Set(dconfig.SettingRatelimitsQuotas, map[string]any{ + "enterprise": float64(2.5), + "professional": int(2), + "os": uint64(1), + }) + return cfg + }(), + }, + "ok/no limits": { + Config: func() *viper.Viper { + cfg := viper.New() + return cfg + }(), + }, + "error/negative value": { + Config: func() *viper.Viper { + cfg := viper.New() + cfg.Set(dconfig.SettingRatelimitsQuotas, map[string]any{"bad": -1.0}) + return cfg + }(), + Error: fmt.Errorf("invalid config value %s[bad]: value must be a positive value", + dconfig.SettingRatelimitsQuotas), + }, + "error/slice without separator": { + Config: func() *viper.Viper { + cfg := viper.New() + cfg.Set(dconfig.SettingRatelimitsQuotas, "foo bar baz") + return cfg + }(), + Error: fmt.Errorf("invalid config %s: value %v item #1: missing key/value separator '='", + dconfig.SettingRatelimitsQuotas, []string{"foo", "bar", "baz"}), + }, + "error/not convertible to float": { + Config: func() *viper.Viper { + cfg := viper.New() + cfg.Set(dconfig.SettingRatelimitsQuotas, "enterprise=many") + return cfg + }(), + Error: fmt.Errorf("error parsing quota value"), + }, + "error/unexpected config type": { + Config: func() *viper.Viper { + cfg := viper.New() + cfg.Set(dconfig.SettingRatelimitsQuotas, "") + return cfg + }(), + Error: fmt.Errorf("invalid config value %s: cannot be empty", + dconfig.SettingRatelimitsQuotas), + }, + "error/unexpected map type": { + Config: func() *viper.Viper { + cfg := viper.New() + cfg.Set(dconfig.SettingRatelimitsQuotas, map[string]any{"foo": "123"}) + return cfg + }(), + Error: fmt.Errorf("invalid config value %s[foo]: "+ + "not a numeric value", + dconfig.SettingRatelimitsQuotas), + }, + } { + tc := _tc + t.Run(name, func(t *testing.T) { + da := &devauth.DevAuth{} + err := setupRatelimits(tc.Config, da, "n/a", nil) + if tc.Error != nil { + assert.ErrorContains(t, err, tc.Error.Error()) + } else { + assert.NoError(t, err) + } + }) + } + +}