diff --git a/auth/api/http/keys/endpoint.go b/auth/api/http/keys/endpoint.go index 4c3d1b7ecc..6aa1788b0f 100644 --- a/auth/api/http/keys/endpoint.go +++ b/auth/api/http/keys/endpoint.go @@ -85,3 +85,18 @@ func revokeEndpoint(svc auth.Service) endpoint.Endpoint { return revokeKeyRes{}, nil } } + +func revokeTokenEndpoint(svc auth.Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(revokeTokenReq) + if err := req.validate(); err != nil { + return nil, err + } + + if err := svc.RevokeToken(ctx, req.token); err != nil { + return nil, err + } + + return revokeKeyRes{}, nil + } +} diff --git a/auth/api/http/keys/endpoint_test.go b/auth/api/http/keys/endpoint_test.go index 4ed62a340d..05cb5b171b 100644 --- a/auth/api/http/keys/endpoint_test.go +++ b/auth/api/http/keys/endpoint_test.go @@ -4,7 +4,6 @@ package keys_test import ( - "context" "encoding/json" "fmt" "io" @@ -16,8 +15,8 @@ import ( "github.com/absmach/magistrala/auth" httpapi "github.com/absmach/magistrala/auth/api/http" - "github.com/absmach/magistrala/auth/jwt" "github.com/absmach/magistrala/auth/mocks" + "github.com/absmach/magistrala/internal/testsutil" mglog "github.com/absmach/magistrala/logger" "github.com/absmach/magistrala/pkg/apiutil" svcerr "github.com/absmach/magistrala/pkg/errors/service" @@ -93,9 +92,7 @@ func toJSON(data interface{}) string { } func TestIssue(t *testing.T) { - svc, krepo := newService() - token, err := svc.Issue(context.Background(), "", auth.Key{Type: auth.AccessKey, IssuedAt: time.Now(), Subject: id}) - assert.Nil(t, err, fmt.Sprintf("Issuing login key expected to succeed: %s", err)) + svc := new(mocks.Service) ts := newServer(svc) defer ts.Close() @@ -110,11 +107,14 @@ func TestIssue(t *testing.T) { req string ct string token string + resp auth.Token + err error status int }{ { desc: "issue login key with empty token", req: toJSON(lk), + resp: auth.Token{AccessToken: "token"}, ct: contentType, token: "", status: http.StatusUnauthorized, @@ -122,29 +122,30 @@ func TestIssue(t *testing.T) { { desc: "issue API key", req: toJSON(ak), + resp: auth.Token{AccessToken: "token"}, ct: contentType, - token: token.AccessToken, + token: "token", status: http.StatusCreated, }, { desc: "issue recovery key", req: toJSON(rk), ct: contentType, - token: token.AccessToken, + token: "token", status: http.StatusCreated, }, { desc: "issue login key wrong content type", req: toJSON(lk), ct: "", - token: token.AccessToken, + token: "token", status: http.StatusUnsupportedMediaType, }, { desc: "issue recovery key wrong content type", req: toJSON(rk), ct: "", - token: token.AccessToken, + token: "token", status: http.StatusUnsupportedMediaType, }, { @@ -152,6 +153,7 @@ func TestIssue(t *testing.T) { req: toJSON(ak), ct: contentType, token: "wrong", + err: svcerr.ErrAuthentication, status: http.StatusUnauthorized, }, { @@ -159,27 +161,28 @@ func TestIssue(t *testing.T) { req: toJSON(rk), ct: contentType, token: "", + err: svcerr.ErrAuthentication, status: http.StatusUnauthorized, }, { desc: "issue key with invalid request", req: "{", ct: contentType, - token: token.AccessToken, + token: "token", status: http.StatusBadRequest, }, { desc: "issue key with invalid JSON", req: "{invalid}", ct: contentType, - token: token.AccessToken, + token: "token", status: http.StatusBadRequest, }, { desc: "issue key with invalid JSON content", req: `{"Type":{"key":"AccessToken"}}`, ct: contentType, - token: token.AccessToken, + token: "token", status: http.StatusBadRequest, }, } @@ -193,24 +196,16 @@ func TestIssue(t *testing.T) { token: tc.token, body: strings.NewReader(tc.req), } - repocall := krepo.On("Save", mock.Anything, mock.Anything).Return("", nil) + svcCall := svc.On("Issue", mock.Anything, tc.token, mock.Anything).Return(tc.resp, tc.err) res, err := req.make() assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode)) - repocall.Unset() + svcCall.Unset() } } func TestRetrieve(t *testing.T) { - svc, krepo := newService() - token, err := svc.Issue(context.Background(), "", auth.Key{Type: auth.AccessKey, IssuedAt: time.Now(), Subject: id}) - assert.Nil(t, err, fmt.Sprintf("Issuing login key expected to succeed: %s", err)) - key := auth.Key{Type: auth.APIKey, IssuedAt: time.Now(), Subject: id} - - repocall := krepo.On("Save", mock.Anything, mock.Anything).Return(mock.Anything, nil) - k, err := svc.Issue(context.Background(), token.AccessToken, key) - assert.Nil(t, err, fmt.Sprintf("Issuing login key expected to succeed: %s", err)) - repocall.Unset() + svc := new(mocks.Service) ts := newServer(svc) defer ts.Close() @@ -226,8 +221,8 @@ func TestRetrieve(t *testing.T) { }{ { desc: "retrieve an existing key", - id: k.AccessToken, - token: token.AccessToken, + id: testsutil.GenerateUUID(t), + token: "token", key: auth.Key{ Subject: id, Type: auth.AccessKey, @@ -240,13 +235,13 @@ func TestRetrieve(t *testing.T) { { desc: "retrieve a non-existing key", id: "non-existing", - token: token.AccessToken, - status: http.StatusBadRequest, + token: "token", + status: http.StatusNotFound, err: svcerr.ErrNotFound, }, { desc: "retrieve a key with an invalid token", - id: k.AccessToken, + id: testsutil.GenerateUUID(t), token: "wrong", status: http.StatusUnauthorized, err: svcerr.ErrAuthentication, @@ -254,7 +249,7 @@ func TestRetrieve(t *testing.T) { { desc: "retrieve a key with an empty token", token: "", - id: k.AccessToken, + id: testsutil.GenerateUUID(t), status: http.StatusUnauthorized, err: svcerr.ErrAuthentication, }, @@ -267,24 +262,16 @@ func TestRetrieve(t *testing.T) { url: fmt.Sprintf("%s/keys/%s", ts.URL, tc.id), token: tc.token, } - repocall := krepo.On("Retrieve", mock.Anything, mock.Anything, mock.Anything).Return(tc.key, tc.err) + svcCall := svc.On("RetrieveKey", mock.Anything, tc.token, tc.id).Return(tc.key, tc.err) res, err := req.make() assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode)) - repocall.Unset() + svcCall.Unset() } } func TestRevoke(t *testing.T) { - svc, krepo := newService() - token, err := svc.Issue(context.Background(), "", auth.Key{Type: auth.AccessKey, IssuedAt: time.Now(), Subject: id}) - assert.Nil(t, err, fmt.Sprintf("Issuing login key expected to succeed: %s", err)) - key := auth.Key{Type: auth.APIKey, IssuedAt: time.Now(), Subject: id} - - repocall := krepo.On("Save", mock.Anything, mock.Anything).Return(mock.Anything, nil) - k, err := svc.Issue(context.Background(), token.AccessToken, key) - assert.Nil(t, err, fmt.Sprintf("Issuing login key expected to succeed: %s", err)) - repocall.Unset() + svc := new(mocks.Service) ts := newServer(svc) defer ts.Close() @@ -294,29 +281,31 @@ func TestRevoke(t *testing.T) { desc string id string token string + err error status int }{ { desc: "revoke an existing key", - id: k.AccessToken, - token: token.AccessToken, + id: testsutil.GenerateUUID(t), + token: "token", status: http.StatusNoContent, }, { desc: "revoke a non-existing key", id: "non-existing", - token: token.AccessToken, + token: "token", status: http.StatusNoContent, }, { desc: "revoke key with invalid token", - id: k.AccessToken, + id: testsutil.GenerateUUID(t), token: "wrong", + err: svcerr.ErrAuthentication, status: http.StatusUnauthorized, }, { desc: "revoke key with empty token", - id: k.AccessToken, + id: testsutil.GenerateUUID(t), token: "", status: http.StatusUnauthorized, }, @@ -329,10 +318,63 @@ func TestRevoke(t *testing.T) { url: fmt.Sprintf("%s/keys/%s", ts.URL, tc.id), token: tc.token, } - repocall := krepo.On("Remove", mock.Anything, mock.Anything, mock.Anything).Return(nil) + svcCall := svc.On("Revoke", mock.Anything, tc.token, tc.id).Return(tc.err) + res, err := req.make() + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) + assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode)) + svcCall.Unset() + } +} + +func TestRevokeToken(t *testing.T) { + svc := new(mocks.Service) + + ts := newServer(svc) + defer ts.Close() + client := ts.Client() + + cases := []struct { + desc string + id string + token string + err error + status int + }{ + { + desc: "revoke an existing token", + token: "token", + status: http.StatusNoContent, + }, + { + desc: "revoke a non-existing token", + token: "token", + err: svcerr.ErrAuthentication, + status: http.StatusUnauthorized, + }, + { + desc: "revoke invalid token", + token: "wrong", + err: svcerr.ErrAuthentication, + status: http.StatusUnauthorized, + }, + { + desc: "revoke empty token", + token: "", + status: http.StatusUnauthorized, + }, + } + + for _, tc := range cases { + req := testRequest{ + client: client, + method: http.MethodDelete, + url: fmt.Sprintf("%s/keys/", ts.URL), + token: tc.token, + } + svcCall := svc.On("RevokeToken", mock.Anything, tc.token).Return(tc.err) res, err := req.make() assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode)) - repocall.Unset() + svcCall.Unset() } } diff --git a/auth/api/http/keys/requests.go b/auth/api/http/keys/requests.go index 53542c60e6..2500769b26 100644 --- a/auth/api/http/keys/requests.go +++ b/auth/api/http/keys/requests.go @@ -46,3 +46,15 @@ func (req keyReq) validate() error { } return nil } + +type revokeTokenReq struct { + token string +} + +func (req revokeTokenReq) validate() error { + if req.token == "" { + return apiutil.ErrBearerToken + } + + return nil +} diff --git a/auth/api/http/keys/requests_test.go b/auth/api/http/keys/requests_test.go index 6172f24347..2bc5ece50c 100644 --- a/auth/api/http/keys/requests_test.go +++ b/auth/api/http/keys/requests_test.go @@ -86,3 +86,30 @@ func TestKeyReqValidate(t *testing.T) { assert.Equal(t, tc.err, err) } } + +func TestRevokeTokenReqValidate(t *testing.T) { + cases := []struct { + desc string + req revokeTokenReq + err error + }{ + { + desc: "valid request", + req: revokeTokenReq{ + token: valid, + }, + err: nil, + }, + { + desc: "empty token", + req: revokeTokenReq{ + token: "", + }, + err: apiutil.ErrBearerToken, + }, + } + for _, tc := range cases { + err := tc.req.validate() + assert.Equal(t, tc.err, err) + } +} diff --git a/auth/api/http/keys/transport.go b/auth/api/http/keys/transport.go index 9554df3ba1..d09da3ea87 100644 --- a/auth/api/http/keys/transport.go +++ b/auth/api/http/keys/transport.go @@ -33,6 +33,13 @@ func MakeHandler(svc auth.Service, mux *chi.Mux, logger *slog.Logger) *chi.Mux { opts..., ).ServeHTTP) + r.Delete("/", kithttp.NewServer( + revokeTokenEndpoint(svc), + decodeRevokeTokenReq, + api.EncodeResponse, + opts..., + ).ServeHTTP) + r.Get("/{id}", kithttp.NewServer( (retrieveEndpoint(svc)), decodeKeyReq, @@ -70,3 +77,11 @@ func decodeKeyReq(_ context.Context, r *http.Request) (interface{}, error) { } return req, nil } + +func decodeRevokeTokenReq(_ context.Context, r *http.Request) (interface{}, error) { + req := revokeTokenReq{ + token: apiutil.ExtractBearerToken(r), + } + + return req, nil +} diff --git a/auth/api/logging.go b/auth/api/logging.go index 30182bb4c4..fe4049b63c 100644 --- a/auth/api/logging.go +++ b/auth/api/logging.go @@ -63,6 +63,22 @@ func (lm *loggingMiddleware) Revoke(ctx context.Context, token, id string) (err return lm.svc.Revoke(ctx, token, id) } +func (lm *loggingMiddleware) RevokeToken(ctx context.Context, token string) (err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + } + if err != nil { + args = append(args, slog.Any("error", err)) + lm.logger.Warn("Revoke token failed to complete successfully", args...) + return + } + lm.logger.Info("Revoke token completed successfully", args...) + }(time.Now()) + + return lm.svc.RevokeToken(ctx, token) +} + func (lm *loggingMiddleware) RetrieveKey(ctx context.Context, token, id string) (key auth.Key, err error) { defer func(begin time.Time) { args := []any{ diff --git a/auth/api/metrics.go b/auth/api/metrics.go index 1e2befa8d2..b58ce9c086 100644 --- a/auth/api/metrics.go +++ b/auth/api/metrics.go @@ -49,6 +49,15 @@ func (ms *metricsMiddleware) Revoke(ctx context.Context, token, id string) error return ms.svc.Revoke(ctx, token, id) } +func (ms *metricsMiddleware) RevokeToken(ctx context.Context, token string) error { + defer func(begin time.Time) { + ms.counter.With("method", "revoke_token").Add(1) + ms.latency.With("method", "revoke_token").Observe(time.Since(begin).Seconds()) + }(time.Now()) + + return ms.svc.RevokeToken(ctx, token) +} + func (ms *metricsMiddleware) RetrieveKey(ctx context.Context, token, id string) (auth.Key, error) { defer func(begin time.Time) { ms.counter.With("method", "retrieve_key").Add(1) diff --git a/auth/cache/policies_test.go b/auth/cache/policies_test.go index 54a65957a4..b82020cbca 100644 --- a/auth/cache/policies_test.go +++ b/auth/cache/policies_test.go @@ -27,7 +27,7 @@ var policy = auth.PolicyReq{ Permission: auth.ViewPermission, } -func setupRedisClient(t *testing.T) auth.Cache { +func setupRedisCacheClient(t *testing.T) auth.Cache { opts, err := redis.ParseURL(redisURL) assert.Nil(t, err, fmt.Sprintf("got unexpected error on parsing redis URL: %s", err)) redisClient := redis.NewClient(opts) @@ -35,7 +35,7 @@ func setupRedisClient(t *testing.T) auth.Cache { } func TestSave(t *testing.T) { - authCache := setupRedisClient(t) + authCache := setupRedisCacheClient(t) cases := []struct { desc string @@ -153,7 +153,7 @@ func TestSave(t *testing.T) { } func TestContains(t *testing.T) { - authCache := setupRedisClient(t) + authCache := setupRedisCacheClient(t) key, val := policy.KV() err := authCache.Save(context.Background(), key, val) @@ -237,7 +237,7 @@ func TestContains(t *testing.T) { } func TestRemove(t *testing.T) { - authCache := setupRedisClient(t) + authCache := setupRedisCacheClient(t) subject := policy.Subject object := policy.Object diff --git a/auth/cache/tokens.go b/auth/cache/tokens.go new file mode 100644 index 0000000000..fa54e1184e --- /dev/null +++ b/auth/cache/tokens.go @@ -0,0 +1,56 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package cache + +import ( + "context" + "time" + + "github.com/absmach/magistrala/auth" + "github.com/absmach/magistrala/pkg/errors" + repoerr "github.com/absmach/magistrala/pkg/errors/repository" + "github.com/redis/go-redis/v9" +) + +const defKey = "revoked_tokens" + +var _ auth.Cache = (*tokensCache)(nil) + +type tokensCache struct { + client *redis.Client + keyDuration time.Duration +} + +// NewTokensCache returns redis auth cache implementation. +func NewTokensCache(client *redis.Client, duration time.Duration) auth.Cache { + return &tokensCache{ + client: client, + keyDuration: duration, + } +} + +func (tc *tokensCache) Save(ctx context.Context, _, value string) error { + if err := tc.client.SAdd(ctx, defKey, value, tc.keyDuration).Err(); err != nil { + return errors.Wrap(repoerr.ErrCreateEntity, err) + } + + return nil +} + +func (tc *tokensCache) Contains(ctx context.Context, _, value string) bool { + ok, err := tc.client.SIsMember(ctx, defKey, value).Result() + if err != nil { + return false + } + + return ok +} + +func (tc *tokensCache) Remove(ctx context.Context, value string) error { + if err := tc.client.SRem(ctx, defKey, value).Err(); err != nil { + return errors.Wrap(repoerr.ErrRemoveEntity, err) + } + + return nil +} diff --git a/auth/cache/tokens_test.go b/auth/cache/tokens_test.go new file mode 100644 index 0000000000..8f9902073f --- /dev/null +++ b/auth/cache/tokens_test.go @@ -0,0 +1,184 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package cache_test + +import ( + "context" + "fmt" + "strings" + "testing" + "time" + + "github.com/absmach/magistrala/auth" + "github.com/absmach/magistrala/auth/cache" + "github.com/absmach/magistrala/internal/testsutil" + "github.com/absmach/magistrala/pkg/errors" + repoerr "github.com/absmach/magistrala/pkg/errors/repository" + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/assert" +) + +var key = auth.Key{ + ID: testsutil.GenerateUUID(&testing.T{}), +} + +func setupRedisTokensClient(t *testing.T) auth.Cache { + opts, err := redis.ParseURL(redisURL) + assert.Nil(t, err, fmt.Sprintf("got unexpected error on parsing redis URL: %s", err)) + redisClient := redis.NewClient(opts) + return cache.NewPoliciesCache(redisClient, 10*time.Minute) +} + +func TestTokenSave(t *testing.T) { + tokensCache := setupRedisTokensClient(t) + + cases := []struct { + desc string + key auth.Key + err error + }{ + { + desc: "Save token", + key: key, + err: nil, + }, + { + desc: "Save already cached policy", + key: key, + err: nil, + }, + { + desc: "Save another policy", + key: auth.Key{ + ID: testsutil.GenerateUUID(&testing.T{}), + }, + err: nil, + }, + { + desc: "Save policy with long key", + key: auth.Key{ + ID: strings.Repeat("a", 513*1024*1024), + }, + err: repoerr.ErrCreateEntity, + }, + { + desc: "Save policy with empty key", + err: nil, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + err := tokensCache.Save(context.Background(), "", tc.key.ID) + if err == nil { + ok := tokensCache.Contains(context.Background(), "", tc.key.ID) + assert.True(t, ok) + } + assert.True(t, errors.Contains(err, tc.err)) + }) + } +} + +func TestTokenContains(t *testing.T) { + tokensCache := setupRedisTokensClient(t) + + err := tokensCache.Save(context.Background(), "", key.ID) + assert.Nil(t, err, fmt.Sprintf("Unexpected error while trying to save: %s", err)) + + cases := []struct { + desc string + key auth.Key + ok bool + }{ + { + desc: "Contains existing key", + key: key, + ok: true, + }, + { + desc: "Contains non existing key", + key: auth.Key{ + ID: testsutil.GenerateUUID(&testing.T{}), + }, + }, + { + desc: "Contains key with long id", + key: auth.Key{ + ID: strings.Repeat("a", 513*1024*1024), + }, + }, + { + desc: "Contains key with empty id", + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + ok := tokensCache.Contains(context.Background(), "", tc.key.ID) + assert.Equal(t, tc.ok, ok) + }) + } +} + +func TestTokenRemove(t *testing.T) { + tokensCache := setupRedisTokensClient(t) + + num := 1000 + var ids []string + for i := 0; i < num; i++ { + id := testsutil.GenerateUUID(&testing.T{}) + err := tokensCache.Save(context.Background(), "", id) + assert.Nil(t, err, fmt.Sprintf("Unexpected error while trying to save: %s", err)) + ids = append(ids, id) + } + + cases := []struct { + desc string + id string + err error + }{ + { + desc: "Remove an existing id from cache", + id: ids[0], + err: nil, + }, + { + desc: "Remove multiple existing id from cache", + id: "*", + err: nil, + }, + { + desc: "Remove non existing id from cache", + id: testsutil.GenerateUUID(&testing.T{}), + err: nil, + }, + { + desc: "Remove policy with empty id from cache", + err: nil, + }, + { + desc: "Remove policy with long id from cache", + id: strings.Repeat("a", 513*1024*1024), + err: repoerr.ErrRemoveEntity, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + err := tokensCache.Remove(context.Background(), tc.id) + assert.True(t, errors.Contains(err, tc.err)) + if tc.id == "*" { + for _, id := range ids { + ok := tokensCache.Contains(context.Background(), "", id) + assert.False(t, ok) + } + return + } + if err == nil { + ok := tokensCache.Contains(context.Background(), "", tc.id) + assert.False(t, ok) + } + }) + } +} diff --git a/auth/events/streams.go b/auth/events/streams.go index 702242cfb0..87b6316fc3 100644 --- a/auth/events/streams.go +++ b/auth/events/streams.go @@ -204,6 +204,10 @@ func (es *eventStore) Revoke(ctx context.Context, token, id string) error { return es.svc.Revoke(ctx, token, id) } +func (es *eventStore) RevokeToken(ctx context.Context, token string) error { + return es.svc.RevokeToken(ctx, token) +} + func (es *eventStore) RetrieveKey(ctx context.Context, token, id string) (auth.Key, error) { return es.svc.RetrieveKey(ctx, token, id) } diff --git a/auth/jwt/token_test.go b/auth/jwt/token_test.go index 461adb95be..dafc990489 100644 --- a/auth/jwt/token_test.go +++ b/auth/jwt/token_test.go @@ -4,14 +4,17 @@ package jwt_test import ( + "context" "fmt" "testing" "time" "github.com/absmach/magistrala/auth" authjwt "github.com/absmach/magistrala/auth/jwt" + "github.com/absmach/magistrala/auth/mocks" "github.com/absmach/magistrala/internal/testsutil" "github.com/absmach/magistrala/pkg/errors" + repoerr "github.com/absmach/magistrala/pkg/errors/repository" svcerr "github.com/absmach/magistrala/pkg/errors/service" "github.com/lestrrat-go/jwx/v2/jwa" "github.com/lestrrat-go/jwx/v2/jwt" @@ -55,7 +58,9 @@ func newToken(issuerName string, key auth.Key) string { } func TestIssue(t *testing.T) { - tokenizer := authjwt.New([]byte(secret)) + repo := new(mocks.TokenRepository) + cache := new(mocks.Cache) + tokenizer := authjwt.New([]byte(secret), repo, cache) cases := []struct { desc string @@ -128,7 +133,9 @@ func TestIssue(t *testing.T) { } func TestParse(t *testing.T) { - tokenizer := authjwt.New([]byte(secret)) + repo := new(mocks.TokenRepository) + cache := new(mocks.Cache) + tokenizer := authjwt.New([]byte(secret), repo, cache) token, err := tokenizer.Issue(key()) require.Nil(t, err, fmt.Sprintf("issuing key expected to succeed: %s", err)) @@ -162,11 +169,19 @@ func TestParse(t *testing.T) { inValidToken := newToken("invalid", key()) + refreshKey := key() + refreshKey.Type = auth.RefreshKey + refreshToken, err := tokenizer.Issue(refreshKey) + require.Nil(t, err, fmt.Sprintf("issuing user key expected to succeed: %s", err)) + cases := []struct { - desc string - key auth.Key - token string - err error + desc string + key auth.Key + token string + cacheContains bool + repoContains bool + cacheSave error + err error }{ { desc: "parse valid key", @@ -222,14 +237,191 @@ func TestParse(t *testing.T) { token: emptyToken, err: nil, }, + { + desc: "parse refresh token", + key: refreshKey, + token: refreshToken, + cacheContains: false, + repoContains: false, + err: nil, + }, + { + desc: "parse revoked refresh token in cache", + key: refreshKey, + token: refreshToken, + cacheContains: true, + repoContains: false, + err: svcerr.ErrAuthentication, + }, + { + desc: "parse revoked refresh token not in cache", + key: refreshKey, + token: refreshToken, + cacheContains: false, + repoContains: true, + err: svcerr.ErrAuthentication, + }, + { + desc: "parse revoked refresh token failed to save in cache", + key: refreshKey, + token: refreshToken, + cacheContains: false, + repoContains: true, + cacheSave: repoerr.ErrCreateEntity, + err: svcerr.ErrAuthentication, + }, } for _, tc := range cases { - key, err := tokenizer.Parse(tc.token) + cacheCall := cache.On("Contains", context.Background(), "", tc.key.ID).Return(tc.cacheContains) + repoCall := repo.On("Contains", context.Background(), tc.key.ID).Return(tc.repoContains) + cacheCall1 := cache.On("Save", context.Background(), "", tc.key.ID).Return(tc.cacheSave) + key, err := tokenizer.Parse(context.Background(), tc.token) assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s expected %s, got %s", tc.desc, tc.err, err)) if err == nil { assert.Equal(t, tc.key, key, fmt.Sprintf("%s expected %v, got %v", tc.desc, tc.key, key)) } + cacheCall.Unset() + repoCall.Unset() + cacheCall1.Unset() + } +} + +func TestRevoke(t *testing.T) { + repo := new(mocks.TokenRepository) + cache := new(mocks.Cache) + tokenizer := authjwt.New([]byte(secret), repo, cache) + + token, err := tokenizer.Issue(key()) + require.Nil(t, err, fmt.Sprintf("issuing key expected to succeed: %s", err)) + + apiKey := key() + apiKey.Type = auth.APIKey + apiKey.ExpiresAt = time.Now().UTC().Add(-1 * time.Minute).Round(time.Second) + apiToken, err := tokenizer.Issue(apiKey) + require.Nil(t, err, fmt.Sprintf("issuing user key expected to succeed: %s", err)) + + expKey := key() + expKey.ExpiresAt = time.Now().UTC().Add(-1 * time.Minute).Round(time.Second) + expToken, err := tokenizer.Issue(expKey) + require.Nil(t, err, fmt.Sprintf("issuing expired key expected to succeed: %s", err)) + + emptyDomainKey := key() + emptyDomainKey.Domain = "" + emptyDomainToken, err := tokenizer.Issue(emptyDomainKey) + require.Nil(t, err, fmt.Sprintf("issuing user key expected to succeed: %s", err)) + + emptySubjectKey := key() + emptySubjectKey.Subject = "" + emptySubjectToken, err := tokenizer.Issue(emptySubjectKey) + require.Nil(t, err, fmt.Sprintf("issuing user key expected to succeed: %s", err)) + + emptyKey := key() + emptyKey.Domain = "" + emptyKey.Subject = "" + emptyToken, err := tokenizer.Issue(emptyKey) + require.Nil(t, err, fmt.Sprintf("issuing user key expected to succeed: %s", err)) + + inValidToken := newToken("invalid", key()) + + refreshKey := key() + refreshKey.Type = auth.RefreshKey + refreshToken, err := tokenizer.Issue(refreshKey) + require.Nil(t, err, fmt.Sprintf("issuing user key expected to succeed: %s", err)) + + cases := []struct { + desc string + key auth.Key + token string + repoErr error + cacheErr error + err error + }{ + { + desc: "revoke valid key", + key: key(), + token: token, + err: nil, + }, + { + desc: "revoke invalid key", + key: auth.Key{}, + token: "invalid", + err: svcerr.ErrAuthentication, + }, + { + desc: "revoke expired key", + key: auth.Key{}, + token: expToken, + err: auth.ErrExpiry, + }, + { + desc: "revoke expired API key", + key: apiKey, + token: apiToken, + err: auth.ErrExpiry, + }, + { + desc: "revoke token with invalid issuer", + key: auth.Key{}, + token: inValidToken, + err: errInvalidIssuer, + }, + { + desc: "revoke token with invalid content", + key: auth.Key{}, + token: newToken(issuerName, key()), + err: authjwt.ErrJSONHandle, + }, + { + desc: "revoke token with empty domain", + key: emptyDomainKey, + token: emptyDomainToken, + err: nil, + }, + { + desc: "revoke token with empty subject", + key: emptySubjectKey, + token: emptySubjectToken, + err: nil, + }, + { + desc: "revoke token with empty domain and subject", + key: emptyKey, + token: emptyToken, + err: nil, + }, + { + desc: "revoke refresh token", + key: refreshKey, + token: refreshToken, + err: nil, + }, + { + desc: "revoke revoked refresh token failed to save in cache", + key: refreshKey, + token: refreshToken, + repoErr: nil, + cacheErr: repoerr.ErrCreateEntity, + err: svcerr.ErrAuthentication, + }, + { + desc: "revoke revoked refresh token failed to save in cache", + key: refreshKey, + token: refreshToken, + repoErr: repoerr.ErrCreateEntity, + cacheErr: nil, + err: svcerr.ErrAuthentication, + }, + } + + for _, tc := range cases { + repoCall := repo.On("Save", context.Background(), tc.key.ID).Return(tc.repoErr) + cacheCall := cache.On("Save", context.Background(), "", tc.key.ID).Return(tc.cacheErr) + err := tokenizer.Revoke(context.Background(), tc.token) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s expected %s, got %s", tc.desc, tc.err, err)) + cacheCall.Unset() + repoCall.Unset() } } diff --git a/auth/jwt/tokenizer.go b/auth/jwt/tokenizer.go index ad79549016..0c48ac4577 100644 --- a/auth/jwt/tokenizer.go +++ b/auth/jwt/tokenizer.go @@ -26,6 +26,8 @@ var ( ErrValidateJWTToken = errors.New("failed to validate jwt token") // ErrJSONHandle indicates an error in handling JSON. ErrJSONHandle = errors.New("failed to perform operation JSON") + // errRevokedToken indicates that the token is revoked. + errRevokedToken = errors.New("token is revoked") ) const ( @@ -40,14 +42,18 @@ const ( type tokenizer struct { secret []byte + cache auth.Cache + repo auth.TokenRepository } var _ auth.Tokenizer = (*tokenizer)(nil) // NewRepository instantiates an implementation of Token repository. -func New(secret []byte) auth.Tokenizer { +func New(secret []byte, repo auth.TokenRepository, cache auth.Cache) auth.Tokenizer { return &tokenizer{ secret: secret, + repo: repo, + cache: cache, } } @@ -79,7 +85,7 @@ func (tok *tokenizer) Issue(key auth.Key) (string, error) { return string(signedTkn), nil } -func (tok *tokenizer) Parse(token string) (auth.Key, error) { +func (tok *tokenizer) Parse(ctx context.Context, token string) (auth.Key, error) { tkn, err := tok.validateToken(token) if err != nil { return auth.Key{}, errors.Wrap(svcerr.ErrAuthentication, err) @@ -90,9 +96,48 @@ func (tok *tokenizer) Parse(token string) (auth.Key, error) { return auth.Key{}, errors.Wrap(svcerr.ErrAuthentication, err) } + if key.Type == auth.RefreshKey { + switch tok.cache.Contains(ctx, "", key.ID) { + case true: + return auth.Key{}, errors.Wrap(svcerr.ErrAuthentication, errRevokedToken) + default: + if ok := tok.repo.Contains(ctx, key.ID); ok { + if err := tok.cache.Save(ctx, "", key.ID); err != nil { + return auth.Key{}, errors.Wrap(svcerr.ErrAuthentication, err) + } + + return auth.Key{}, errors.Wrap(svcerr.ErrAuthentication, errRevokedToken) + } + } + } + return key, nil } +func (tok *tokenizer) Revoke(ctx context.Context, token string) error { + tkn, err := tok.validateToken(token) + if err != nil { + return errors.Wrap(svcerr.ErrAuthentication, err) + } + + key, err := toKey(tkn) + if err != nil { + return errors.Wrap(svcerr.ErrAuthentication, err) + } + + if key.Type == auth.RefreshKey { + if err := tok.repo.Save(ctx, key.ID); err != nil { + return errors.Wrap(svcerr.ErrAuthentication, err) + } + + if err := tok.cache.Save(ctx, "", key.ID); err != nil { + return errors.Wrap(svcerr.ErrAuthentication, err) + } + } + + return nil +} + func (tok *tokenizer) validateToken(token string) (jwt.Token, error) { tkn, err := jwt.Parse( []byte(token), diff --git a/auth/mocks/cache.go b/auth/mocks/cache.go new file mode 100644 index 0000000000..f68b885bdb --- /dev/null +++ b/auth/mocks/cache.go @@ -0,0 +1,84 @@ +// Code generated by mockery v2.43.2. DO NOT EDIT. + +// Copyright (c) Abstract Machines + +package mocks + +import ( + context "context" + + mock "github.com/stretchr/testify/mock" +) + +// Cache is an autogenerated mock type for the Cache type +type Cache struct { + mock.Mock +} + +// Contains provides a mock function with given fields: ctx, key, value +func (_m *Cache) Contains(ctx context.Context, key string, value string) bool { + ret := _m.Called(ctx, key, value) + + if len(ret) == 0 { + panic("no return value specified for Contains") + } + + var r0 bool + if rf, ok := ret.Get(0).(func(context.Context, string, string) bool); ok { + r0 = rf(ctx, key, value) + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + +// Remove provides a mock function with given fields: ctx, key +func (_m *Cache) Remove(ctx context.Context, key string) error { + ret := _m.Called(ctx, key) + + if len(ret) == 0 { + panic("no return value specified for Remove") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string) error); ok { + r0 = rf(ctx, key) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Save provides a mock function with given fields: ctx, key, value +func (_m *Cache) Save(ctx context.Context, key string, value string) error { + ret := _m.Called(ctx, key, value) + + if len(ret) == 0 { + panic("no return value specified for Save") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok { + r0 = rf(ctx, key, value) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// NewCache creates a new instance of Cache. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewCache(t interface { + mock.TestingT + Cleanup(func()) +}) *Cache { + mock := &Cache{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/auth/mocks/service.go b/auth/mocks/service.go index 80ec2714fb..98da88e708 100644 --- a/auth/mocks/service.go +++ b/auth/mocks/service.go @@ -345,6 +345,24 @@ func (_m *Service) Revoke(ctx context.Context, token string, id string) error { return r0 } +// RevokeToken provides a mock function with given fields: ctx, token +func (_m *Service) RevokeToken(ctx context.Context, token string) error { + ret := _m.Called(ctx, token) + + if len(ret) == 0 { + panic("no return value specified for RevokeToken") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string) error); ok { + r0 = rf(ctx, token) + } else { + r0 = ret.Error(0) + } + + return r0 +} + // UnassignUser provides a mock function with given fields: ctx, token, id, userID func (_m *Service) UnassignUser(ctx context.Context, token string, id string, userID string) error { ret := _m.Called(ctx, token, id, userID) diff --git a/auth/mocks/token.go b/auth/mocks/token.go new file mode 100644 index 0000000000..0f4dfb8bf0 --- /dev/null +++ b/auth/mocks/token.go @@ -0,0 +1,66 @@ +// Code generated by mockery v2.43.2. DO NOT EDIT. + +// Copyright (c) Abstract Machines + +package mocks + +import ( + context "context" + + mock "github.com/stretchr/testify/mock" +) + +// TokenRepository is an autogenerated mock type for the TokenRepository type +type TokenRepository struct { + mock.Mock +} + +// Contains provides a mock function with given fields: ctx, id +func (_m *TokenRepository) Contains(ctx context.Context, id string) bool { + ret := _m.Called(ctx, id) + + if len(ret) == 0 { + panic("no return value specified for Contains") + } + + var r0 bool + if rf, ok := ret.Get(0).(func(context.Context, string) bool); ok { + r0 = rf(ctx, id) + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + +// Save provides a mock function with given fields: ctx, id +func (_m *TokenRepository) Save(ctx context.Context, id string) error { + ret := _m.Called(ctx, id) + + if len(ret) == 0 { + panic("no return value specified for Save") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string) error); ok { + r0 = rf(ctx, id) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// NewTokenRepository creates a new instance of TokenRepository. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewTokenRepository(t interface { + mock.TestingT + Cleanup(func()) +}) *TokenRepository { + mock := &TokenRepository{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/auth/postgres/init.go b/auth/postgres/init.go index ae69c3a0ca..eca0ad9132 100644 --- a/auth/postgres/init.go +++ b/auth/postgres/init.go @@ -57,6 +57,17 @@ func Migration() *migrate.MemoryMigrationSource { `ALTER TABLE domains ALTER COLUMN alias SET NOT NULL`, }, }, + { + Id: "auth_3", + Up: []string{ + `CREATE TABLE IF NOT EXISTS tokens ( + id VARCHAR(36) PRIMARY KEY + );`, + }, + Down: []string{ + `DROP TABLE IF EXISTS tokens`, + }, + }, }, } } diff --git a/auth/postgres/token.go b/auth/postgres/token.go new file mode 100644 index 0000000000..f63182a50e --- /dev/null +++ b/auth/postgres/token.go @@ -0,0 +1,61 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package postgres + +import ( + "context" + + "github.com/absmach/magistrala/auth" + "github.com/absmach/magistrala/pkg/errors" + repoerr "github.com/absmach/magistrala/pkg/errors/repository" + "github.com/absmach/magistrala/pkg/postgres" +) + +var _ auth.TokenRepository = (*tokenRepo)(nil) + +type tokenRepo struct { + db postgres.Database +} + +// NewTokensRepository instantiates a PostgreSQL implementation of tokens repository. +func NewTokensRepository(db postgres.Database) auth.TokenRepository { + return &tokenRepo{ + db: db, + } +} + +func (repo *tokenRepo) Save(ctx context.Context, id string) error { + q := `INSERT INTO tokens (id) VALUES ($1);` + + result, err := repo.db.ExecContext(ctx, q, id) + if err != nil { + return postgres.HandleError(repoerr.ErrCreateEntity, err) + } + if rows, err := result.RowsAffected(); rows == 0 { + return errors.Wrap(repoerr.ErrCreateEntity, err) + } + + return nil +} + +func (repo *tokenRepo) Contains(ctx context.Context, id string) bool { + q := `SELECT * FROM tokens WHERE id = $1;` + + rows, err := repo.db.QueryContext(ctx, q, id) + if err != nil { + return false + } + defer rows.Close() + + if rows.Next() { + id := "" + if err = rows.Scan(&id); err != nil { + return false + } + + return true + } + + return false +} diff --git a/auth/service.go b/auth/service.go index 8783e02a83..3d795bbd2c 100644 --- a/auth/service.go +++ b/auth/service.go @@ -63,6 +63,9 @@ type Authn interface { // issued by the user identified by the provided key. Revoke(ctx context.Context, token, id string) error + // RevokeToken revokes the token. + RevokeToken(ctx context.Context, token string) error + // RetrieveKey retrieves data for the Key identified by the provided // ID, that is issued by the user identified by the provided key. RetrieveKey(ctx context.Context, token, id string) (Key, error) @@ -116,6 +119,12 @@ func New(keys KeyRepository, domains DomainsRepository, idp magistrala.IDProvide func (svc service) Issue(ctx context.Context, token string, key Key) (Token, error) { key.IssuedAt = time.Now().UTC() + id, err := svc.idProvider.ID() + if err != nil { + return Token{}, errors.Wrap(errIssueUser, err) + } + key.ID = id + switch key.Type { case APIKey: return svc.userKey(ctx, token, key) @@ -131,7 +140,7 @@ func (svc service) Issue(ctx context.Context, token string, key Key) (Token, err } func (svc service) Revoke(ctx context.Context, token, id string) error { - issuerID, _, err := svc.authenticate(token) + issuerID, _, err := svc.authenticate(ctx, token) if err != nil { return errors.Wrap(errRevoke, err) } @@ -141,8 +150,12 @@ func (svc service) Revoke(ctx context.Context, token, id string) error { return nil } +func (svc service) RevokeToken(ctx context.Context, token string) error { + return svc.tokenizer.Revoke(ctx, token) +} + func (svc service) RetrieveKey(ctx context.Context, token, id string) (Key, error) { - issuerID, _, err := svc.authenticate(token) + issuerID, _, err := svc.authenticate(ctx, token) if err != nil { return Key{}, errors.Wrap(errRetrieve, err) } @@ -155,7 +168,7 @@ func (svc service) RetrieveKey(ctx context.Context, token, id string) (Key, erro } func (svc service) Identify(ctx context.Context, token string) (Key, error) { - key, err := svc.tokenizer.Parse(token) + key, err := svc.tokenizer.Parse(ctx, token) if errors.Contains(err, ErrExpiry) { err = svc.keys.Remove(ctx, key.Issuer, key.ID) return Key{}, errors.Wrap(svcerr.ErrAuthentication, errors.Wrap(ErrKeyExpired, err)) @@ -328,7 +341,7 @@ func (svc service) invitationKey(ctx context.Context, key Key) (Token, error) { } func (svc service) refreshKey(ctx context.Context, token string, key Key) (Token, error) { - k, err := svc.tokenizer.Parse(token) + k, err := svc.tokenizer.Parse(ctx, token) if err != nil { return Token{}, errors.Wrap(errRetrieve, err) } @@ -392,7 +405,7 @@ func (svc service) checkUserDomain(ctx context.Context, key Key) (subject string } func (svc service) userKey(ctx context.Context, token string, key Key) (Token, error) { - id, sub, err := svc.authenticate(token) + id, sub, err := svc.authenticate(ctx, token) if err != nil { return Token{}, errors.Wrap(errIssueUser, err) } @@ -402,12 +415,6 @@ func (svc service) userKey(ctx context.Context, token string, key Key) (Token, e key.Subject = sub } - keyID, err := svc.idProvider.ID() - if err != nil { - return Token{}, errors.Wrap(errIssueUser, err) - } - key.ID = keyID - if _, err := svc.keys.Save(ctx, key); err != nil { return Token{}, errors.Wrap(errIssueUser, err) } @@ -420,8 +427,8 @@ func (svc service) userKey(ctx context.Context, token string, key Key) (Token, e return Token{AccessToken: tkn}, nil } -func (svc service) authenticate(token string) (string, string, error) { - key, err := svc.tokenizer.Parse(token) +func (svc service) authenticate(ctx context.Context, token string) (string, string, error) { + key, err := svc.tokenizer.Parse(ctx, token) if err != nil { return "", "", errors.Wrap(svcerr.ErrAuthentication, err) } diff --git a/auth/service_test.go b/auth/service_test.go index ddc65e22f3..dd488d9614 100644 --- a/auth/service_test.go +++ b/auth/service_test.go @@ -65,14 +65,14 @@ var ( pEvaluator *policymocks.Evaluator ) -func newService() (auth.Service, string) { +func newService() (auth.Service, *mocks.TokenRepository, *mocks.Cache, string) { krepo = new(mocks.KeyRepository) drepo = new(mocks.DomainsRepository) pService = new(policymocks.Service) pEvaluator = new(policymocks.Evaluator) idProvider := uuid.NewMock() - t := jwt.New([]byte(secret)) + t := jwt.New([]byte(secret), trepo, cache) key := auth.Key{ IssuedAt: time.Now(), ExpiresAt: time.Now().Add(refreshDuration), @@ -86,10 +86,22 @@ func newService() (auth.Service, string) { return auth.New(krepo, drepo, idProvider, t, pEvaluator, pService, loginDuration, refreshDuration, invalidDuration), token } -func TestIssue(t *testing.T) { - svc, accessToken := newService() +func newMinimalService() auth.Service { + krepo = new(mocks.KeyRepository) + trepo := new(mocks.TokenRepository) + cache := new(mocks.Cache) + prepo = new(mocks.PolicyAgent) + drepo = new(mocks.DomainsRepository) + idProvider := uuid.NewMock() - n := jwt.New([]byte(secret)) + t := jwt.New([]byte(secret), trepo, cache) + + return auth.New(krepo, drepo, idProvider, t, prepo, loginDuration, refreshDuration, invalidDuration) +} + +func TestIssue(t *testing.T) { + svc, trepo, cache, accessToken := newService() + n := jwt.New([]byte(secret), trepo, cache) apikey := auth.Key{ IssuedAt: time.Now(), @@ -382,6 +394,9 @@ func TestIssue(t *testing.T) { checkDOmainPolicyReq policies.Policy checkPolicyErr error retrieveByIDErr error + cacheContains bool + repoContains bool + cacheSave error err error }{ { @@ -468,15 +483,16 @@ func TestIssue(t *testing.T) { ObjectType: policies.PlatformType, Permission: policies.AdminPermission, }, - token: "", - err: nil, + cacheContains: true, + repoContains: false, + token: refreshToken, + err: svcerr.ErrAuthentication, }, { desc: "issue invitation key with invalid pService", key: auth.Key{ - Type: auth.InvitationKey, + Type: auth.RefreshKey, IssuedAt: time.Now(), - Domain: groupName, }, checkPolicyRequest: policies.Policy{ SubjectType: policies.UserType, @@ -490,10 +506,11 @@ func TestIssue(t *testing.T) { ObjectType: policies.DomainType, Permission: policies.MembershipPermission, }, - token: refreshToken, - checkPolicyErr: svcerr.ErrAuthorization, - retrieveByIDErr: repoerr.ErrNotFound, - err: svcerr.ErrDomainAuthorization, + cacheContains: false, + repoContains: true, + cacheSave: repoerr.ErrCreateEntity, + token: refreshToken, + err: svcerr.ErrAuthentication, }, } for _, tc := range cases4 { @@ -502,14 +519,17 @@ func TestIssue(t *testing.T) { repoCall2 := pEvaluator.On("CheckPolicy", mock.Anything, tc.checkDOmainPolicyReq).Return(tc.checkPolicyErr) _, err := svc.Issue(context.Background(), tc.token, tc.key) assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s expected %s got %s\n", tc.desc, tc.err, err)) + cacheCall.Unset() + cacheCall1.Unset() repoCall.Unset() repoCall1.Unset() repoCall2.Unset() + repoCall3.Unset() } } func TestRevoke(t *testing.T) { - svc, _ := newService() + svc := newMinimalService() repocall := krepo.On("Save", mock.Anything, mock.Anything).Return(mock.Anything, errIssueUser) secret, err := svc.Issue(context.Background(), "", auth.Key{Type: auth.AccessKey, IssuedAt: time.Now(), Subject: id}) repocall.Unset() @@ -562,7 +582,7 @@ func TestRevoke(t *testing.T) { } func TestRetrieve(t *testing.T) { - svc, _ := newService() + svc := newMinimalService() repocall := krepo.On("Save", mock.Anything, mock.Anything).Return(mock.Anything, nil) secret, err := svc.Issue(context.Background(), "", auth.Key{Type: auth.AccessKey, IssuedAt: time.Now(), Subject: id}) assert.Nil(t, err, fmt.Sprintf("Issuing login key expected to succeed: %s", err)) @@ -632,7 +652,7 @@ func TestRetrieve(t *testing.T) { } func TestIdentify(t *testing.T) { - svc, _ := newService() + svc, trepo, cache, _ := newService() repocall := krepo.On("Save", mock.Anything, mock.Anything).Return(mock.Anything, nil) repocall1 := pEvaluator.On("CheckPolicy", mock.Anything, mock.Anything).Return(nil) @@ -658,7 +678,7 @@ func TestIdentify(t *testing.T) { assert.Nil(t, err, fmt.Sprintf("Issuing expired login key expected to succeed: %s", err)) repocall4.Unset() - te := jwt.New([]byte(secret)) + te := jwt.New([]byte(secret), trepo, cache) key := auth.Key{ IssuedAt: time.Now(), ExpiresAt: time.Now().Add(refreshDuration), @@ -737,7 +757,7 @@ func TestIdentify(t *testing.T) { } func TestAuthorize(t *testing.T) { - svc, accessToken := newService() + svc, trepo, cache, accessToken := newService() repocall := krepo.On("Save", mock.Anything, mock.Anything).Return(mock.Anything, nil) repocall1 := pEvaluator.On("CheckPolicy", mock.Anything, mock.Anything).Return(nil) @@ -758,7 +778,7 @@ func TestAuthorize(t *testing.T) { repocall2.Unset() repocall3.Unset() - te := jwt.New([]byte(secret)) + te := jwt.New([]byte(secret), trepo, cache) key := auth.Key{ IssuedAt: time.Now(), ExpiresAt: time.Now().Add(refreshDuration), @@ -1244,7 +1264,7 @@ func TestSwitchToPermission(t *testing.T) { } func TestCreateDomain(t *testing.T) { - svc, accessToken := newService() + svc, _, _, accessToken := newService() cases := []struct { desc string @@ -1369,7 +1389,7 @@ func TestCreateDomain(t *testing.T) { } func TestRetrieveDomain(t *testing.T) { - svc, accessToken := newService() + svc, _, _, accessToken := newService() cases := []struct { desc string @@ -1429,7 +1449,7 @@ func TestRetrieveDomain(t *testing.T) { } func TestRetrieveDomainPermissions(t *testing.T) { - svc, accessToken := newService() + svc, _, _, accessToken := newService() cases := []struct { desc string @@ -1488,7 +1508,7 @@ func TestRetrieveDomainPermissions(t *testing.T) { } func TestUpdateDomain(t *testing.T) { - svc, accessToken := newService() + svc, _, _, accessToken := newService() cases := []struct { desc string @@ -1568,7 +1588,7 @@ func TestUpdateDomain(t *testing.T) { } func TestChangeDomainStatus(t *testing.T) { - svc, accessToken := newService() + svc, _, _, accessToken := newService() disabledStatus := auth.DisabledStatus @@ -1645,7 +1665,7 @@ func TestChangeDomainStatus(t *testing.T) { } func TestListDomains(t *testing.T) { - svc, accessToken := newService() + svc, _, _, accessToken := newService() cases := []struct { desc string @@ -1711,7 +1731,7 @@ func TestListDomains(t *testing.T) { } func TestAssignUsers(t *testing.T) { - svc, accessToken := newService() + svc, _, _, accessToken := newService() cases := []struct { desc string @@ -2028,7 +2048,7 @@ func TestAssignUsers(t *testing.T) { } func TestUnassignUser(t *testing.T) { - svc, accessToken := newService() + svc, _, _, accessToken := newService() cases := []struct { desc string @@ -2255,7 +2275,7 @@ func TestUnassignUser(t *testing.T) { } func TestListUsersDomains(t *testing.T) { - svc, accessToken := newService() + svc, _, _, accessToken := newService() cases := []struct { desc string diff --git a/auth/tokenizer.go b/auth/tokenizer.go index 1aaed7df4f..991bbdf891 100644 --- a/auth/tokenizer.go +++ b/auth/tokenizer.go @@ -3,11 +3,27 @@ package auth +import "context" + // Tokenizer specifies API for encoding and decoding between string and Key. type Tokenizer interface { // Issue converts API Key to its string representation. Issue(key Key) (token string, err error) // Parse extracts API Key data from string token. - Parse(token string) (key Key, err error) + Parse(ctx context.Context, token string) (key Key, err error) + + // Revoke revokes the token. + Revoke(ctx context.Context, token string) error +} + +// TokenRepository specifies token persistence API. +// +//go:generate mockery --name TokenRepository --output=./mocks --filename token.go --quiet --note "Copyright (c) Abstract Machines" +type TokenRepository interface { + // Save persists the token. + Save(ctx context.Context, id string) (err error) + + // Contains checks if token with provided ID exists. + Contains(ctx context.Context, id string) (ok bool) } diff --git a/auth/tracing/tracing.go b/auth/tracing/tracing.go index 97b5f1790f..4443e1f5dc 100644 --- a/auth/tracing/tracing.go +++ b/auth/tracing/tracing.go @@ -44,6 +44,13 @@ func (tm *tracingMiddleware) Revoke(ctx context.Context, token, id string) error return tm.svc.Revoke(ctx, token, id) } +func (tm *tracingMiddleware) RevokeToken(ctx context.Context, token string) error { + ctx, span := tm.tracer.Start(ctx, "revoke") + defer span.End() + + return tm.svc.RevokeToken(ctx, token) +} + func (tm *tracingMiddleware) RetrieveKey(ctx context.Context, token, id string) (auth.Key, error) { ctx, span := tm.tracer.Start(ctx, "retrieve_key", trace.WithAttributes( attribute.String("id", id), diff --git a/cmd/auth/main.go b/cmd/auth/main.go index dcacee8eb4..f2f472cd44 100644 --- a/cmd/auth/main.go +++ b/cmd/auth/main.go @@ -222,6 +222,7 @@ func initSchema(ctx context.Context, client *authzed.ClientWithExperimental, sch func newService(ctx context.Context, db *sqlx.DB, tracer trace.Tracer, cfg config, dbConfig pgclient.Config, cacheClient *redis.Client, keyDuration time.Duration, logger *slog.Logger, spicedbClient *authzed.ClientWithExperimental) auth.Service { database := postgres.NewDatabase(db, dbConfig, tracer) keysRepo := apostgres.New(database) + tokensRepo := apostgres.NewTokensRepository(database) domainsRepo := apostgres.NewDomainRepository(database) idProvider := uuid.New()