Skip to content

Commit

Permalink
refactor: add superadmin check on retrieve domain
Browse files Browse the repository at this point in the history
Signed-off-by: Felix Gateru <[email protected]>
  • Loading branch information
felixgateru authored and dborovcanin committed Feb 26, 2025
1 parent 6026c3b commit 3c8457b
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 18 deletions.
6 changes: 3 additions & 3 deletions domains/api/http/endpoint_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1149,7 +1149,7 @@ func TestSendInvitation(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
if tc.token == validToken {
tc.session = authn.Session{UserID: userID, DomainID: domainID, DomainUserID: domainID + "_" + userID}
tc.session = authn.Session{UserID: userID, DomainID: domainID}
}
authnCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authnErr)
repoCall := svc.On("SendInvitation", mock.Anything, tc.session, mock.Anything).Return(tc.svcErr)
Expand Down Expand Up @@ -1382,7 +1382,7 @@ func TestViewInvitation(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
if tc.token == validToken {
tc.session = authn.Session{UserID: userID, DomainID: domainID, DomainUserID: domainID + "_" + userID}
tc.session = authn.Session{UserID: userID, DomainID: domainID}
}
authnCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authnErr)
repoCall := svc.On("ViewInvitation", mock.Anything, tc.session, tc.userID, tc.domainID).Return(domains.Invitation{}, tc.svcErr)
Expand Down Expand Up @@ -1476,7 +1476,7 @@ func TestDeleteInvitation(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
if tc.token == validToken {
tc.session = authn.Session{UserID: userID, DomainID: domainID, DomainUserID: domainID + "_" + userID}
tc.session = authn.Session{UserID: userID, DomainID: domainID}
}
authnCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authnErr)
repoCall := svc.On("DeleteInvitation", mock.Anything, tc.session, tc.userID, tc.domainID).Return(tc.svcErr)
Expand Down
21 changes: 13 additions & 8 deletions domains/middleware/authorization.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ package middleware
import (
"context"

"github.com/absmach/supermq/auth"
"github.com/absmach/supermq/domains"
"github.com/absmach/supermq/pkg/authn"
"github.com/absmach/supermq/pkg/authz"
Expand Down Expand Up @@ -58,6 +57,16 @@ func (am *authorizationMiddleware) CreateDomain(ctx context.Context, session aut
}

func (am *authorizationMiddleware) RetrieveDomain(ctx context.Context, session authn.Session, id string) (domains.Domain, error) {
if err := am.authz.Authorize(ctx, authz.PolicyReq{
Subject: session.UserID,
SubjectType: policies.UserType,
Permission: policies.AdminPermission,
ObjectType: policies.PlatformType,
Object: policies.SuperMQObject,
}); err == nil {
session.SuperAdmin = true
return am.svc.RetrieveDomain(ctx, session, id)
}

Check warning on line 69 in domains/middleware/authorization.go

View check run for this annotation

Codecov / codecov/patch

domains/middleware/authorization.go#L60-L69

Added lines #L60 - L69 were not covered by tests
if err := am.authorize(ctx, domains.OpRetrieveDomain, authz.PolicyReq{
Subject: session.UserID,

Check warning on line 71 in domains/middleware/authorization.go

View check run for this annotation

Codecov / codecov/patch

domains/middleware/authorization.go#L71

Added line #L71 was not covered by tests
SubjectType: policies.UserType,
Expand Down Expand Up @@ -141,8 +150,7 @@ func (am *authorizationMiddleware) ListDomains(ctx context.Context, session auth
}

func (am *authorizationMiddleware) SendInvitation(ctx context.Context, session authn.Session, invitation domains.Invitation) (err error) {
domainUserId := auth.EncodeDomainUserID(invitation.DomainID, invitation.InviteeUserID)
if err := am.extAuthorize(ctx, domainUserId, policies.MembershipPermission, policies.DomainType, invitation.DomainID); err == nil {
if err := am.extAuthorize(ctx, session.UserID, policies.MembershipPermission, policies.DomainType, invitation.DomainID); err == nil {

Check warning on line 153 in domains/middleware/authorization.go

View check run for this annotation

Codecov / codecov/patch

domains/middleware/authorization.go#L153

Added line #L153 was not covered by tests
// return error if the user is already a member of the domain
return errors.Wrap(svcerr.ErrConflict, ErrMemberExist)
}
Expand All @@ -155,7 +163,6 @@ func (am *authorizationMiddleware) SendInvitation(ctx context.Context, session a
}

func (am *authorizationMiddleware) ViewInvitation(ctx context.Context, session authn.Session, inviteeUserID, domain string) (invitation domains.Invitation, err error) {
session.DomainUserID = auth.EncodeDomainUserID(session.DomainID, session.UserID)
if session.UserID != inviteeUserID {
if err := am.checkAdmin(ctx, session); err != nil {
return domains.Invitation{}, err
Expand All @@ -166,7 +173,6 @@ func (am *authorizationMiddleware) ViewInvitation(ctx context.Context, session a
}

func (am *authorizationMiddleware) ListInvitations(ctx context.Context, session authn.Session, page domains.InvitationPageMeta) (invs domains.InvitationPage, err error) {
session.DomainUserID = auth.EncodeDomainUserID(session.DomainID, session.UserID)
if err := am.extAuthorize(ctx, session.UserID, policies.AdminPermission, policies.PlatformType, policies.SuperMQObject); err == nil {
session.SuperAdmin = true
page.DomainID = ""
Expand All @@ -175,7 +181,7 @@ func (am *authorizationMiddleware) ListInvitations(ctx context.Context, session
if !session.SuperAdmin {
switch {
case page.DomainID != "":
if err := am.extAuthorize(ctx, session.DomainUserID, policies.AdminPermission, policies.DomainType, page.DomainID); err != nil {
if err := am.extAuthorize(ctx, session.UserID, policies.AdminPermission, policies.DomainType, page.DomainID); err != nil {

Check warning on line 184 in domains/middleware/authorization.go

View check run for this annotation

Codecov / codecov/patch

domains/middleware/authorization.go#L184

Added line #L184 was not covered by tests
return domains.InvitationPage{}, err
}
default:
Expand All @@ -195,7 +201,6 @@ func (am *authorizationMiddleware) RejectInvitation(ctx context.Context, session
}

func (am *authorizationMiddleware) DeleteInvitation(ctx context.Context, session authn.Session, inviteeUserID, domainID string) (err error) {
session.DomainUserID = auth.EncodeDomainUserID(session.DomainID, session.UserID)
if err := am.checkAdmin(ctx, session); err != nil {
return err
}
Expand All @@ -222,7 +227,7 @@ func (am *authorizationMiddleware) checkAdmin(ctx context.Context, session authn
req := smqauthz.PolicyReq{
SubjectType: policies.UserType,
SubjectKind: policies.UsersKind,
Subject: session.DomainUserID,
Subject: session.UserID,

Check warning on line 230 in domains/middleware/authorization.go

View check run for this annotation

Codecov / codecov/patch

domains/middleware/authorization.go#L230

Added line #L230 was not covered by tests
Permission: policies.AdminPermission,
ObjectType: policies.DomainType,
Object: session.DomainID,
Expand Down
10 changes: 3 additions & 7 deletions pkg/sdk/invitations_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,7 @@ func TestSendInvitation(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
if tc.token == valid {
tc.session = smqauthn.Session{
UserID: tc.sendInvitationReq.InviteeUserID,
DomainID: tc.sendInvitationReq.DomainID,
DomainUserID: tc.sendInvitationReq.DomainID + "_" + tc.sendInvitationReq.InviteeUserID,
}
tc.session = smqauthn.Session{UserID: tc.sendInvitationReq.InviteeUserID, DomainID: tc.sendInvitationReq.DomainID}
}
authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr)
svcCall := svc.On("SendInvitation", mock.Anything, tc.session, tc.svcReq).Return(tc.svcErr)
Expand Down Expand Up @@ -213,7 +209,7 @@ func TestViewInvitation(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
if tc.token == valid {
tc.session = smqauthn.Session{UserID: tc.userID, DomainID: tc.domainID, DomainUserID: tc.domainID + "_" + tc.userID}
tc.session = smqauthn.Session{UserID: tc.userID, DomainID: tc.domainID}
}
authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr)
svcCall := svc.On("ViewInvitation", mock.Anything, tc.session, tc.userID, tc.domainID).Return(tc.svcRes, tc.svcErr)
Expand Down Expand Up @@ -527,7 +523,7 @@ func TestDeleteInvitation(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
if tc.token == valid {
tc.session = smqauthn.Session{UserID: tc.inviteeUserID, DomainID: tc.domainID, DomainUserID: tc.domainID + "_" + tc.inviteeUserID}
tc.session = smqauthn.Session{UserID: tc.inviteeUserID, DomainID: tc.domainID}
}
authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr)
svcCall := svc.On("DeleteInvitation", mock.Anything, tc.session, tc.inviteeUserID, tc.domainID).Return(tc.svcErr)
Expand Down

0 comments on commit 3c8457b

Please sign in to comment.