Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MEN-7733: Rate limits for devices APIs #202

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backend/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ require (
golang.org/x/net v0.31.0
golang.org/x/sys v0.27.0
golang.org/x/term v0.26.0
golang.org/x/time v0.7.0
gopkg.in/tomb.v2 v2.0.0-20161208151619-d5d1b5820637
gopkg.in/yaml.v3 v3.0.1
)
Expand Down Expand Up @@ -121,7 +122,6 @@ require (
golang.org/x/exp v0.0.0-20230905200255-921286631fa9 // indirect
golang.org/x/sync v0.9.0 // indirect
golang.org/x/text v0.20.0 // indirect
golang.org/x/time v0.7.0 // indirect
google.golang.org/protobuf v1.34.1 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect
)
70 changes: 70 additions & 0 deletions backend/pkg/rate/limit.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
package rate

import (
"context"
"time"

"golang.org/x/time/rate"
)

// Limiter implements a rate limiting interface based on golang.org/x/time/rate
// but extending the interface with the ability to expose internal errors.
type Limiter interface {
Reserve(ctx context.Context) (Reservation, error)
Tokens(ctx context.Context) (uint64, error)
}

// Reservation is a point-in-time reservation of a ratelimit token. If the token
// is approved OK return true, otherways Delay will return the duration for next
// token(s) to become available. While Tokens return the number of available
// tokens after the reservation.
type Reservation interface {
OK() bool
Delay() time.Duration
Tokens() uint64
}

type limiter rate.Limiter

func NewLimiter(limit int, interval time.Duration) Limiter {
return (*limiter)(rate.NewLimiter(rate.Every(interval/time.Duration(limit)), limit))
}

func (lim *limiter) Reserve(context.Context) (Reservation, error) {
now := time.Now()
goLimiter := (*rate.Limiter)(lim)
res := &reservation{
reservation: goLimiter.ReserveN(now, 1),
time: now,
}
if res.OK() {
res.tokens = uint64(goLimiter.TokensAt(now))
}
return res, nil
}

func (lim *limiter) Tokens(context.Context) (uint64, error) {
goLimiter := (*rate.Limiter)(lim)
if tokens := goLimiter.Tokens(); tokens > 0 {
return uint64(tokens), nil
}
return 0, nil
}

type reservation struct {
time time.Time
tokens uint64
reservation *rate.Reservation
}

func (r *reservation) OK() bool {
return r.Delay() == 0
}

func (r *reservation) Delay() time.Duration {
return r.reservation.DelayFrom(r.time)
}

func (r *reservation) Tokens() uint64 {
return r.tokens
}
47 changes: 47 additions & 0 deletions backend/pkg/rate/limit_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package rate

import (
"context"
"testing"
"time"
)

func TestLimiter(t *testing.T) {
t.Parallel()

ctx := context.Background()

limiter := NewLimiter(2, time.Hour)
resApproved, err := limiter.Reserve(ctx)
if err != nil {
t.Errorf("unexpected error reserving time slot #1: %s", err.Error())
t.FailNow()
}
_, err = limiter.Reserve(ctx)
if err != nil {
t.Errorf("unexpected error reserving time slot #2: %s", err.Error())
t.FailNow()
}
resDenied, err := limiter.Reserve(ctx)
if err != nil {
t.Errorf("unexpected error reserving time slot #3: %s", err.Error())
t.FailNow()
}
if !resApproved.OK() {
t.Error("expected first reservation to be available")
} else if resApproved.Delay() > 0 {
t.Error("an approved reservation should not have a delay")
} else if resApproved.Tokens() == 0 {
t.Error("expected more tokens to be available after first reservation")
}
if resDenied.OK() {
t.Error("reservation should not be available before 1h has passed")
} else {
if resDenied.Delay() == 0 {
t.Error("a denied reservation should have a non-zero delay")
}
if resDenied.Tokens() > 0 {
t.Errorf("a deneied reservation should not report free tokens, got: %d", resDenied.Tokens())
}
}
}
129 changes: 129 additions & 0 deletions backend/pkg/redis/ratelimit.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
package redis

import (
"context"
"errors"
"fmt"
"time"

"github.com/redis/go-redis/v9"

"github.com/mendersoftware/mender-server/pkg/rate"
)

func NewFixedWindowRateLimiter(
client Client,
paramsFromContext RatelimitParamsFunc,
) rate.Limiter {
return &fixedWindowRatelimiter{
client: client,
paramsFunc: paramsFromContext,
nowFunc: time.Now,
}
}

type RatelimitParams struct {
Burst uint64
Interval time.Duration
KeyPrefix string
}

type RatelimitParamsFunc func(context.Context) (*RatelimitParams, error)

func FixedRatelimitParams(params RatelimitParams) RatelimitParamsFunc {
return func(ctx context.Context) (*RatelimitParams, error) {
return &params, nil
}
}

type fixedWindowRatelimiter struct {
client Client
paramsFunc RatelimitParamsFunc
nowFunc func() time.Time
}

type simpleReservation struct {
ok bool
tokens uint64
delay time.Duration
}

func (r *simpleReservation) OK() bool {
return r.ok
}

func (r *simpleReservation) Delay() time.Duration {
return r.delay
}

func (r *simpleReservation) Tokens() uint64 {
return r.tokens
}

func epoch(t time.Time, interval time.Duration) int64 {
return t.UnixMilli() / interval.Milliseconds()
}

func fixedWindowKey(prefix string, epoch int64) string {
if prefix == "" {
prefix = "ratelimit"
}
return fmt.Sprintf("%s:e:%d:c", prefix, epoch)
}

func (rl *fixedWindowRatelimiter) Reserve(ctx context.Context) (rate.Reservation, error) {
now := rl.nowFunc()
params, err := rl.paramsFunc(ctx)
if err != nil {
return nil, err
} else if params == nil {
return &simpleReservation{
ok: true,
}, nil
}
epoch := epoch(now, params.Interval)
kjaskiewiczz marked this conversation as resolved.
Show resolved Hide resolved
key := fixedWindowKey(params.KeyPrefix, epoch)
count := uint64(1)

err = rl.client.SetArgs(ctx, key, count, redis.SetArgs{
TTL: params.Interval,
Mode: `NX`,
}).Err()
if errors.Is(err, redis.Nil) {
count, err = rl.client.Incr(ctx, key).Uint64()
}
if err != nil {
return nil, fmt.Errorf("redis: error computing rate limit: %w", err)
}
if count <= params.Burst {
return &simpleReservation{
delay: 0,
ok: true,
tokens: params.Burst - count,
}, nil
}
return &simpleReservation{
delay: now.Sub(time.UnixMilli((epoch + 1) *
params.Interval.Milliseconds())),
ok: false,
tokens: 0,
}, nil
}

func (rl *fixedWindowRatelimiter) Tokens(ctx context.Context) (uint64, error) {
params, err := rl.paramsFunc(ctx)
if err != nil {
return 0, err
}
count, err := rl.client.Get(ctx,
fixedWindowKey(params.KeyPrefix, epoch(rl.nowFunc(), params.Interval)),
).Uint64()
if errors.Is(err, redis.Nil) {
return params.Burst, nil
} else if err != nil {
return 0, fmt.Errorf("redis: error getting free tokens: %w", err)
} else if count > params.Burst {
return 0, nil
}
return params.Burst - count, nil
}
66 changes: 66 additions & 0 deletions backend/pkg/redis/ratelimit_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package redis

import (
"context"
"fmt"
"strings"
"testing"
"time"

"github.com/mendersoftware/mender-server/pkg/rate"
)

func TestFixedWindowRatelimit(t *testing.T) {
requireRedis(t)
t.Parallel()

ctx := context.Background()

client, err := ClientFromConnectionString(ctx, RedisURL)
if err != nil {
t.Errorf("could not connect to redis (%s): is redis running?",
RedisURL)
t.FailNow()
}
tMicro := time.Now().UnixMicro()
params := FixedRatelimitParams(RatelimitParams{
Burst: 1,
Interval: time.Hour,
KeyPrefix: fmt.Sprintf("%s_%x", strings.ToLower(t.Name()), tMicro),
})
rateLimiter := NewFixedWindowRateLimiter(client, params)

// Freeze time to avoid time to progress to next window.
nowFrozen := time.Now()
rateLimiter.(*fixedWindowRatelimiter).nowFunc = func() time.Time { return nowFrozen }

if tokens, _ := rateLimiter.Tokens(ctx); tokens != 1 {
t.Errorf("expected token available after initialization, actual: %d", tokens)
}

var reservations [2]rate.Reservation
for i := 0; i < len(reservations); i++ {
reservations[i], err = rateLimiter.Reserve(ctx)
if err != nil {
t.Errorf("unexpected error reserving rate limit: %s", err.Error())
t.FailNow()
}
}
if !reservations[0].OK() {
t.Errorf("expected first event to pass, but didn't")
}
if reservations[1].OK() {
t.Errorf("expected the second event to block, but didn't")
}
if remaining, err := rateLimiter.Tokens(ctx); err != nil {
t.Errorf("unexpected error retrieving remaining tokens: %s", err.Error())
} else if remaining != 0 {
t.Errorf("expected 0 tokens remaining, actual: %d", remaining)
}

if reservations[0].Tokens() != 0 {
t.Errorf("there should be no tokens left after first event")
} else if reservations[1].Tokens() != 0 {
t.Errorf("there should be no tokens left after second event")
}
}
2 changes: 2 additions & 0 deletions backend/pkg/redis/redis.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ import (
"github.com/redis/go-redis/v9"
)

type Client = redis.Cmdable

// nolint:lll
// NewClient creates a new redis client (Cmdable) from the parameters in the
// connectionString URL format:
Expand Down
18 changes: 18 additions & 0 deletions backend/pkg/redis/redis_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package redis

import (
"os"
"testing"
)

const EnvRedisURL = "TEST_REDIS_URL"

var RedisURL = os.Getenv(EnvRedisURL)

func requireRedis(t *testing.T) {
if RedisURL == "" {
t.Skipf("skipping test %q due to missing redis URL, "+
"use environment variable %q to run test",
t.Name(), EnvRedisURL)
}
}
Loading