From 386a39ef3275760003aae661a629f628129ffcf6 Mon Sep 17 00:00:00 2001 From: nyagamunene Date: Tue, 4 Feb 2025 17:04:14 +0300 Subject: [PATCH] fix tests Signed-off-by: nyagamunene --- auth/api/http/keys/endpoint_test.go | 3 +- auth/api/http/pats/endpoint.go | 7 +- auth/bolt/doc.go | 6 - auth/bolt/init.go | 21 - auth/bolt/pat.go | 812 -------------------------- auth/cache/pat.go | 2 +- auth/mocks/cache.go | 96 +++ auth/pat.go | 2 +- auth/postgres/init.go | 8 +- auth/postgres/pat.go | 77 +-- auth/postgres/repo.go | 161 +++-- auth/service.go | 12 +- auth/service_test.go | 4 +- go.mod | 1 - go.sum | 2 - internal/clients/bolt/bolt.go | 83 --- internal/clients/bolt/doc.go | 9 - pkg/groups/events/consumer/streams.go | 1 - 18 files changed, 214 insertions(+), 1093 deletions(-) delete mode 100644 auth/bolt/doc.go delete mode 100644 auth/bolt/init.go delete mode 100644 auth/bolt/pat.go create mode 100644 auth/mocks/cache.go delete mode 100644 internal/clients/bolt/bolt.go delete mode 100644 internal/clients/bolt/doc.go diff --git a/auth/api/http/keys/endpoint_test.go b/auth/api/http/keys/endpoint_test.go index 5b2693f751..52801c7d40 100644 --- a/auth/api/http/keys/endpoint_test.go +++ b/auth/api/http/keys/endpoint_test.go @@ -70,13 +70,14 @@ func (tr testRequest) make() (*http.Response, error) { func newService() (auth.Service, *mocks.KeyRepository) { krepo := new(mocks.KeyRepository) pRepo := new(mocks.PATSRepository) + cache := new(mocks.Cache) hash := new(mocks.Hasher) idProvider := uuid.NewMock() pService := new(policymocks.Service) pEvaluator := new(policymocks.Evaluator) t := jwt.New([]byte(secret)) - return auth.New(krepo, pRepo, hash, idProvider, t, pEvaluator, pService, loginDuration, refreshDuration, invalidDuration), krepo + return auth.New(krepo, pRepo, cache, hash, idProvider, t, pEvaluator, pService, loginDuration, refreshDuration, invalidDuration), krepo } func newServer(svc auth.Service) *httptest.Server { diff --git a/auth/api/http/pats/endpoint.go b/auth/api/http/pats/endpoint.go index 45e6b3c607..006d7d4955 100644 --- a/auth/api/http/pats/endpoint.go +++ b/auth/api/http/pats/endpoint.go @@ -33,7 +33,12 @@ func retrievePATEndpoint(svc auth.Service) endpoint.Endpoint { return nil, err } - pat, err := svc.RetrievePAT(ctx, req.token, req.id) + key, err := svc.Identify(ctx, req.token) + if err != nil { + return nil, err + } + + pat, err := svc.RetrievePAT(ctx, key.User, req.id) if err != nil { return nil, err } diff --git a/auth/bolt/doc.go b/auth/bolt/doc.go deleted file mode 100644 index dcd06ac566..0000000000 --- a/auth/bolt/doc.go +++ /dev/null @@ -1,6 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -// Package bolt contains PAT repository implementations using -// bolt as the underlying database. -package bolt diff --git a/auth/bolt/init.go b/auth/bolt/init.go deleted file mode 100644 index 490443e08a..0000000000 --- a/auth/bolt/init.go +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -// Package bolt contains PAT repository implementations using -// bolt as the underlying database. -package bolt - -// import ( -// "github.com/absmach/supermq/pkg/errors" -// bolt "go.etcd.io/bbolt" -// ) - -// var errInit = errors.New("failed to initialize BoltDB") - -// func Init(tx *bolt.Tx, bucket string) error { -// _, err := tx.CreateBucketIfNotExists([]byte(bucket)) -// if err != nil { -// return errors.Wrap(errInit, err) -// } -// return nil -// } diff --git a/auth/bolt/pat.go b/auth/bolt/pat.go deleted file mode 100644 index 05d4bb5061..0000000000 --- a/auth/bolt/pat.go +++ /dev/null @@ -1,812 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -package bolt - -// import ( -// "bytes" -// "context" -// "encoding/binary" -// "fmt" -// "strings" -// "time" - -// "github.com/absmach/supermq/auth" -// "github.com/absmach/supermq/pkg/errors" -// repoerr "github.com/absmach/supermq/pkg/errors/repository" -// bolt "go.etcd.io/bbolt" -// ) - -// const ( -// idKey = "id" -// userKey = "user" -// nameKey = "name" -// descriptionKey = "description" -// secretKey = "secret_key" -// scopeKey = "scope" -// issuedAtKey = "issued_at" -// expiresAtKey = "expires_at" -// updatedAtKey = "updated_at" -// lastUsedAtKey = "last_used_at" -// revokedKey = "revoked" -// revokedAtKey = "revoked_at" -// platformEntitiesKey = "platform_entities" -// patKey = "pat" - -// keySeparator = ":" -// anyID = "*" -// ) - -// var ( -// activateValue = []byte{0x00} -// revokedValue = []byte{0x01} -// entityValue = []byte{0x02} -// anyIDValue = []byte{0x03} -// selectedIDsValue = []byte{0x04} -// ) - -// type patRepo struct { -// db *bolt.DB -// bucketName string -// } - -// // NewPATSRepository instantiates a bolt -// // implementation of PAT repository. -// func NewPATSRepository(db *bolt.DB, bucketName string) auth.PATSRepository { -// return &patRepo{ -// db: db, -// bucketName: bucketName, -// } -// } - -// func (pr *patRepo) Save(ctx context.Context, pat auth.PAT) error { -// idxKey := []byte(pat.User + keySeparator + patKey + keySeparator + pat.ID) -// kv, err := patToKeyValue(pat) -// if err != nil { -// return err -// } -// return pr.db.Update(func(tx *bolt.Tx) error { -// rootBucket, err := pr.retrieveRootBucket(tx) -// if err != nil { -// return errors.Wrap(repoerr.ErrCreateEntity, err) -// } -// b, err := pr.createUserBucket(rootBucket, pat.User) -// if err != nil { -// return errors.Wrap(repoerr.ErrCreateEntity, err) -// } -// for key, value := range kv { -// fullKey := []byte(pat.ID + keySeparator + key) -// if err := b.Put(fullKey, value); err != nil { -// return errors.Wrap(repoerr.ErrCreateEntity, err) -// } -// } -// if err := rootBucket.Put(idxKey, []byte(pat.ID)); err != nil { -// return errors.Wrap(repoerr.ErrCreateEntity, err) -// } -// return nil -// }) -// } - -// func (pr *patRepo) Retrieve(ctx context.Context, userID, patID string) (auth.PAT, error) { -// prefix := []byte(patID + keySeparator) -// kv := map[string][]byte{} -// if err := pr.db.View(func(tx *bolt.Tx) error { -// b, err := pr.retrieveUserBucket(tx, userID, patID, repoerr.ErrViewEntity) -// if err != nil { -// return err -// } -// c := b.Cursor() -// for k, v := c.Seek(prefix); k != nil && bytes.HasPrefix(k, prefix); k, v = c.Next() { -// kv[string(k)] = v -// } -// return nil -// }); err != nil { -// return auth.PAT{}, err -// } - -// return keyValueToPAT(kv) -// } - -// func (pr *patRepo) RetrieveSecretAndRevokeStatus(ctx context.Context, userID, patID string) (string, bool, bool, error) { -// revoked := true -// expired := false -// keySecret := patID + keySeparator + secretKey -// keyRevoked := patID + keySeparator + revokedKey -// keyExpiresAt := patID + keySeparator + expiresAtKey -// var secretHash string -// if err := pr.db.View(func(tx *bolt.Tx) error { -// b, err := pr.retrieveUserBucket(tx, userID, patID, repoerr.ErrViewEntity) -// if err != nil { -// return err -// } -// secretHash = string(b.Get([]byte(keySecret))) -// revoked = bytesToBoolean(b.Get([]byte(keyRevoked))) -// expiresAt := bytesToTime(b.Get([]byte(keyExpiresAt))) -// expired = time.Now().After(expiresAt) -// return nil -// }); err != nil { -// return "", true, true, err -// } -// return secretHash, revoked, expired, nil -// } - -// func (pr *patRepo) UpdateName(ctx context.Context, userID, patID, name string) (auth.PAT, error) { -// return pr.updatePATField(ctx, userID, patID, nameKey, []byte(name)) -// } - -// func (pr *patRepo) UpdateDescription(ctx context.Context, userID, patID, description string) (auth.PAT, error) { -// return pr.updatePATField(ctx, userID, patID, descriptionKey, []byte(description)) -// } - -// func (pr *patRepo) UpdateTokenHash(ctx context.Context, userID, patID, tokenHash string, expiryAt time.Time) (auth.PAT, error) { -// prefix := []byte(patID + keySeparator) -// kv := map[string][]byte{} -// if err := pr.db.Update(func(tx *bolt.Tx) error { -// b, err := pr.retrieveUserBucket(tx, userID, patID, repoerr.ErrUpdateEntity) -// if err != nil { -// return err -// } -// if err := b.Put([]byte(patID+keySeparator+secretKey), []byte(tokenHash)); err != nil { -// return errors.Wrap(repoerr.ErrUpdateEntity, err) -// } -// if err := b.Put([]byte(patID+keySeparator+expiresAtKey), timeToBytes(expiryAt)); err != nil { -// return errors.Wrap(repoerr.ErrUpdateEntity, err) -// } -// if err := b.Put([]byte(patID+keySeparator+updatedAtKey), timeToBytes(time.Now())); err != nil { -// return errors.Wrap(repoerr.ErrUpdateEntity, err) -// } -// c := b.Cursor() -// for k, v := c.Seek(prefix); k != nil && bytes.HasPrefix(k, prefix); k, v = c.Next() { -// kv[string(k)] = v -// } -// return nil -// }); err != nil { -// return auth.PAT{}, err -// } -// return keyValueToPAT(kv) -// } - -// func (pr *patRepo) RetrieveAll(ctx context.Context, userID string, pm auth.PATSPageMeta) (auth.PATSPage, error) { -// prefix := []byte(userID + keySeparator + patKey + keySeparator) - -// patIDs := []string{} -// if err := pr.db.View(func(tx *bolt.Tx) error { -// b, err := pr.retrieveRootBucket(tx) -// if err != nil { -// return errors.Wrap(repoerr.ErrViewEntity, err) -// } -// c := b.Cursor() -// for k, v := c.Seek(prefix); k != nil && bytes.HasPrefix(k, prefix); k, v = c.Next() { -// if v != nil { -// patIDs = append(patIDs, string(v)) -// } -// } -// return nil -// }); err != nil { -// return auth.PATSPage{}, err -// } - -// total := len(patIDs) - -// var pats []auth.PAT - -// patsPage := auth.PATSPage{ -// Total: uint64(total), -// Limit: pm.Limit, -// Offset: pm.Offset, -// PATS: pats, -// } - -// if int(pm.Offset) >= total { -// return patsPage, nil -// } - -// aLimit := pm.Limit -// if rLimit := total - int(pm.Offset); int(pm.Limit) > rLimit { -// aLimit = uint64(rLimit) -// } - -// for i := pm.Offset; i < pm.Offset+aLimit; i++ { -// if int(i) < total { -// pat, err := pr.Retrieve(ctx, userID, patIDs[i]) -// if err != nil { -// return patsPage, err -// } -// patsPage.PATS = append(patsPage.PATS, pat) -// } -// } - -// return patsPage, nil -// } - -// func (pr *patRepo) Revoke(ctx context.Context, userID, patID string) error { -// if err := pr.db.Update(func(tx *bolt.Tx) error { -// b, err := pr.retrieveUserBucket(tx, userID, patID, repoerr.ErrUpdateEntity) -// if err != nil { -// return err -// } -// if err := b.Put([]byte(patID+keySeparator+revokedKey), revokedValue); err != nil { -// return errors.Wrap(repoerr.ErrUpdateEntity, err) -// } -// if err := b.Put([]byte(patID+keySeparator+revokedAtKey), timeToBytes(time.Now())); err != nil { -// return errors.Wrap(repoerr.ErrUpdateEntity, err) -// } -// return nil -// }); err != nil { -// return err -// } -// return nil -// } - -// func (pr *patRepo) Reactivate(ctx context.Context, userID, patID string) error { -// if err := pr.db.Update(func(tx *bolt.Tx) error { -// b, err := pr.retrieveUserBucket(tx, userID, patID, repoerr.ErrUpdateEntity) -// if err != nil { -// return err -// } -// if err := b.Put([]byte(patID+keySeparator+revokedKey), activateValue); err != nil { -// return errors.Wrap(repoerr.ErrUpdateEntity, err) -// } -// if err := b.Put([]byte(patID+keySeparator+revokedAtKey), []byte{}); err != nil { -// return errors.Wrap(repoerr.ErrUpdateEntity, err) -// } -// return nil -// }); err != nil { -// return err -// } -// return nil -// } - -// func (pr *patRepo) Remove(ctx context.Context, userID, patID string) error { -// prefix := []byte(patID + keySeparator) -// idxKey := []byte(userID + keySeparator + patKey + keySeparator + patID) -// if err := pr.db.Update(func(tx *bolt.Tx) error { -// b, err := pr.retrieveUserBucket(tx, userID, patID, repoerr.ErrRemoveEntity) -// if err != nil { -// return err -// } -// c := b.Cursor() -// for k, _ := c.Seek(prefix); k != nil && bytes.HasPrefix(k, prefix); k, _ = c.Next() { -// if err := b.Delete(k); err != nil { -// return errors.Wrap(repoerr.ErrRemoveEntity, err) -// } -// } -// rb, err := pr.retrieveRootBucket(tx) -// if err != nil { -// return err -// } -// if err := rb.Delete(idxKey); err != nil { -// return errors.Wrap(repoerr.ErrRemoveEntity, err) -// } -// return nil -// }); err != nil { -// return err -// } - -// return nil -// } - -// func (pr *patRepo) AddScopeEntry(ctx context.Context, userID, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (auth.Scope, error) { -// prefix := []byte(patID + keySeparator + scopeKey) -// rKV := make(map[string][]byte) -// if err := pr.db.Update(func(tx *bolt.Tx) error { -// b, err := pr.retrieveUserBucket(tx, userID, patID, repoerr.ErrCreateEntity) -// if err != nil { -// return err -// } -// kv, err := scopeEntryToKeyValue(platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) -// if err != nil { -// return err -// } -// for key, value := range kv { -// fullKey := []byte(patID + keySeparator + key) -// if err := b.Put(fullKey, value); err != nil { -// return errors.Wrap(repoerr.ErrCreateEntity, err) -// } -// } - -// c := b.Cursor() -// for k, v := c.Seek(prefix); k != nil && bytes.HasPrefix(k, prefix); k, v = c.Next() { -// rKV[string(k)] = v -// } -// return nil -// }); err != nil { -// return auth.Scope{}, err -// } - -// return parseKeyValueToScope(rKV) -// } - -// func (pr *patRepo) RemoveScopeEntry(ctx context.Context, userID, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (auth.Scope, error) { -// if len(entityIDs) == 0 { -// return auth.Scope{}, repoerr.ErrMalformedEntity -// } -// prefix := []byte(patID + keySeparator + scopeKey) -// rKV := make(map[string][]byte) -// if err := pr.db.Update(func(tx *bolt.Tx) error { -// b, err := pr.retrieveUserBucket(tx, userID, patID, repoerr.ErrRemoveEntity) -// if err != nil { -// return err -// } -// kv, err := scopeEntryToKeyValue(platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) -// if err != nil { -// return err -// } -// for key := range kv { -// fullKey := []byte(patID + keySeparator + key) -// if err := b.Delete(fullKey); err != nil { -// return errors.Wrap(repoerr.ErrRemoveEntity, err) -// } -// } -// c := b.Cursor() -// for k, v := c.Seek(prefix); k != nil && bytes.HasPrefix(k, prefix); k, v = c.Next() { -// rKV[string(k)] = v -// } -// return nil -// }); err != nil { -// return auth.Scope{}, err -// } -// return parseKeyValueToScope(rKV) -// } - -// func (pr *patRepo) CheckScopeEntry(ctx context.Context, userID, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) error { -// return pr.db.Update(func(tx *bolt.Tx) error { -// b, err := pr.retrieveUserBucket(tx, userID, patID, repoerr.ErrViewEntity) -// if err != nil { -// return errors.Wrap(repoerr.ErrViewEntity, err) -// } -// srootKey, err := scopeRootKey(platformEntityType, optionalDomainID, optionalDomainEntityType, operation) -// if err != nil { -// return errors.Wrap(repoerr.ErrViewEntity, err) -// } - -// rootKey := patID + keySeparator + srootKey -// if value := b.Get([]byte(rootKey)); bytes.Equal(value, anyIDValue) { -// return nil -// } -// for _, entity := range entityIDs { -// value := b.Get([]byte(rootKey + keySeparator + entity)) -// if !bytes.Equal(value, entityValue) { -// return repoerr.ErrNotFound -// } -// } -// return nil -// }) -// } - -// func (pr *patRepo) RemoveAllScopeEntry(ctx context.Context, userID, patID string) error { -// return nil -// } - -// func (pr *patRepo) updatePATField(_ context.Context, userID, patID, key string, value []byte) (auth.PAT, error) { -// prefix := []byte(patID + keySeparator) -// kv := map[string][]byte{} -// if err := pr.db.Update(func(tx *bolt.Tx) error { -// b, err := pr.retrieveUserBucket(tx, userID, patID, repoerr.ErrUpdateEntity) -// if err != nil { -// return err -// } -// if err := b.Put([]byte(patID+keySeparator+key), value); err != nil { -// return errors.Wrap(repoerr.ErrUpdateEntity, err) -// } -// if err := b.Put([]byte(patID+keySeparator+updatedAtKey), timeToBytes(time.Now())); err != nil { -// return errors.Wrap(repoerr.ErrUpdateEntity, err) -// } -// c := b.Cursor() -// for k, v := c.Seek(prefix); k != nil && bytes.HasPrefix(k, prefix); k, v = c.Next() { -// kv[string(k)] = v -// } -// return nil -// }); err != nil { -// return auth.PAT{}, err -// } -// return keyValueToPAT(kv) -// } - -// func (pr *patRepo) createUserBucket(rootBucket *bolt.Bucket, userID string) (*bolt.Bucket, error) { -// userBucket, err := rootBucket.CreateBucketIfNotExists([]byte(userID)) -// if err != nil { -// return nil, errors.Wrap(repoerr.ErrCreateEntity, fmt.Errorf("failed to retrieve or create bucket for user %s : %w", userID, err)) -// } - -// return userBucket, nil -// } - -// func (pr *patRepo) retrieveUserBucket(tx *bolt.Tx, userID, patID string, wrap error) (*bolt.Bucket, error) { -// rootBucket, err := pr.retrieveRootBucket(tx) -// if err != nil { -// return nil, errors.Wrap(wrap, err) -// } - -// vPatID := rootBucket.Get([]byte(userID + keySeparator + patKey + keySeparator + patID)) -// if vPatID == nil { -// return nil, repoerr.ErrNotFound -// } - -// userBucket := rootBucket.Bucket([]byte(userID)) -// if userBucket == nil { -// return nil, errors.Wrap(wrap, fmt.Errorf("user %s not found", userID)) -// } -// return userBucket, nil -// } - -// func (pr *patRepo) retrieveRootBucket(tx *bolt.Tx) (*bolt.Bucket, error) { -// rootBucket := tx.Bucket([]byte(pr.bucketName)) -// if rootBucket == nil { -// return nil, fmt.Errorf("bucket %s not found", pr.bucketName) -// } -// return rootBucket, nil -// } - -// func patToKeyValue(pat auth.PAT) (map[string][]byte, error) { -// kv := map[string][]byte{ -// idKey: []byte(pat.ID), -// userKey: []byte(pat.User), -// nameKey: []byte(pat.Name), -// descriptionKey: []byte(pat.Description), -// secretKey: []byte(pat.Secret), -// issuedAtKey: timeToBytes(pat.IssuedAt), -// expiresAtKey: timeToBytes(pat.ExpiresAt), -// updatedAtKey: timeToBytes(pat.UpdatedAt), -// lastUsedAtKey: timeToBytes(pat.LastUsedAt), -// revokedKey: booleanToBytes(pat.Revoked), -// revokedAtKey: timeToBytes(pat.RevokedAt), -// } -// scopeKV, err := scopeToKeyValue(pat.Scope) -// if err != nil { -// return nil, err -// } -// for k, v := range scopeKV { -// kv[k] = v -// } -// return kv, nil -// } - -// func scopeToKeyValue(scope auth.Scope) (map[string][]byte, error) { -// kv := map[string][]byte{} -// for opType, scopeValue := range scope.Users { -// tempKV, err := scopeEntryToKeyValue(auth.PlatformUsersScope, "", auth.DomainNullScope, opType, scopeValue.Values()...) -// if err != nil { -// return nil, err -// } -// for k, v := range tempKV { -// kv[k] = v -// } -// } -// for opType, scopeValue := range scope.Dashboard { -// tempKV, err := scopeEntryToKeyValue(auth.PlatformDashBoardScope, "", auth.DomainNullScope, opType, scopeValue.Values()...) -// if err != nil { -// return nil, err -// } -// for k, v := range tempKV { -// kv[k] = v -// } -// } -// for opType, scopeValue := range scope.Messaging { -// tempKV, err := scopeEntryToKeyValue(auth.PlatformMesagingScope, "", auth.DomainNullScope, opType, scopeValue.Values()...) -// if err != nil { -// return nil, err -// } -// for k, v := range tempKV { -// kv[k] = v -// } -// } -// for domainID, domainScope := range scope.Domains { -// for opType, scopeValue := range domainScope.DomainManagement { -// tempKV, err := scopeEntryToKeyValue(auth.PlatformDomainsScope, domainID, auth.DomainManagementScope, opType, scopeValue.Values()...) -// if err != nil { -// return nil, errors.Wrap(repoerr.ErrCreateEntity, err) -// } -// for k, v := range tempKV { -// kv[k] = v -// } -// } -// for entityType, scope := range domainScope.Entities { -// for opType, scopeValue := range scope { -// tempKV, err := scopeEntryToKeyValue(auth.PlatformDomainsScope, domainID, entityType, opType, scopeValue.Values()...) -// if err != nil { -// return nil, errors.Wrap(repoerr.ErrCreateEntity, err) -// } -// for k, v := range tempKV { -// kv[k] = v -// } -// } -// } -// } -// return kv, nil -// } - -// func scopeEntryToKeyValue(platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (map[string][]byte, error) { -// if len(entityIDs) == 0 { -// return nil, repoerr.ErrMalformedEntity -// } - -// rootKey, err := scopeRootKey(platformEntityType, optionalDomainID, optionalDomainEntityType, operation) -// if err != nil { -// return nil, err -// } -// if len(entityIDs) == 1 && entityIDs[0] == anyID { -// return map[string][]byte{rootKey: anyIDValue}, nil -// } - -// kv := map[string][]byte{rootKey: selectedIDsValue} - -// for _, entryID := range entityIDs { -// if entryID == anyID { -// return nil, repoerr.ErrMalformedEntity -// } -// kv[rootKey+keySeparator+entryID] = entityValue -// } - -// return kv, nil -// } - -// func scopeRootKey(platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType) (string, error) { -// op, err := operation.ValidString() -// if err != nil { -// return "", errors.Wrap(repoerr.ErrMalformedEntity, err) -// } - -// var rootKey strings.Builder - -// rootKey.WriteString(scopeKey) -// rootKey.WriteString(keySeparator) -// rootKey.WriteString(platformEntityType.String()) -// rootKey.WriteString(keySeparator) - -// switch platformEntityType { -// case auth.PlatformUsersScope: -// rootKey.WriteString(op) -// case auth.PlatformDashBoardScope: -// rootKey.WriteString(op) -// case auth.PlatformMesagingScope: -// rootKey.WriteString(op) -// case auth.PlatformDomainsScope: -// if optionalDomainID == "" { -// return "", fmt.Errorf("failed to add platform %s scope: invalid domain id", platformEntityType.String()) -// } -// odet, err := optionalDomainEntityType.ValidString() -// if err != nil { -// return "", errors.Wrap(repoerr.ErrMalformedEntity, err) -// } -// rootKey.WriteString(optionalDomainID) -// rootKey.WriteString(keySeparator) -// rootKey.WriteString(odet) -// rootKey.WriteString(keySeparator) -// rootKey.WriteString(op) -// default: -// return "", errors.Wrap(repoerr.ErrMalformedEntity, fmt.Errorf("invalid platform entity type %s", platformEntityType.String())) -// } - -// return rootKey.String(), nil -// } - -// func keyValueToBasicPAT(kv map[string][]byte) auth.PAT { -// var pat auth.PAT -// for k, v := range kv { -// switch { -// case strings.HasSuffix(k, keySeparator+idKey): -// pat.ID = string(v) -// case strings.HasSuffix(k, keySeparator+userKey): -// pat.User = string(v) -// case strings.HasSuffix(k, keySeparator+nameKey): -// pat.Name = string(v) -// case strings.HasSuffix(k, keySeparator+descriptionKey): -// pat.Description = string(v) -// case strings.HasSuffix(k, keySeparator+issuedAtKey): -// pat.IssuedAt = bytesToTime(v) -// case strings.HasSuffix(k, keySeparator+expiresAtKey): -// pat.ExpiresAt = bytesToTime(v) -// case strings.HasSuffix(k, keySeparator+updatedAtKey): -// pat.UpdatedAt = bytesToTime(v) -// case strings.HasSuffix(k, keySeparator+lastUsedAtKey): -// pat.LastUsedAt = bytesToTime(v) -// case strings.HasSuffix(k, keySeparator+revokedKey): -// pat.Revoked = bytesToBoolean(v) -// case strings.HasSuffix(k, keySeparator+revokedAtKey): -// pat.RevokedAt = bytesToTime(v) -// } -// } -// return pat -// } - -// func keyValueToPAT(kv map[string][]byte) (auth.PAT, error) { -// pat := keyValueToBasicPAT(kv) -// scope, err := parseKeyValueToScope(kv) -// if err != nil { -// return auth.PAT{}, err -// } -// pat.Scope = scope -// return pat, nil -// } - -// func parseKeyValueToScope(kv map[string][]byte) (auth.Scope, error) { -// scope := auth.Scope{ -// Domains: make(map[string]auth.DomainScope), -// } -// for key, value := range kv { -// if strings.Index(key, keySeparator+scopeKey+keySeparator) > 0 { -// keyParts := strings.Split(key, keySeparator) - -// platformEntityType, err := auth.ParsePlatformEntityType(keyParts[2]) -// if err != nil { -// return auth.Scope{}, errors.Wrap(repoerr.ErrViewEntity, err) -// } - -// switch platformEntityType { -// case auth.PlatformUsersScope: -// scope.Users, err = parseOperation(platformEntityType, scope.Users, key, keyParts, value) -// if err != nil { -// return auth.Scope{}, errors.Wrap(repoerr.ErrViewEntity, err) -// } - -// case auth.PlatformDashBoardScope: -// scope.Dashboard, err = parseOperation(platformEntityType, scope.Dashboard, key, keyParts, value) -// if err != nil { -// return auth.Scope{}, errors.Wrap(repoerr.ErrViewEntity, err) -// } - -// case auth.PlatformMesagingScope: -// scope.Messaging, err = parseOperation(platformEntityType, scope.Messaging, key, keyParts, value) -// if err != nil { -// return auth.Scope{}, errors.Wrap(repoerr.ErrViewEntity, err) -// } - -// case auth.PlatformDomainsScope: -// if len(keyParts) < 6 { -// return auth.Scope{}, fmt.Errorf("invalid scope key format: %s", key) -// } -// domainID := keyParts[3] -// if scope.Domains == nil { -// scope.Domains = make(map[string]auth.DomainScope) -// } -// if _, ok := scope.Domains[domainID]; !ok { -// scope.Domains[domainID] = auth.DomainScope{} -// } -// domainScope := scope.Domains[domainID] - -// entityType := keyParts[4] - -// switch entityType { -// case auth.DomainManagementScope.String(): -// domainScope.DomainManagement, err = parseOperation(platformEntityType, domainScope.DomainManagement, key, keyParts, value) -// if err != nil { -// return auth.Scope{}, errors.Wrap(repoerr.ErrViewEntity, err) -// } -// default: -// etype, err := auth.ParseDomainEntityType(entityType) -// if err != nil { -// return auth.Scope{}, fmt.Errorf("key %s invalid entity type %s : %w", key, entityType, err) -// } -// if domainScope.Entities == nil { -// domainScope.Entities = make(map[auth.DomainEntityType]auth.OperationScope) -// } -// if _, ok := domainScope.Entities[etype]; !ok { -// domainScope.Entities[etype] = auth.OperationScope{} -// } -// entityOperationScope := domainScope.Entities[etype] -// entityOperationScope, err = parseOperation(platformEntityType, entityOperationScope, key, keyParts, value) -// if err != nil { -// return auth.Scope{}, errors.Wrap(repoerr.ErrViewEntity, err) -// } -// domainScope.Entities[etype] = entityOperationScope -// } -// scope.Domains[domainID] = domainScope -// default: -// return auth.Scope{}, errors.Wrap(repoerr.ErrViewEntity, fmt.Errorf("invalid platform entity type : %s", platformEntityType.String())) -// } -// } -// } -// return scope, nil -// } - -// func parseOperation(platformEntityType auth.PlatformEntityType, opScope auth.OperationScope, key string, keyParts []string, value []byte) (auth.OperationScope, error) { -// if opScope == nil { -// opScope = make(map[auth.OperationType]auth.ScopeValue) -// } - -// if err := validateOperation(platformEntityType, opScope, key, keyParts, value); err != nil { -// return auth.OperationScope{}, err -// } - -// switch string(value) { -// case string(entityValue): -// opType, err := auth.ParseOperationType(keyParts[len(keyParts)-2]) -// if err != nil { -// return auth.OperationScope{}, errors.Wrap(repoerr.ErrViewEntity, err) -// } -// entityID := keyParts[len(keyParts)-1] - -// if _, oValueExists := opScope[opType]; !oValueExists { -// opScope[opType] = &auth.SelectedIDs{} -// } -// oValue := opScope[opType] -// if err := oValue.AddValues(entityID); err != nil { -// return auth.OperationScope{}, fmt.Errorf("failed to add scope key %s with entity value %v : %w", key, entityID, err) -// } -// opScope[opType] = oValue -// case string(anyIDValue): -// opType, err := auth.ParseOperationType(keyParts[len(keyParts)-1]) -// if err != nil { -// return auth.OperationScope{}, errors.Wrap(repoerr.ErrViewEntity, err) -// } -// if oValue, oValueExists := opScope[opType]; oValueExists && oValue != nil { -// if _, ok := oValue.(*auth.AnyIDs); !ok { -// return auth.OperationScope{}, fmt.Errorf("failed to add scope key %s with entity anyIDs scope value : key already initialized with different type", key) -// } -// } -// opScope[opType] = &auth.AnyIDs{} -// case string(selectedIDsValue): -// opType, err := auth.ParseOperationType(keyParts[len(keyParts)-1]) -// if err != nil { -// return auth.OperationScope{}, errors.Wrap(repoerr.ErrViewEntity, err) -// } -// oValue, oValueExists := opScope[opType] -// if oValueExists && oValue != nil { -// if _, ok := oValue.(*auth.SelectedIDs); !ok { -// return auth.OperationScope{}, fmt.Errorf("failed to add scope key %s with entity selectedIDs scope value : key already initialized with different type", key) -// } -// } -// if !oValueExists { -// opScope[opType] = &auth.SelectedIDs{} -// } -// default: -// return auth.OperationScope{}, fmt.Errorf("key %s have invalid value %v", key, value) -// } -// return opScope, nil -// } - -// func validateOperation(platformEntityType auth.PlatformEntityType, opScope auth.OperationScope, key string, keyParts []string, value []byte) error { -// expectedKeyPartsLength := 0 -// switch string(value) { -// case string(entityValue): -// switch platformEntityType { -// case auth.PlatformDomainsScope: -// expectedKeyPartsLength = 7 -// case auth.PlatformUsersScope, auth.PlatformDashBoardScope, auth.PlatformMesagingScope: -// expectedKeyPartsLength = 5 -// default: -// return fmt.Errorf("invalid platform entity type : %s", platformEntityType.String()) -// } -// case string(selectedIDsValue), string(anyIDValue): -// switch platformEntityType { -// case auth.PlatformDomainsScope: -// expectedKeyPartsLength = 6 -// case auth.PlatformUsersScope, auth.PlatformDashBoardScope, auth.PlatformMesagingScope: -// expectedKeyPartsLength = 4 -// default: -// return fmt.Errorf("invalid platform entity type : %s", platformEntityType.String()) -// } -// default: -// return fmt.Errorf("key %s have invalid value %v", key, value) -// } -// if len(keyParts) != expectedKeyPartsLength { -// return fmt.Errorf("invalid scope key format: %s", key) -// } -// return nil -// } - -// func timeToBytes(t time.Time) []byte { -// timeBytes := make([]byte, 8) -// binary.BigEndian.PutUint64(timeBytes, uint64(t.Unix())) -// return timeBytes -// } - -// func bytesToTime(b []byte) time.Time { -// timeAtSeconds := binary.BigEndian.Uint64(b) -// return time.Unix(int64(timeAtSeconds), 0) -// } - -// func booleanToBytes(b bool) []byte { -// if b { -// return []byte{1} -// } -// return []byte{0} -// } - -// func bytesToBoolean(b []byte) bool { -// if len(b) > 1 || b[0] != activateValue[0] { -// return true -// } -// return false -// } diff --git a/auth/cache/pat.go b/auth/cache/pat.go index 122c23da68..b0b9f2ecb5 100644 --- a/auth/cache/pat.go +++ b/auth/cache/pat.go @@ -25,7 +25,7 @@ func NewPatsCache(client *redis.Client, duration time.Duration) auth.Cache { } } -func (pc *patCache) Save(ctx context.Context, patSecret, patID string, pat auth.PAT) error { +func (pc *patCache) Save(ctx context.Context, patID string, pat auth.PAT) error { if err := pc.client.Set(ctx, patID, pat, pc.duration).Err(); err != nil { return errors.Wrap(repoerr.ErrCreateEntity, err) } diff --git a/auth/mocks/cache.go b/auth/mocks/cache.go new file mode 100644 index 0000000000..ae1665772d --- /dev/null +++ b/auth/mocks/cache.go @@ -0,0 +1,96 @@ +// Code generated by mockery v2.43.2. DO NOT EDIT. + +// Copyright (c) Abstract Machines + +package mocks + +import ( + context "context" + + auth "github.com/absmach/supermq/auth" + + mock "github.com/stretchr/testify/mock" +) + +// Cache is an autogenerated mock type for the Cache type +type Cache struct { + mock.Mock +} + +// ID provides a mock function with given fields: ctx, patID +func (_m *Cache) ID(ctx context.Context, patID string) (auth.PAT, error) { + ret := _m.Called(ctx, patID) + + if len(ret) == 0 { + panic("no return value specified for ID") + } + + var r0 auth.PAT + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (auth.PAT, error)); ok { + return rf(ctx, patID) + } + if rf, ok := ret.Get(0).(func(context.Context, string) auth.PAT); ok { + r0 = rf(ctx, patID) + } else { + r0 = ret.Get(0).(auth.PAT) + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, patID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Remove provides a mock function with given fields: ctx, patID +func (_m *Cache) Remove(ctx context.Context, patID string) error { + ret := _m.Called(ctx, patID) + + if len(ret) == 0 { + panic("no return value specified for Remove") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string) error); ok { + r0 = rf(ctx, patID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Save provides a mock function with given fields: ctx, patID, scope +func (_m *Cache) Save(ctx context.Context, patID string, scope auth.PAT) error { + ret := _m.Called(ctx, patID, scope) + + if len(ret) == 0 { + panic("no return value specified for Save") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, auth.PAT) error); ok { + r0 = rf(ctx, patID, scope) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// NewCache creates a new instance of Cache. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewCache(t interface { + mock.TestingT + Cleanup(func()) +}) *Cache { + mock := &Cache{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/auth/pat.go b/auth/pat.go index b56df6be48..74ba6fa51b 100644 --- a/auth/pat.go +++ b/auth/pat.go @@ -676,7 +676,7 @@ func (s *Scope) String() string { // PAT represents Personal Access Token. type PAT struct { ID string `json:"id,omitempty"` - User string `json:"user,omitempty"` + User string `json:"user_id,omitempty"` Name string `json:"name,omitempty"` Description string `json:"description,omitempty"` Secret string `json:"secret,omitempty"` diff --git a/auth/postgres/init.go b/auth/postgres/init.go index 557be74856..90e2fb726d 100644 --- a/auth/postgres/init.go +++ b/auth/postgres/init.go @@ -71,7 +71,7 @@ func Migration() *migrate.MemoryMigrationSource { `CREATE TABLE IF NOT EXISTS pats ( id VARCHAR(36) PRIMARY KEY, name VARCHAR(254), - user_id VARCHAR(36), + user_id VARCHAR(36), description TEXT, secret TEXT, issued_at TIMESTAMP, @@ -81,11 +81,7 @@ func Migration() *migrate.MemoryMigrationSource { revoked_at TIMESTAMP, entity_type TEXT, last_used_at TIMESTAMP, - scopes_data TEXT, - allowed_ops TEXT[], - entity_ids TEXT[], - domains TEXT[], - entity_types TEXT[], + scopes_data TEXT )`, }, Down: []string{ diff --git a/auth/postgres/pat.go b/auth/postgres/pat.go index f0c760cc17..6fa03f827d 100644 --- a/auth/postgres/pat.go +++ b/auth/postgres/pat.go @@ -23,15 +23,7 @@ type dbPat struct { LastUsedAt time.Time `db:"last_used_at,omitempty"` Revoked bool `db:"revoked,omitempty"` RevokedAt time.Time `db:"revoked_at,omitempty"` - - // Scopes data stored as JSON - ScopesData string `db:"scopes_data,omitempty"` - - // Aggregated scope fields for querying - AllowedOps []string `db:"allowed_ops,omitempty"` // Combined list of all allowed operations - EntityIDs []string `db:"entity_ids,omitempty"` // Combined list of all entity IDs - Domains []string `db:"domains,omitempty"` // List of all domain IDs - EntityTypes []string `db:"entity_types,omitempty"` // List of all entity types + ScopesData string `db:"scopes_data,omitempty"` } type dbAuthPage struct { @@ -101,33 +93,16 @@ func patToDBRecords(pat auth.PAT) (dbPat, error) { DomainScopes: make(map[string]DomainScopeData), } - var allOps []string - var allEntityIDs []string - var allDomains []string - var allEntityTypes []string - if len(pat.Scope.Users) > 0 { scopeData.Users = pat.Scope.Users - ops, ids := extractOpsAndIDs(pat.Scope.Users) - allOps = append(allOps, ops...) - allEntityIDs = append(allEntityIDs, ids...) - allEntityTypes = append(allEntityTypes, "users") } if len(pat.Scope.Dashboard) > 0 { scopeData.Dashboard = pat.Scope.Dashboard - ops, ids := extractOpsAndIDs(pat.Scope.Dashboard) - allOps = append(allOps, ops...) - allEntityIDs = append(allEntityIDs, ids...) - allEntityTypes = append(allEntityTypes, "dashboard") } if len(pat.Scope.Messaging) > 0 { scopeData.Messaging = pat.Scope.Messaging - ops, ids := extractOpsAndIDs(pat.Scope.Messaging) - allOps = append(allOps, ops...) - allEntityIDs = append(allEntityIDs, ids...) - allEntityTypes = append(allEntityTypes, "messaging") } for domainID, domainScope := range pat.Scope.Domains { @@ -136,35 +111,17 @@ func patToDBRecords(pat auth.PAT) (dbPat, error) { Entities: make(map[string]auth.OperationScope), } - if len(domainScope.DomainManagement) > 0 { - ops, ids := extractOpsAndIDs(domainScope.DomainManagement) - allOps = append(allOps, ops...) - allEntityIDs = append(allEntityIDs, ids...) - allEntityTypes = append(allEntityTypes, "domain_management") - } - for entityType, ops := range domainScope.Entities { entityTypeStr, err := entityType.ValidString() if err != nil { return dbPat{}, fmt.Errorf("invalid entity type: %w", err) } dsd.Entities[entityTypeStr] = ops - - extractedOps, ids := extractOpsAndIDs(ops) - allOps = append(allOps, extractedOps...) - allEntityIDs = append(allEntityIDs, ids...) - allEntityTypes = append(allEntityTypes, entityTypeStr) } scopeData.DomainScopes[domainID] = dsd - allDomains = append(allDomains, domainID) } - allOps = uniqueStrings(allOps) - allEntityIDs = uniqueStrings(allEntityIDs) - allDomains = uniqueStrings(allDomains) - allEntityTypes = uniqueStrings(allEntityTypes) - scopesJSON, err := json.Marshal(scopeData) if err != nil { return dbPat{}, fmt.Errorf("failed to marshal scopes data: %w", err) @@ -183,10 +140,6 @@ func patToDBRecords(pat auth.PAT) (dbPat, error) { Revoked: pat.Revoked, RevokedAt: pat.RevokedAt, ScopesData: string(scopesJSON), - AllowedOps: allOps, - EntityIDs: allEntityIDs, - Domains: allDomains, - EntityTypes: allEntityTypes, }, nil } @@ -195,34 +148,6 @@ type DomainScopeData struct { Entities map[string]auth.OperationScope `json:"entities,omitempty"` } -func extractOpsAndIDs(ops auth.OperationScope) ([]string, []string) { - var operations []string - var entityIDs []string - - for op, scopeValue := range ops { - opStr, err := op.ValidString() - if err != nil { - continue - } - operations = append(operations, opStr) - entityIDs = append(entityIDs, scopeValue.Values()...) - } - - return operations, entityIDs -} - -func uniqueStrings(strs []string) []string { - seen := make(map[string]struct{}) - var result []string - for _, str := range strs { - if _, exists := seen[str]; !exists { - seen[str] = struct{}{} - result = append(result, str) - } - } - return result -} - func toDBAuthPage(user string, pm auth.PATSPageMeta) dbAuthPage { return dbAuthPage{ Limit: pm.Limit, diff --git a/auth/postgres/repo.go b/auth/postgres/repo.go index 52c5e0e49e..f50ba9e699 100644 --- a/auth/postgres/repo.go +++ b/auth/postgres/repo.go @@ -5,7 +5,6 @@ package postgres import ( "context" - "fmt" "time" "github.com/absmach/supermq/auth" @@ -38,11 +37,11 @@ func (pr *patRepo) Save(ctx context.Context, pat auth.PAT) error { INSERT INTO pats ( id, user_id, name, description, secret, issued_at, expires_at, updated_at, last_used_at, revoked, revoked_at, - scopes_data, allowed_ops, entity_ids, domains, entity_types, metadata + scopes_data ) VALUES ( :id, :user_id, :name, :description, :secret, :issued_at, :expires_at, :updated_at, :last_used_at, :revoked, :revoked_at, - :scopes_data, :allowed_ops, :entity_ids, :domains, :entity_types, :metadata + :scopes_data )` row, err := pr.db.NamedQueryContext(ctx, q, record) @@ -51,7 +50,7 @@ func (pr *patRepo) Save(ctx context.Context, pat auth.PAT) error { } defer row.Close() - if err := pr.cache.Save(ctx, pat.Secret, pat.ID, pat); err != nil { + if err := pr.cache.Save(ctx, pat.ID, pat); err != nil { return errors.Wrap(repoerr.ErrCreateEntity, err) } @@ -63,45 +62,23 @@ func (pr *patRepo) Retrieve(ctx context.Context, userID, patID string) (auth.PAT return pat, nil } - q := ` - SELECT - id, user_id, name, description, secret, issued_at, expires_at, - updated_at, last_used_at, revoked, revoked_at, - scopes_data,allowed_ops, entity_ids, domains, entity_types, metadata - FROM pats WHERE user_id = $1 AND id = $2` - - rows, err := pr.db.QueryContext(ctx, q, userID, patID) + pat, err := pr.retrieveFromDB(ctx, userID, patID) if err != nil { return auth.PAT{}, errors.Wrap(repoerr.ErrViewEntity, err) } - defer rows.Close() - - var record dbPat - if rows.Next() { - if err := rows.Scan(&record); err != nil { - return auth.PAT{}, errors.Wrap(repoerr.ErrViewEntity, err) - } - pat, err := toAuthPat(record) - if err != nil { - return auth.PAT{}, err - } - // Save to cache - if err := pr.cache.Save(ctx, pat.Secret, pat.ID, pat); err != nil { - return auth.PAT{}, errors.Wrap(repoerr.ErrViewEntity, err) - } - return pat, nil + if err := pr.cache.Save(ctx, pat.ID, pat); err != nil { + return auth.PAT{}, err } - return auth.PAT{}, repoerr.ErrNotFound + return pat, nil } func (pr *patRepo) RetrieveAll(ctx context.Context, userID string, pm auth.PATSPageMeta) (auth.PATSPage, error) { q := ` SELECT - p.id, p.user_id, p.name, p.description, p.secret, p.issued_at, p.expires_at, - p.updated_at, p.last_used_at, p.revoked, p.revoked_at, - p.scopes_data, p.allowed_ops, p.entity_ids, p.domains, p.entity_types, p.metadata - FROM pats p WHERE user_id :user_id + p.id, p.user_id, p.name, p.description, p.issued_at, p.expires_at, + p.updated_at, p.revoked, p.revoked_at + FROM pats p WHERE user_id = :user_id ORDER BY issued_at DESC LIMIT :limit OFFSET :offset` @@ -113,20 +90,38 @@ func (pr *patRepo) RetrieveAll(ctx context.Context, userID string, pm auth.PATSP } defer rows.Close() + type data struct { + ID string `db:"id,omitempty"` + User string `db:"user_id,omitempty"` + Name string `db:"name,omitempty"` + Description string `db:"description,omitempty"` + IssuedAt time.Time `db:"issued_at,omitempty"` + ExpiresAt time.Time `db:"expires_at,omitempty"` + UpdatedAt time.Time `db:"updated_at,omitempty"` + Revoked bool `db:"revoked,omitempty"` + RevokedAt time.Time `db:"revoked_at,omitempty"` + } + var items []auth.PAT for rows.Next() { - var pat auth.PAT + var pat data if err := rows.StructScan(&pat); err != nil { return auth.PATSPage{}, errors.Wrap(repoerr.ErrViewEntity, err) } - items = append(items, pat) + items = append(items, auth.PAT{ + ID: pat.ID, + User: pat.User, + Name: pat.Name, + Description: pat.Description, + IssuedAt: pat.IssuedAt, + ExpiresAt: pat.ExpiresAt, + UpdatedAt: pat.UpdatedAt, + Revoked: pat.Revoked, + RevokedAt: pat.RevokedAt, + }) } - cq := fmt.Sprintf(`SELECT COUNT(*) AS total_count - FROM ( - SELECT DISTINCT p.id FROM pats p %s - ) AS subquery; - `, q) + cq := `SELECT COUNT(*) FROM pats p WHERE user_id = :user_id` total, err := postgres.Total(ctx, pr.db, cq, dbPage) if err != nil { @@ -194,7 +189,16 @@ func (pr *patRepo) UpdateName(ctx context.Context, userID, patID, name string) ( return auth.PAT{}, repoerr.ErrNotFound } - return pr.Retrieve(ctx, userID, patID) + pat, err := pr.retrieveFromDB(ctx, userID, patID) + if err != nil { + return auth.PAT{}, err + } + + if err := pr.cache.Save(ctx, patID, pat); err != nil { + return auth.PAT{}, errors.Wrap(repoerr.ErrUpdateEntity, err) + } + + return pat, nil } func (pr *patRepo) UpdateDescription(ctx context.Context, userID, patID, description string) (auth.PAT, error) { @@ -216,7 +220,16 @@ func (pr *patRepo) UpdateDescription(ctx context.Context, userID, patID, descrip return auth.PAT{}, repoerr.ErrNotFound } - return pr.Retrieve(ctx, userID, patID) + pat, err := pr.retrieveFromDB(ctx, userID, patID) + if err != nil { + return auth.PAT{}, err + } + + if err := pr.cache.Save(ctx, patID, pat); err != nil { + return auth.PAT{}, errors.Wrap(repoerr.ErrUpdateEntity, err) + } + + return pat, nil } func (pr *patRepo) UpdateTokenHash(ctx context.Context, userID, patID, tokenHash string, expiryAt time.Time) (auth.PAT, error) { @@ -238,12 +251,12 @@ func (pr *patRepo) UpdateTokenHash(ctx context.Context, userID, patID, tokenHash return auth.PAT{}, repoerr.ErrNotFound } - pat, err := pr.Retrieve(ctx, userID, patID) + pat, err := pr.retrieveFromDB(ctx, userID, patID) if err != nil { return auth.PAT{}, err } - if err := pr.cache.Save(ctx, tokenHash, patID, pat); err != nil { + if err := pr.cache.Save(ctx, patID, pat); err != nil { return auth.PAT{}, errors.Wrap(repoerr.ErrUpdateEntity, err) } @@ -300,7 +313,7 @@ func (pr *patRepo) Reactivate(ctx context.Context, userID, patID string) error { return err } - if err := pr.cache.Save(ctx, pat.Secret, patID, pat); err != nil { + if err := pr.cache.Save(ctx, patID, pat); err != nil { return errors.Wrap(repoerr.ErrUpdateEntity, err) } @@ -331,7 +344,7 @@ func (pr *patRepo) Remove(ctx context.Context, userID, patID string) error { } func (pr *patRepo) AddScopeEntry(ctx context.Context, userID, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (auth.Scope, error) { - pat, err := pr.Retrieve(ctx, userID, patID) + pat, err := pr.retrieveFromDB(ctx, userID, patID) if err != nil { return auth.Scope{}, err } @@ -348,10 +361,6 @@ func (pr *patRepo) AddScopeEntry(ctx context.Context, userID, patID string, plat q := ` UPDATE pats SET scopes_data = :scopes_data, - allowed_ops = :allowed_ops, - entity_ids = :entity_ids, - domains = :domains, - entity_types = :entity_types, updated_at = :updated_at WHERE user_id = :user_id AND id = :id` @@ -368,7 +377,7 @@ func (pr *patRepo) AddScopeEntry(ctx context.Context, userID, patID string, plat return auth.Scope{}, repoerr.ErrNotFound } - if err := pr.cache.Save(ctx, pat.Secret, pat.ID, pat); err != nil { + if err := pr.cache.Save(ctx, pat.ID, pat); err != nil { return auth.Scope{}, errors.Wrap(repoerr.ErrUpdateEntity, err) } @@ -393,10 +402,6 @@ func (pr *patRepo) RemoveScopeEntry(ctx context.Context, userID, patID string, p q := ` UPDATE pats SET scopes_data = :scopes_data, - allowed_ops = :allowed_ops, - entity_ids = :entity_ids, - domains = :domains, - entity_types = :entity_types, updated_at = :updated_at WHERE user_id = :user_id AND id = :id` @@ -413,7 +418,7 @@ func (pr *patRepo) RemoveScopeEntry(ctx context.Context, userID, patID string, p return auth.Scope{}, repoerr.ErrNotFound } - if err := pr.cache.Save(ctx, pat.Secret, pat.ID, pat); err != nil { + if err := pr.cache.Save(ctx, pat.ID, pat); err != nil { return auth.Scope{}, errors.Wrap(repoerr.ErrUpdateEntity, err) } @@ -455,10 +460,6 @@ func (pr *patRepo) RemoveAllScopeEntry(ctx context.Context, userID, patID string q := ` UPDATE pats SET scopes_data = :scopes_data, - allowed_ops = :allowed_ops, - entity_ids = :entity_ids, - domains = :domains, - entity_types = :entity_types, updated_at = :updated_at WHERE user_id = :user_id AND id = :id` @@ -475,9 +476,47 @@ func (pr *patRepo) RemoveAllScopeEntry(ctx context.Context, userID, patID string return repoerr.ErrNotFound } - if err := pr.cache.Save(ctx, pat.Secret, pat.ID, pat); err != nil { + if err := pr.cache.Save(ctx, pat.ID, pat); err != nil { return errors.Wrap(repoerr.ErrUpdateEntity, err) } return nil } + +func (pr *patRepo) retrieveFromDB(ctx context.Context, userID, patID string) (auth.PAT, error) { + q := ` + SELECT + id, user_id, name, description, secret, issued_at, expires_at, + updated_at, last_used_at, revoked, revoked_at, + scopes_data + FROM pats WHERE user_id = $1 AND id = $2` + + rows, err := pr.db.QueryContext(ctx, q, userID, patID) + if err != nil { + return auth.PAT{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + defer rows.Close() + + var record dbPat + if rows.Next() { + if err := rows.Scan( + &record.ID, + &record.User, + &record.Name, + &record.Description, + &record.Secret, + &record.IssuedAt, + &record.ExpiresAt, + &record.UpdatedAt, + &record.LastUsedAt, + &record.Revoked, + &record.RevokedAt, + &record.ScopesData, + ); err != nil { + return auth.PAT{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + return toAuthPat(record) + } + + return auth.PAT{}, repoerr.ErrNotFound +} diff --git a/auth/service.go b/auth/service.go index 4f57a7846a..9217eaa5cb 100644 --- a/auth/service.go +++ b/auth/service.go @@ -6,7 +6,6 @@ package auth import ( "context" "encoding/base64" - "fmt" "math/rand" "strings" "time" @@ -98,7 +97,7 @@ type Service interface { //go:generate mockery --name Cache --output=./mocks --filename cache.go --quiet --note "Copyright (c) Abstract Machines" type Cache interface { - Save(ctx context.Context, patSecret, patID string, scope PAT) error + Save(ctx context.Context, patID string, scope PAT) error ID(ctx context.Context, patID string) (PAT, error) @@ -490,13 +489,12 @@ func (svc service) CreatePAT(ctx context.Context, token, name, description strin User: key.User, Name: name, Description: description, - Secret: hash, + Secret: patPrefix + patSecretSeparator + hash, IssuedAt: now, ExpiresAt: now.Add(duration), Scope: scope, } if err := svc.pats.Save(ctx, pat); err != nil { - fmt.Printf("err is %+v\n", err) return PAT{}, errors.Wrap(errCreatePAT, err) } pat.Secret = secret @@ -530,9 +528,6 @@ func (svc service) UpdatePATDescription(ctx context.Context, token, patID, descr func (svc service) RetrievePAT(ctx context.Context, userID, patID string) (PAT, error) { pat, err := svc.pats.Retrieve(ctx, userID, patID) if err != nil { - fmt.Printf("userID is %+v\n", userID) - fmt.Printf("patID is %+v\n", patID) - fmt.Printf("error is %+v\n", err) return PAT{}, errors.Wrap(errRetrievePAT, err) } return pat, nil @@ -647,9 +642,6 @@ func (svc service) IdentifyPAT(ctx context.Context, secret string) (PAT, error) if err != nil { return PAT{}, errors.Wrap(svcerr.ErrAuthentication, err) } - fmt.Printf("revoked is %+v\n", revoked) - fmt.Printf("secretHash is %+v\n", secretHash) - fmt.Printf("expired is %+v\n", expired) if revoked { return PAT{}, errors.Wrap(svcerr.ErrAuthentication, errRevokedPAT) } diff --git a/auth/service_test.go b/auth/service_test.go index 18115fd4c7..f3c8b0d97b 100644 --- a/auth/service_test.go +++ b/auth/service_test.go @@ -50,11 +50,13 @@ var ( pService *policymocks.Service pEvaluator *policymocks.Evaluator patsrepo *mocks.PATSRepository + cache *mocks.Cache hasher *mocks.Hasher ) func newService() (auth.Service, string) { krepo = new(mocks.KeyRepository) + cache = new(mocks.Cache) pService = new(policymocks.Service) pEvaluator = new(policymocks.Evaluator) patsrepo = new(mocks.PATSRepository) @@ -72,7 +74,7 @@ func newService() (auth.Service, string) { } token, _ := t.Issue(key) - return auth.New(krepo, patsrepo, hasher, idProvider, t, pEvaluator, pService, loginDuration, refreshDuration, invalidDuration), token + return auth.New(krepo, patsrepo, cache, hasher, idProvider, t, pEvaluator, pService, loginDuration, refreshDuration, invalidDuration), token } func TestIssue(t *testing.T) { diff --git a/go.mod b/go.mod index 6d781de1e5..e27e2f7fc0 100644 --- a/go.mod +++ b/go.mod @@ -40,7 +40,6 @@ require ( github.com/spf13/cobra v1.8.1 github.com/sqids/sqids-go v0.4.1 github.com/stretchr/testify v1.10.0 - go.etcd.io/bbolt v1.3.11 go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.59.0 go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.59.0 go.opentelemetry.io/otel v1.34.0 diff --git a/go.sum b/go.sum index e655a94303..00a2573877 100644 --- a/go.sum +++ b/go.sum @@ -427,8 +427,6 @@ github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9de github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q= -go.etcd.io/bbolt v1.3.11 h1:yGEzV1wPz2yVCLsD8ZAiGHhHVlczyC9d1rP43/VCRJ0= -go.etcd.io/bbolt v1.3.11/go.mod h1:dksAq7YMXoljX0xu6VF5DMZGbhYYoLUalEiSySYAS4I= go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A= go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.59.0 h1:rgMkmiGfix9vFJDcDi1PK8WEQP4FLQwLDfhp5ZLpFeE= diff --git a/internal/clients/bolt/bolt.go b/internal/clients/bolt/bolt.go deleted file mode 100644 index 8e2afebf97..0000000000 --- a/internal/clients/bolt/bolt.go +++ /dev/null @@ -1,83 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -package bolt - -import ( - "io/fs" - "strconv" - "time" - - "github.com/absmach/supermq/pkg/errors" - "github.com/caarlos0/env/v11" - bolt "go.etcd.io/bbolt" -) - -var ( - errConfig = errors.New("failed to load BoltDB configuration") - errConnect = errors.New("failed to connect to BoltDB database") - errInit = errors.New("failed to initialize to BoltDB database") -) - -type FileMode fs.FileMode - -func (fm *FileMode) UnmarshalText(text []byte) error { - temp, err := strconv.ParseUint(string(text), 8, 32) - if err != nil { - return err - } - *fm = FileMode(temp) - return nil -} - -// Config contains BoltDB specific parameters. -type Config struct { - FileDirPath string `env:"FILE_DIR_PATH" envDefault:"./supermq-data"` - FileName string `env:"FILE_NAME" envDefault:"supermq-pat.db"` - FileMode FileMode `env:"FILE_MODE" envDefault:"0600"` - Bucket string `env:"BUCKET" envDefault:"supermq"` - Timeout time.Duration `env:"TIMEOUT" envDefault:"0"` -} - -// Setup load configuration from environment and creates new BoltDB. -func Setup(envPrefix string, initFn func(*bolt.Tx, string) error) (*bolt.DB, error) { - return SetupDB(envPrefix, initFn) -} - -// SetupDB load configuration from environment,. -func SetupDB(envPrefix string, initFn func(*bolt.Tx, string) error) (*bolt.DB, error) { - cfg := Config{} - if err := env.ParseWithOptions(&cfg, env.Options{Prefix: envPrefix}); err != nil { - return nil, errors.Wrap(errConfig, err) - } - bdb, err := Connect(cfg, initFn) - if err != nil { - return nil, err - } - - return bdb, nil -} - -// Connect establishes connection to the BoltDB. -func Connect(cfg Config, initFn func(*bolt.Tx, string) error) (*bolt.DB, error) { - filePath := cfg.FileDirPath + "/" + cfg.FileName - db, err := bolt.Open(filePath, fs.FileMode(cfg.FileMode), nil) - if err != nil { - return nil, errors.Wrap(errConnect, err) - } - if initFn != nil { - if err := Init(db, cfg, initFn); err != nil { - return nil, err - } - } - return db, nil -} - -func Init(db *bolt.DB, cfg Config, initFn func(*bolt.Tx, string) error) error { - if err := db.Update(func(tx *bolt.Tx) error { - return initFn(tx, cfg.Bucket) - }); err != nil { - return errors.Wrap(errInit, err) - } - return nil -} diff --git a/internal/clients/bolt/doc.go b/internal/clients/bolt/doc.go deleted file mode 100644 index 24fc0f92a5..0000000000 --- a/internal/clients/bolt/doc.go +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -// Package BoltDB contains the domain concept definitions needed to support -// Supermq BoltDB database functionality. -// -// It provides the abstraction of the BoltDB database service, which is used -// to configure, setup and connect to the BoltDB database. -package bolt diff --git a/pkg/groups/events/consumer/streams.go b/pkg/groups/events/consumer/streams.go index 6e1b49b68c..fdf2e49a31 100644 --- a/pkg/groups/events/consumer/streams.go +++ b/pkg/groups/events/consumer/streams.go @@ -180,7 +180,6 @@ func (es *eventHandler) removeParentGroupHandler(ctx context.Context, data map[s if err != nil { return errors.Wrap(errRemoveParentGroupEvent, err) } - if err := es.repo.UnassignParentGroup(ctx, g.Parent, id); err != nil { return errors.Wrap(errRemoveParentGroupEvent, err) }