Skip to content

Commit

Permalink
channels: return role provisioned on create
Browse files Browse the repository at this point in the history
Signed-off-by: Arvindh <[email protected]>
  • Loading branch information
arvindh123 committed Dec 26, 2024
1 parent c759a10 commit a48a983
Show file tree
Hide file tree
Showing 12 changed files with 56 additions and 36 deletions.
5 changes: 3 additions & 2 deletions channels/api/http/endpoint_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"github.com/absmach/supermq/pkg/connections"
"github.com/absmach/supermq/pkg/errors"
svcerr "github.com/absmach/supermq/pkg/errors/service"
"github.com/absmach/supermq/pkg/roles"
"github.com/go-chi/chi/v5"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
Expand Down Expand Up @@ -176,7 +177,7 @@ func TestCreateChannelEndpoint(t *testing.T) {
tc.session = smqauthn.Session{DomainUserID: validID + "_" + validID, UserID: validID, DomainID: validID}
}
authCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authnErr)
svcCall := svc.On("CreateChannels", mock.Anything, tc.session, tc.req).Return(tc.svcResp, tc.svcErr)
svcCall := svc.On("CreateChannels", mock.Anything, tc.session, tc.req).Return(tc.svcResp, []roles.RoleProvision{}, tc.svcErr)
res, err := req.make()
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err))
var errRes respBody
Expand Down Expand Up @@ -312,7 +313,7 @@ func TestCreateChannelsEndpoint(t *testing.T) {
tc.session = smqauthn.Session{DomainUserID: validID + "_" + validID, UserID: validID, DomainID: validID}
}
authCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authnErr)
svcCall := svc.On("CreateChannels", mock.Anything, tc.session, tc.req[0]).Return(tc.svcResp, tc.svcErr)
svcCall := svc.On("CreateChannels", mock.Anything, tc.session, tc.req[0]).Return(tc.svcResp, []roles.RoleProvision{}, tc.svcErr)
res, err := req.make()
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err))
var errRes respBody
Expand Down
4 changes: 2 additions & 2 deletions channels/api/http/endpoints.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func createChannelEndpoint(svc channels.Service) endpoint.Endpoint {
return nil, svcerr.ErrAuthentication
}

channels, err := svc.CreateChannels(ctx, session, req.Channel)
channels, _, err := svc.CreateChannels(ctx, session, req.Channel)
if err != nil {
return nil, err
}
Expand All @@ -51,7 +51,7 @@ func createChannelsEndpoint(svc channels.Service) endpoint.Endpoint {
return nil, svcerr.ErrAuthentication
}

channels, err := svc.CreateChannels(ctx, session, req.Channels...)
channels, _, err := svc.CreateChannels(ctx, session, req.Channels...)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion channels/channels.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ type AuthzReq struct {
//go:generate mockery --name Service --output=./mocks --filename service.go --quiet --note "Copyright (c) Abstract Machines"
type Service interface {
// CreateChannels adds channels to the user identified by the provided key.
CreateChannels(ctx context.Context, session authn.Session, channels ...Channel) ([]Channel, error)
CreateChannels(ctx context.Context, session authn.Session, channels ...Channel) ([]Channel, []roles.RoleProvision, error)

// ViewChannel retrieves data about the channel identified by the provided
// ID, that belongs to the user identified by the provided key.
Expand Down
11 changes: 7 additions & 4 deletions channels/events/events.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/absmach/supermq/channels"
"github.com/absmach/supermq/pkg/connections"
"github.com/absmach/supermq/pkg/events"
"github.com/absmach/supermq/pkg/roles"
)

const (
Expand Down Expand Up @@ -38,14 +39,16 @@ var (

type createChannelEvent struct {
channels.Channel
rolesProvisioned []roles.RoleProvision
}

func (cce createChannelEvent) Encode() (map[string]interface{}, error) {
val := map[string]interface{}{
"operation": channelCreate,
"id": cce.ID,
"status": cce.Status.String(),
"created_at": cce.CreatedAt,
"operation": channelCreate,
"id": cce.ID,
"roles_provisioned": cce.rolesProvisioned,
"status": cce.Status.String(),
"created_at": cce.CreatedAt,
}

if cce.Name != "" {
Expand Down
14 changes: 8 additions & 6 deletions channels/events/streams.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/absmach/supermq/pkg/connections"
"github.com/absmach/supermq/pkg/events"
"github.com/absmach/supermq/pkg/events/store"
"github.com/absmach/supermq/pkg/roles"
rmEvents "github.com/absmach/supermq/pkg/roles/rolemanager/events"
)

Expand Down Expand Up @@ -40,22 +41,23 @@ func NewEventStoreMiddleware(ctx context.Context, svc channels.Service, url stri
}, nil
}

func (es *eventStore) CreateChannels(ctx context.Context, session authn.Session, chs ...channels.Channel) ([]channels.Channel, error) {
chs, err := es.svc.CreateChannels(ctx, session, chs...)
func (es *eventStore) CreateChannels(ctx context.Context, session authn.Session, chs ...channels.Channel) ([]channels.Channel, []roles.RoleProvision, error) {
chs, rps, err := es.svc.CreateChannels(ctx, session, chs...)
if err != nil {
return chs, err
return chs, rps, err
}

for _, ch := range chs {
event := createChannelEvent{
ch,
Channel: ch,
rolesProvisioned: rps,
}
if err := es.Publish(ctx, event); err != nil {
return chs, err
return chs, rps, err
}
}

return chs, nil
return chs, rps, nil
}

func (es *eventStore) UpdateChannel(ctx context.Context, session authn.Session, ch channels.Channel) (channels.Channel, error) {
Expand Down
7 changes: 4 additions & 3 deletions channels/middleware/authorization.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/absmach/supermq/pkg/errors"
svcerr "github.com/absmach/supermq/pkg/errors/service"
"github.com/absmach/supermq/pkg/policies"
"github.com/absmach/supermq/pkg/roles"
rmMW "github.com/absmach/supermq/pkg/roles/rolemanager/middleware"
"github.com/absmach/supermq/pkg/svcutil"
)
Expand Down Expand Up @@ -80,15 +81,15 @@ func AuthorizationMiddleware(svc channels.Service, repo channels.Repository, aut
}, nil
}

func (am *authorizationMiddleware) CreateChannels(ctx context.Context, session authn.Session, chs ...channels.Channel) ([]channels.Channel, error) {
func (am *authorizationMiddleware) CreateChannels(ctx context.Context, session authn.Session, chs ...channels.Channel) ([]channels.Channel, []roles.RoleProvision, error) {
if err := am.extAuthorize(ctx, channels.DomainOpCreateChannel, authz.PolicyReq{
Domain: session.DomainID,
SubjectType: policies.UserType,
Subject: session.DomainUserID,
ObjectType: policies.DomainType,
Object: session.DomainID,
}); err != nil {
return []channels.Channel{}, errors.Wrap(err, errDomainCreateChannels)
return []channels.Channel{}, []roles.RoleProvision{}, errors.Wrap(err, errDomainCreateChannels)
}

for _, ch := range chs {
Expand All @@ -100,7 +101,7 @@ func (am *authorizationMiddleware) CreateChannels(ctx context.Context, session a
ObjectType: policies.GroupType,
Object: ch.ParentGroup,
}); err != nil {
return []channels.Channel{}, errors.Wrap(err, errors.Wrap(errGroupSetChildChannels, fmt.Errorf("channel name %s parent group id %s", ch.Name, ch.ParentGroup)))
return []channels.Channel{}, []roles.RoleProvision{}, errors.Wrap(err, errors.Wrap(errGroupSetChildChannels, fmt.Errorf("channel name %s parent group id %s", ch.Name, ch.ParentGroup)))
}
}
}
Expand Down
3 changes: 2 additions & 1 deletion channels/middleware/logging.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/absmach/supermq/channels"
"github.com/absmach/supermq/pkg/authn"
"github.com/absmach/supermq/pkg/connections"
"github.com/absmach/supermq/pkg/roles"
rmMW "github.com/absmach/supermq/pkg/roles/rolemanager/middleware"
)

Expand All @@ -27,7 +28,7 @@ func LoggingMiddleware(svc channels.Service, logger *slog.Logger) channels.Servi
return &loggingMiddleware{logger, svc, rmMW.NewRoleManagerLoggingMiddleware("channels", svc, logger)}
}

func (lm *loggingMiddleware) CreateChannels(ctx context.Context, session authn.Session, clients ...channels.Channel) (cs []channels.Channel, err error) {
func (lm *loggingMiddleware) CreateChannels(ctx context.Context, session authn.Session, clients ...channels.Channel) (cs []channels.Channel, rps []roles.RoleProvision, err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
Expand Down
3 changes: 2 additions & 1 deletion channels/middleware/metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"github.com/absmach/supermq/channels"
"github.com/absmach/supermq/pkg/authn"
"github.com/absmach/supermq/pkg/connections"
"github.com/absmach/supermq/pkg/roles"
rmMW "github.com/absmach/supermq/pkg/roles/rolemanager/middleware"
"github.com/go-kit/kit/metrics"
)
Expand All @@ -33,7 +34,7 @@ func MetricsMiddleware(svc channels.Service, counter metrics.Counter, latency me
}
}

func (ms *metricsMiddleware) CreateChannels(ctx context.Context, session authn.Session, chs ...channels.Channel) ([]channels.Channel, error) {
func (ms *metricsMiddleware) CreateChannels(ctx context.Context, session authn.Session, chs ...channels.Channel) ([]channels.Channel, []roles.RoleProvision, error) {
defer func(begin time.Time) {
ms.counter.With("method", "register_channels").Add(1)
ms.latency.With("method", "register_channels").Observe(time.Since(begin).Seconds())
Expand Down
21 changes: 15 additions & 6 deletions channels/mocks/service.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 8 additions & 7 deletions channels/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,19 +57,19 @@ func New(repo Repository, policy policies.Service, idProvider supermq.IDProvider
}, nil
}

func (svc service) CreateChannels(ctx context.Context, session authn.Session, chs ...Channel) (retChs []Channel, retErr error) {
func (svc service) CreateChannels(ctx context.Context, session authn.Session, chs ...Channel) (retChs []Channel, retRps []roles.RoleProvision, retErr error) {
var reChs []Channel
for _, c := range chs {
if c.ID == "" {
clientID, err := svc.idProvider.ID()
if err != nil {
return []Channel{}, err
return []Channel{}, []roles.RoleProvision{}, err
}
c.ID = clientID
}

if c.Status != smqclients.DisabledStatus && c.Status != smqclients.EnabledStatus {
return []Channel{}, svcerr.ErrInvalidStatus
return []Channel{}, []roles.RoleProvision{}, svcerr.ErrInvalidStatus
}
c.Domain = session.DomainID
c.CreatedAt = time.Now()
Expand All @@ -78,7 +78,7 @@ func (svc service) CreateChannels(ctx context.Context, session authn.Session, ch

savedChs, err := svc.repo.Save(ctx, reChs...)
if err != nil {
return nil, errors.Wrap(svcerr.ErrCreateEntity, err)
return []Channel{}, []roles.RoleProvision{}, errors.Wrap(svcerr.ErrCreateEntity, err)
}
chIDs := []string{}
for _, c := range savedChs {
Expand Down Expand Up @@ -110,10 +110,11 @@ func (svc service) CreateChannels(ctx context.Context, session authn.Session, ch
},
)
}
if _, err := svc.AddNewEntitiesRoles(ctx, session.DomainID, session.UserID, chIDs, optionalPolicies, newBuiltInRoleMembers); err != nil {
return []Channel{}, errors.Wrap(svcerr.ErrAddPolicies, err)
nrps, err := svc.AddNewEntitiesRoles(ctx, session.DomainID, session.UserID, chIDs, optionalPolicies, newBuiltInRoleMembers)
if err != nil {
return []Channel{}, []roles.RoleProvision{}, errors.Wrap(svcerr.ErrAddPolicies, err)
}
return savedChs, nil
return savedChs, nrps, nil
}

func (svc service) UpdateChannel(ctx context.Context, session authn.Session, ch Channel) (Channel, error) {
Expand Down
4 changes: 2 additions & 2 deletions channels/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -186,9 +186,9 @@ func TestCreateChannel(t *testing.T) {
repoCall := repo.On("Save", context.Background(), mock.Anything).Return(tc.saveResp, tc.saveErr)
policyCall := policies.On("AddPolicies", context.Background(), mock.Anything).Return(tc.addPoliciesErr)
policyCall1 := policies.On("DeletePolicies", context.Background(), mock.Anything).Return(tc.deletePoliciesErr)
repoCall1 := repo.On("AddRoles", context.Background(), mock.Anything).Return([]roles.Role{}, tc.addRoleErr)
repoCall1 := repo.On("AddRoles", context.Background(), mock.Anything).Return([]roles.RoleProvision{}, tc.addRoleErr)
repoCall2 := repo.On("Remove", context.Background(), mock.Anything).Return(tc.deleteErr)
_, err := svc.CreateChannels(context.Background(), validSession, tc.channel)
_, _, err := svc.CreateChannels(context.Background(), validSession, tc.channel)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("expected error %v but got %v", tc.err, err))
if err == nil {
ok := repoCall.Parent.AssertCalled(t, "Save", context.Background(), mock.Anything)
Expand Down
3 changes: 2 additions & 1 deletion channels/tracing/tracing.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/absmach/supermq/channels"
"github.com/absmach/supermq/pkg/authn"
"github.com/absmach/supermq/pkg/connections"
"github.com/absmach/supermq/pkg/roles"
rmTrace "github.com/absmach/supermq/pkg/roles/rolemanager/tracing"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
Expand All @@ -28,7 +29,7 @@ func New(svc channels.Service, tracer trace.Tracer) channels.Service {
}

// CreateChannels traces the "CreateChannels" operation of the wrapped policies.Service.
func (tm *tracingMiddleware) CreateChannels(ctx context.Context, session authn.Session, chs ...channels.Channel) ([]channels.Channel, error) {
func (tm *tracingMiddleware) CreateChannels(ctx context.Context, session authn.Session, chs ...channels.Channel) ([]channels.Channel, []roles.RoleProvision, error) {
ctx, span := tm.tracer.Start(ctx, "svc_create_channel")
defer span.End()

Expand Down

0 comments on commit a48a983

Please sign in to comment.