diff --git a/api/grpc/auth/v1/auth.pb.go b/api/grpc/auth/v1/auth.pb.go index 1ee360a95f..e921967828 100644 --- a/api/grpc/auth/v1/auth.pb.go +++ b/api/grpc/auth/v1/auth.pb.go @@ -236,16 +236,15 @@ func (x *AuthZReq) GetObjectType() string { } type AuthZPatReq struct { - state protoimpl.MessageState `protogen:"open.v1"` - UserId string `protobuf:"bytes,1,opt,name=user_id,json=userId,proto3" json:"user_id,omitempty"` // User id - PatId string `protobuf:"bytes,2,opt,name=pat_id,json=patId,proto3" json:"pat_id,omitempty"` // Pat id - PlatformEntityType uint32 `protobuf:"varint,3,opt,name=platform_entity_type,json=platformEntityType,proto3" json:"platform_entity_type,omitempty"` // Platform entity type - OptionalDomainId string `protobuf:"bytes,4,opt,name=optional_domain_id,json=optionalDomainId,proto3" json:"optional_domain_id,omitempty"` // Optional domain id - OptionalDomainEntityType uint32 `protobuf:"varint,5,opt,name=optional_domain_entity_type,json=optionalDomainEntityType,proto3" json:"optional_domain_entity_type,omitempty"` // Optional domain entity type - Operation uint32 `protobuf:"varint,6,opt,name=operation,proto3" json:"operation,omitempty"` // Operation - EntityIds []string `protobuf:"bytes,7,rep,name=entity_ids,json=entityIds,proto3" json:"entity_ids,omitempty"` // EntityIDs - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache + state protoimpl.MessageState `protogen:"open.v1"` + UserId string `protobuf:"bytes,1,opt,name=user_id,json=userId,proto3" json:"user_id,omitempty"` // User id + PatId string `protobuf:"bytes,2,opt,name=pat_id,json=patId,proto3" json:"pat_id,omitempty"` // Pat id + EntityType uint32 `protobuf:"varint,3,opt,name=entity_type,json=entityType,proto3" json:"entity_type,omitempty"` // Entity type + OptionalDomainId string `protobuf:"bytes,4,opt,name=optional_domain_id,json=optionalDomainId,proto3" json:"optional_domain_id,omitempty"` // Optional domain id + Operation uint32 `protobuf:"varint,6,opt,name=operation,proto3" json:"operation,omitempty"` // Operation + EntityId string `protobuf:"bytes,7,opt,name=entity_id,json=entityId,proto3" json:"entity_id,omitempty"` // EntityID + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *AuthZPatReq) Reset() { @@ -292,9 +291,9 @@ func (x *AuthZPatReq) GetPatId() string { return "" } -func (x *AuthZPatReq) GetPlatformEntityType() uint32 { +func (x *AuthZPatReq) GetEntityType() uint32 { if x != nil { - return x.PlatformEntityType + return x.EntityType } return 0 } @@ -306,13 +305,6 @@ func (x *AuthZPatReq) GetOptionalDomainId() string { return "" } -func (x *AuthZPatReq) GetOptionalDomainEntityType() uint32 { - if x != nil { - return x.OptionalDomainEntityType - } - return 0 -} - func (x *AuthZPatReq) GetOperation() uint32 { if x != nil { return x.Operation @@ -320,11 +312,11 @@ func (x *AuthZPatReq) GetOperation() uint32 { return 0 } -func (x *AuthZPatReq) GetEntityIds() []string { +func (x *AuthZPatReq) GetEntityId() string { if x != nil { - return x.EntityIds + return x.EntityId } - return nil + return "" } type AuthZRes struct { @@ -409,47 +401,42 @@ var file_auth_v1_auth_proto_rawDesc = []byte{ 0x06, 0x6f, 0x62, 0x6a, 0x65, 0x63, 0x74, 0x18, 0x08, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x6f, 0x62, 0x6a, 0x65, 0x63, 0x74, 0x12, 0x1f, 0x0a, 0x0b, 0x6f, 0x62, 0x6a, 0x65, 0x63, 0x74, 0x5f, 0x74, 0x79, 0x70, 0x65, 0x18, 0x09, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x6f, 0x62, 0x6a, 0x65, - 0x63, 0x74, 0x54, 0x79, 0x70, 0x65, 0x22, 0x99, 0x02, 0x0a, 0x0b, 0x41, 0x75, 0x74, 0x68, 0x5a, + 0x63, 0x74, 0x54, 0x79, 0x70, 0x65, 0x22, 0xc7, 0x01, 0x0a, 0x0b, 0x41, 0x75, 0x74, 0x68, 0x5a, 0x50, 0x61, 0x74, 0x52, 0x65, 0x71, 0x12, 0x17, 0x0a, 0x07, 0x75, 0x73, 0x65, 0x72, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x75, 0x73, 0x65, 0x72, 0x49, 0x64, 0x12, 0x15, 0x0a, 0x06, 0x70, 0x61, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x05, 0x70, 0x61, 0x74, 0x49, 0x64, 0x12, 0x30, 0x0a, 0x14, 0x70, 0x6c, 0x61, 0x74, 0x66, 0x6f, - 0x72, 0x6d, 0x5f, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x5f, 0x74, 0x79, 0x70, 0x65, 0x18, 0x03, - 0x20, 0x01, 0x28, 0x0d, 0x52, 0x12, 0x70, 0x6c, 0x61, 0x74, 0x66, 0x6f, 0x72, 0x6d, 0x45, 0x6e, - 0x74, 0x69, 0x74, 0x79, 0x54, 0x79, 0x70, 0x65, 0x12, 0x2c, 0x0a, 0x12, 0x6f, 0x70, 0x74, 0x69, - 0x6f, 0x6e, 0x61, 0x6c, 0x5f, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x5f, 0x69, 0x64, 0x18, 0x04, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x10, 0x6f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x61, 0x6c, 0x44, 0x6f, - 0x6d, 0x61, 0x69, 0x6e, 0x49, 0x64, 0x12, 0x3d, 0x0a, 0x1b, 0x6f, 0x70, 0x74, 0x69, 0x6f, 0x6e, - 0x61, 0x6c, 0x5f, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x5f, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, - 0x5f, 0x74, 0x79, 0x70, 0x65, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x18, 0x6f, 0x70, 0x74, - 0x69, 0x6f, 0x6e, 0x61, 0x6c, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x45, 0x6e, 0x74, 0x69, 0x74, - 0x79, 0x54, 0x79, 0x70, 0x65, 0x12, 0x1c, 0x0a, 0x09, 0x6f, 0x70, 0x65, 0x72, 0x61, 0x74, 0x69, + 0x05, 0x70, 0x61, 0x74, 0x49, 0x64, 0x12, 0x1f, 0x0a, 0x0b, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, + 0x5f, 0x74, 0x79, 0x70, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0a, 0x65, 0x6e, 0x74, + 0x69, 0x74, 0x79, 0x54, 0x79, 0x70, 0x65, 0x12, 0x2c, 0x0a, 0x12, 0x6f, 0x70, 0x74, 0x69, 0x6f, + 0x6e, 0x61, 0x6c, 0x5f, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x5f, 0x69, 0x64, 0x18, 0x04, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x10, 0x6f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x61, 0x6c, 0x44, 0x6f, 0x6d, + 0x61, 0x69, 0x6e, 0x49, 0x64, 0x12, 0x1c, 0x0a, 0x09, 0x6f, 0x70, 0x65, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x06, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x09, 0x6f, 0x70, 0x65, 0x72, 0x61, 0x74, - 0x69, 0x6f, 0x6e, 0x12, 0x1d, 0x0a, 0x0a, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x5f, 0x69, 0x64, - 0x73, 0x18, 0x07, 0x20, 0x03, 0x28, 0x09, 0x52, 0x09, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x49, - 0x64, 0x73, 0x22, 0x3a, 0x0a, 0x08, 0x41, 0x75, 0x74, 0x68, 0x5a, 0x52, 0x65, 0x73, 0x12, 0x1e, - 0x0a, 0x0a, 0x61, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x65, 0x64, 0x18, 0x01, 0x20, 0x01, - 0x28, 0x08, 0x52, 0x0a, 0x61, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x65, 0x64, 0x12, 0x0e, - 0x0a, 0x02, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, 0x64, 0x32, 0xf0, - 0x01, 0x0a, 0x0b, 0x41, 0x75, 0x74, 0x68, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x33, - 0x0a, 0x09, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x65, 0x12, 0x11, 0x2e, 0x61, 0x75, - 0x74, 0x68, 0x2e, 0x76, 0x31, 0x2e, 0x41, 0x75, 0x74, 0x68, 0x5a, 0x52, 0x65, 0x71, 0x1a, 0x11, - 0x2e, 0x61, 0x75, 0x74, 0x68, 0x2e, 0x76, 0x31, 0x2e, 0x41, 0x75, 0x74, 0x68, 0x5a, 0x52, 0x65, - 0x73, 0x22, 0x00, 0x12, 0x39, 0x0a, 0x0c, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x65, - 0x50, 0x41, 0x54, 0x12, 0x14, 0x2e, 0x61, 0x75, 0x74, 0x68, 0x2e, 0x76, 0x31, 0x2e, 0x41, 0x75, - 0x74, 0x68, 0x5a, 0x50, 0x61, 0x74, 0x52, 0x65, 0x71, 0x1a, 0x11, 0x2e, 0x61, 0x75, 0x74, 0x68, - 0x2e, 0x76, 0x31, 0x2e, 0x41, 0x75, 0x74, 0x68, 0x5a, 0x52, 0x65, 0x73, 0x22, 0x00, 0x12, 0x36, - 0x0a, 0x0c, 0x41, 0x75, 0x74, 0x68, 0x65, 0x6e, 0x74, 0x69, 0x63, 0x61, 0x74, 0x65, 0x12, 0x11, - 0x2e, 0x61, 0x75, 0x74, 0x68, 0x2e, 0x76, 0x31, 0x2e, 0x41, 0x75, 0x74, 0x68, 0x4e, 0x52, 0x65, - 0x71, 0x1a, 0x11, 0x2e, 0x61, 0x75, 0x74, 0x68, 0x2e, 0x76, 0x31, 0x2e, 0x41, 0x75, 0x74, 0x68, - 0x4e, 0x52, 0x65, 0x73, 0x22, 0x00, 0x12, 0x39, 0x0a, 0x0f, 0x41, 0x75, 0x74, 0x68, 0x65, 0x6e, - 0x74, 0x69, 0x63, 0x61, 0x74, 0x65, 0x50, 0x41, 0x54, 0x12, 0x11, 0x2e, 0x61, 0x75, 0x74, 0x68, - 0x2e, 0x76, 0x31, 0x2e, 0x41, 0x75, 0x74, 0x68, 0x4e, 0x52, 0x65, 0x71, 0x1a, 0x11, 0x2e, 0x61, - 0x75, 0x74, 0x68, 0x2e, 0x76, 0x31, 0x2e, 0x41, 0x75, 0x74, 0x68, 0x4e, 0x52, 0x65, 0x73, 0x22, - 0x00, 0x42, 0x2d, 0x5a, 0x2b, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, - 0x61, 0x62, 0x73, 0x6d, 0x61, 0x63, 0x68, 0x2f, 0x73, 0x75, 0x70, 0x65, 0x72, 0x6d, 0x71, 0x2f, - 0x61, 0x70, 0x69, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x61, 0x75, 0x74, 0x68, 0x2f, 0x76, 0x31, - 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x69, 0x6f, 0x6e, 0x12, 0x1b, 0x0a, 0x09, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x5f, 0x69, 0x64, + 0x18, 0x07, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x49, 0x64, + 0x22, 0x3a, 0x0a, 0x08, 0x41, 0x75, 0x74, 0x68, 0x5a, 0x52, 0x65, 0x73, 0x12, 0x1e, 0x0a, 0x0a, + 0x61, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x65, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, + 0x52, 0x0a, 0x61, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x65, 0x64, 0x12, 0x0e, 0x0a, 0x02, + 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, 0x64, 0x32, 0xf0, 0x01, 0x0a, + 0x0b, 0x41, 0x75, 0x74, 0x68, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x33, 0x0a, 0x09, + 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x65, 0x12, 0x11, 0x2e, 0x61, 0x75, 0x74, 0x68, + 0x2e, 0x76, 0x31, 0x2e, 0x41, 0x75, 0x74, 0x68, 0x5a, 0x52, 0x65, 0x71, 0x1a, 0x11, 0x2e, 0x61, + 0x75, 0x74, 0x68, 0x2e, 0x76, 0x31, 0x2e, 0x41, 0x75, 0x74, 0x68, 0x5a, 0x52, 0x65, 0x73, 0x22, + 0x00, 0x12, 0x39, 0x0a, 0x0c, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x65, 0x50, 0x41, + 0x54, 0x12, 0x14, 0x2e, 0x61, 0x75, 0x74, 0x68, 0x2e, 0x76, 0x31, 0x2e, 0x41, 0x75, 0x74, 0x68, + 0x5a, 0x50, 0x61, 0x74, 0x52, 0x65, 0x71, 0x1a, 0x11, 0x2e, 0x61, 0x75, 0x74, 0x68, 0x2e, 0x76, + 0x31, 0x2e, 0x41, 0x75, 0x74, 0x68, 0x5a, 0x52, 0x65, 0x73, 0x22, 0x00, 0x12, 0x36, 0x0a, 0x0c, + 0x41, 0x75, 0x74, 0x68, 0x65, 0x6e, 0x74, 0x69, 0x63, 0x61, 0x74, 0x65, 0x12, 0x11, 0x2e, 0x61, + 0x75, 0x74, 0x68, 0x2e, 0x76, 0x31, 0x2e, 0x41, 0x75, 0x74, 0x68, 0x4e, 0x52, 0x65, 0x71, 0x1a, + 0x11, 0x2e, 0x61, 0x75, 0x74, 0x68, 0x2e, 0x76, 0x31, 0x2e, 0x41, 0x75, 0x74, 0x68, 0x4e, 0x52, + 0x65, 0x73, 0x22, 0x00, 0x12, 0x39, 0x0a, 0x0f, 0x41, 0x75, 0x74, 0x68, 0x65, 0x6e, 0x74, 0x69, + 0x63, 0x61, 0x74, 0x65, 0x50, 0x41, 0x54, 0x12, 0x11, 0x2e, 0x61, 0x75, 0x74, 0x68, 0x2e, 0x76, + 0x31, 0x2e, 0x41, 0x75, 0x74, 0x68, 0x4e, 0x52, 0x65, 0x71, 0x1a, 0x11, 0x2e, 0x61, 0x75, 0x74, + 0x68, 0x2e, 0x76, 0x31, 0x2e, 0x41, 0x75, 0x74, 0x68, 0x4e, 0x52, 0x65, 0x73, 0x22, 0x00, 0x42, + 0x2d, 0x5a, 0x2b, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x61, 0x62, + 0x73, 0x6d, 0x61, 0x63, 0x68, 0x2f, 0x73, 0x75, 0x70, 0x65, 0x72, 0x6d, 0x71, 0x2f, 0x61, 0x70, + 0x69, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x61, 0x75, 0x74, 0x68, 0x2f, 0x76, 0x31, 0x62, 0x06, + 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( diff --git a/auth/api/grpc/auth/client.go b/auth/api/grpc/auth/client.go index ba9ad7e740..b469522b2b 100644 --- a/auth/api/grpc/auth/client.go +++ b/auth/api/grpc/auth/client.go @@ -151,13 +151,12 @@ func (client authGrpcClient) AuthorizePAT(ctx context.Context, req *grpcAuthV1.A defer cancel() res, err := client.authorizePAT(ctx, authPATReq{ - userID: req.GetUserId(), - patID: req.GetPatId(), - platformEntityType: auth.PlatformEntityType(req.GetPlatformEntityType()), - optionalDomainID: req.GetOptionalDomainId(), - optionalDomainEntityType: auth.DomainEntityType(req.GetOptionalDomainEntityType()), - operation: auth.OperationType(req.GetOperation()), - entityIDs: req.GetEntityIds(), + userID: req.GetUserId(), + patID: req.GetPatId(), + entityType: auth.EntityType(req.GetEntityType()), + optionalDomainID: req.GetOptionalDomainId(), + operation: auth.Operation(req.GetOperation()), + entityID: req.GetEntityId(), }) if err != nil { return &grpcAuthV1.AuthZRes{}, grpcapi.DecodeError(err) @@ -170,12 +169,11 @@ func (client authGrpcClient) AuthorizePAT(ctx context.Context, req *grpcAuthV1.A func encodeAuthorizePATRequest(_ context.Context, grpcReq interface{}) (interface{}, error) { req := grpcReq.(authPATReq) return &grpcAuthV1.AuthZPatReq{ - UserId: req.userID, - PatId: req.patID, - PlatformEntityType: uint32(req.platformEntityType), - OptionalDomainId: req.optionalDomainID, - OptionalDomainEntityType: uint32(req.optionalDomainEntityType), - Operation: uint32(req.operation), - EntityIds: req.entityIDs, + UserId: req.userID, + PatId: req.patID, + EntityType: uint32(req.entityType), + OptionalDomainId: req.optionalDomainID, + Operation: uint32(req.operation), + EntityId: req.entityID, }, nil } diff --git a/auth/api/grpc/auth/endpoint.go b/auth/api/grpc/auth/endpoint.go index 05516b64e6..7af6c83c9d 100644 --- a/auth/api/grpc/auth/endpoint.go +++ b/auth/api/grpc/auth/endpoint.go @@ -74,7 +74,7 @@ func authorizePATEndpoint(svc auth.Service) endpoint.Endpoint { if err := req.validate(); err != nil { return authorizeRes{}, err } - err := svc.AuthorizePAT(ctx, req.userID, req.patID, req.platformEntityType, req.optionalDomainID, req.optionalDomainEntityType, req.operation, req.entityIDs...) + err := svc.AuthorizePAT(ctx, req.userID, req.patID, req.entityType, req.optionalDomainID, req.operation, req.entityID) if err != nil { return authorizeRes{authorized: false}, err } diff --git a/auth/api/grpc/auth/endpoint_test.go b/auth/api/grpc/auth/endpoint_test.go index e506f6c82c..159a25449e 100644 --- a/auth/api/grpc/auth/endpoint_test.go +++ b/auth/api/grpc/auth/endpoint_test.go @@ -301,13 +301,12 @@ func TestAuthorizePAT(t *testing.T) { desc: "authorize user with authorized token", token: validPATToken, authRequest: &grpcAuthV1.AuthZPatReq{ - UserId: id, - PatId: id, - PlatformEntityType: uint32(auth.PlatformDomainsScope), - OptionalDomainId: domainID, - OptionalDomainEntityType: uint32(auth.DomainClientsScope), - Operation: uint32(auth.CreateOp), - EntityIds: []string{clientID}, + UserId: id, + PatId: id, + EntityType: uint32(auth.ClientsType), + OptionalDomainId: domainID, + Operation: uint32(auth.CreateOp), + EntityId: clientID, }, authResponse: &grpcAuthV1.AuthZRes{Authorized: true}, err: nil, @@ -316,13 +315,12 @@ func TestAuthorizePAT(t *testing.T) { desc: "authorize user with unauthorized token", token: inValidPATToken, authRequest: &grpcAuthV1.AuthZPatReq{ - UserId: id, - PatId: id, - PlatformEntityType: uint32(auth.PlatformDomainsScope), - OptionalDomainId: domainID, - OptionalDomainEntityType: uint32(auth.DomainClientsScope), - Operation: uint32(auth.CreateOp), - EntityIds: []string{clientID}, + UserId: id, + PatId: id, + EntityType: uint32(auth.ClientsType), + OptionalDomainId: domainID, + Operation: uint32(auth.CreateOp), + EntityId: clientID, }, authResponse: &grpcAuthV1.AuthZRes{Authorized: false}, err: svcerr.ErrAuthorization, @@ -331,12 +329,11 @@ func TestAuthorizePAT(t *testing.T) { desc: "authorize user with missing user id", token: validPATToken, authRequest: &grpcAuthV1.AuthZPatReq{ - PatId: id, - PlatformEntityType: uint32(auth.PlatformDomainsScope), - OptionalDomainId: domainID, - OptionalDomainEntityType: uint32(auth.DomainClientsScope), - Operation: uint32(auth.CreateOp), - EntityIds: []string{clientID}, + PatId: id, + EntityType: uint32(auth.ClientsType), + OptionalDomainId: domainID, + Operation: uint32(auth.CreateOp), + EntityId: clientID, }, authResponse: &grpcAuthV1.AuthZRes{Authorized: false}, err: apiutil.ErrMissingUserID, @@ -345,12 +342,11 @@ func TestAuthorizePAT(t *testing.T) { desc: "authorize user with missing pat id", token: validPATToken, authRequest: &grpcAuthV1.AuthZPatReq{ - UserId: id, - PlatformEntityType: uint32(auth.PlatformDomainsScope), - OptionalDomainId: domainID, - OptionalDomainEntityType: uint32(auth.DomainClientsScope), - Operation: uint32(auth.CreateOp), - EntityIds: []string{clientID}, + UserId: id, + EntityType: uint32(auth.ClientsType), + OptionalDomainId: domainID, + Operation: uint32(auth.CreateOp), + EntityId: clientID, }, authResponse: &grpcAuthV1.AuthZRes{Authorized: false}, err: apiutil.ErrMissingPATID, diff --git a/auth/api/grpc/auth/requests.go b/auth/api/grpc/auth/requests.go index 11db7338bb..b9efcc4f67 100644 --- a/auth/api/grpc/auth/requests.go +++ b/auth/api/grpc/auth/requests.go @@ -52,13 +52,12 @@ func (req authReq) validate() error { } type authPATReq struct { - userID string - patID string - platformEntityType auth.PlatformEntityType - optionalDomainID string - optionalDomainEntityType auth.DomainEntityType - operation auth.OperationType - entityIDs []string + userID string + patID string + entityType auth.EntityType + optionalDomainID string + operation auth.Operation + entityID string } func (req authPATReq) validate() error { diff --git a/auth/api/grpc/auth/server.go b/auth/api/grpc/auth/server.go index 72e37d2e8d..4188732f2a 100644 --- a/auth/api/grpc/auth/server.go +++ b/auth/api/grpc/auth/server.go @@ -112,13 +112,12 @@ func encodeAuthorizeResponse(_ context.Context, grpcRes interface{}) (interface{ func decodeAuthorizePATRequest(_ context.Context, grpcReq interface{}) (interface{}, error) { req := grpcReq.(*grpcAuthV1.AuthZPatReq) return authPATReq{ - userID: req.GetUserId(), - patID: req.GetPatId(), - platformEntityType: auth.PlatformEntityType(req.GetPlatformEntityType()), - optionalDomainID: req.GetOptionalDomainId(), - optionalDomainEntityType: auth.DomainEntityType(req.GetOptionalDomainEntityType()), - operation: auth.OperationType(req.GetOperation()), - entityIDs: req.GetEntityIds(), + userID: req.GetUserId(), + patID: req.GetPatId(), + entityType: auth.EntityType(req.GetEntityType()), + optionalDomainID: req.GetOptionalDomainId(), + operation: auth.Operation(req.GetOperation()), + entityID: req.GetEntityId(), }, nil } diff --git a/auth/api/http/keys/endpoint_test.go b/auth/api/http/keys/endpoint_test.go index 5b2693f751..52801c7d40 100644 --- a/auth/api/http/keys/endpoint_test.go +++ b/auth/api/http/keys/endpoint_test.go @@ -70,13 +70,14 @@ func (tr testRequest) make() (*http.Response, error) { func newService() (auth.Service, *mocks.KeyRepository) { krepo := new(mocks.KeyRepository) pRepo := new(mocks.PATSRepository) + cache := new(mocks.Cache) hash := new(mocks.Hasher) idProvider := uuid.NewMock() pService := new(policymocks.Service) pEvaluator := new(policymocks.Evaluator) t := jwt.New([]byte(secret)) - return auth.New(krepo, pRepo, hash, idProvider, t, pEvaluator, pService, loginDuration, refreshDuration, invalidDuration), krepo + return auth.New(krepo, pRepo, cache, hash, idProvider, t, pEvaluator, pService, loginDuration, refreshDuration, invalidDuration), krepo } func newServer(svc auth.Service) *httptest.Server { diff --git a/auth/api/http/pats/endpoint.go b/auth/api/http/pats/endpoint.go index 45e6b3c607..7ba58270a9 100644 --- a/auth/api/http/pats/endpoint.go +++ b/auth/api/http/pats/endpoint.go @@ -17,7 +17,7 @@ func createPATEndpoint(svc auth.Service) endpoint.Endpoint { return nil, err } - pat, err := svc.CreatePAT(ctx, req.token, req.Name, req.Description, req.Duration, req.Scope) + pat, err := svc.CreatePAT(ctx, req.token, req.Name, req.Description, req.Duration) if err != nil { return nil, err } @@ -140,48 +140,83 @@ func revokePATSecretEndpoint(svc auth.Service) endpoint.Endpoint { } } -func addPATScopeEntryEndpoint(svc auth.Service) endpoint.Endpoint { +func clearAllPATEndpoint(svc auth.Service) endpoint.Endpoint { return func(ctx context.Context, request interface{}) (interface{}, error) { - req := request.(addPatScopeEntryReq) + req := request.(clearAllPATReq) if err := req.validate(); err != nil { return nil, err } - scope, err := svc.AddPATScopeEntry(ctx, req.token, req.id, req.PlatformEntityType, req.OptionalDomainID, req.OptionalDomainEntityType, req.Operation, req.EntityIDs...) + if err := svc.RemoveAllPAT(ctx, req.token); err != nil { + return nil, err + } + + return clearAllRes{}, nil + } +} + +func addScopeEndpoint(svc auth.Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(addScopeReq) + if err := req.validate(); err != nil { + return nil, err + } + err := svc.AddScope(ctx, req.token, req.id, req.Scopes) if err != nil { return nil, err } - return addPatScopeEntryRes{scope}, nil + return scopeRes{}, nil } } -func removePATScopeEntryEndpoint(svc auth.Service) endpoint.Endpoint { +func removeScopeEndpoint(svc auth.Service) endpoint.Endpoint { return func(ctx context.Context, request interface{}) (interface{}, error) { - req := request.(removePatScopeEntryReq) + req := request.(removeScopeReq) if err := req.validate(); err != nil { return nil, err } - scope, err := svc.RemovePATScopeEntry(ctx, req.token, req.id, req.PlatformEntityType, req.OptionalDomainID, req.OptionalDomainEntityType, req.Operation, req.EntityIDs...) + err := svc.RemoveScope(ctx, req.token, req.id, req.ScopesID...) if err != nil { return nil, err } - return removePatScopeEntryRes{scope}, nil + return scopeRes{}, nil } } -func clearPATAllScopeEntryEndpoint(svc auth.Service) endpoint.Endpoint { +func clearAllScopeEndpoint(svc auth.Service) endpoint.Endpoint { return func(ctx context.Context, request interface{}) (interface{}, error) { - req := request.(clearAllScopeEntryReq) + req := request.(clearAllScopeReq) if err := req.validate(); err != nil { return nil, err } - if err := svc.ClearPATAllScopeEntry(ctx, req.token, req.id); err != nil { + if err := svc.RemovePATAllScope(ctx, req.token, req.id); err != nil { + return nil, err + } + + return clearAllRes{}, nil + } +} + +func listScopesEndpoint(svc auth.Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(listScopesReq) + if err := req.validate(); err != nil { + return nil, err + } + + pm := auth.ScopesPageMeta{ + Limit: req.limit, + Offset: req.offset, + PatID: req.patID, + } + scopesPage, err := svc.ListScopes(ctx, req.token, pm) + if err != nil { return nil, err } - return clearAllScopeEntryRes{}, nil + return listScopeRes{scopesPage}, nil } } diff --git a/auth/api/http/pats/requests.go b/auth/api/http/pats/requests.go index 0a7f0a3f3d..e90a31381e 100644 --- a/auth/api/http/pats/requests.go +++ b/auth/api/http/pats/requests.go @@ -10,6 +10,7 @@ import ( apiutil "github.com/absmach/supermq/api/http/util" "github.com/absmach/supermq/auth" + "github.com/absmach/supermq/pkg/errors" ) type createPatReq struct { @@ -17,15 +18,13 @@ type createPatReq struct { Name string `json:"name,omitempty"` Description string `json:"description,omitempty"` Duration time.Duration `json:"duration,omitempty"` - Scope auth.Scope `json:"scope,omitempty"` } func (cpr *createPatReq) UnmarshalJSON(data []byte) error { var temp struct { - Name string `json:"name,omitempty"` - Description string `json:"description,omitempty"` - Duration string `json:"duration,omitempty"` - Scope auth.Scope `json:"scope,omitempty"` + Name string `json:"name,omitempty"` + Description string `json:"description,omitempty"` + Duration string `json:"duration,omitempty"` } if err := json.Unmarshal(data, &temp); err != nil { return err @@ -37,7 +36,6 @@ func (cpr *createPatReq) UnmarshalJSON(data []byte) error { cpr.Name = temp.Name cpr.Description = temp.Description cpr.Duration = duration - cpr.Scope = temp.Scope return nil } @@ -63,7 +61,7 @@ func (req retrievePatReq) validate() (err error) { return apiutil.ErrBearerToken } if req.id == "" { - return apiutil.ErrMissingID + return apiutil.ErrMissingPATID } return nil } @@ -79,7 +77,7 @@ func (req updatePatNameReq) validate() (err error) { return apiutil.ErrBearerToken } if req.id == "" { - return apiutil.ErrMissingID + return apiutil.ErrMissingPATID } if strings.TrimSpace(req.Name) == "" { return apiutil.ErrMissingName @@ -98,7 +96,7 @@ func (req updatePatDescriptionReq) validate() (err error) { return apiutil.ErrBearerToken } if req.id == "" { - return apiutil.ErrMissingID + return apiutil.ErrMissingPATID } if strings.TrimSpace(req.Description) == "" { return apiutil.ErrMissingDescription @@ -129,7 +127,7 @@ func (req deletePatReq) validate() (err error) { return apiutil.ErrBearerToken } if req.id == "" { - return apiutil.ErrMissingID + return apiutil.ErrMissingPATID } return nil } @@ -161,7 +159,7 @@ func (req resetPatSecretReq) validate() (err error) { return apiutil.ErrBearerToken } if req.id == "" { - return apiutil.ErrMissingID + return apiutil.ErrMissingPATID } return nil } @@ -176,128 +174,111 @@ func (req revokePatSecretReq) validate() (err error) { return apiutil.ErrBearerToken } if req.id == "" { - return apiutil.ErrMissingID + return apiutil.ErrMissingPATID } return nil } -type addPatScopeEntryReq struct { - token string - id string - PlatformEntityType auth.PlatformEntityType `json:"platform_entity_type,omitempty"` - OptionalDomainID string `json:"optional_domain_id,omitempty"` - OptionalDomainEntityType auth.DomainEntityType `json:"optional_domain_entity_type,omitempty"` - Operation auth.OperationType `json:"operation,omitempty"` - EntityIDs []string `json:"entity_ids,omitempty"` +type clearAllPATReq struct { + token string } -func (apser *addPatScopeEntryReq) UnmarshalJSON(data []byte) error { - var temp struct { - PlatformEntityType string `json:"platform_entity_type,omitempty"` - OptionalDomainID string `json:"optional_domain_id,omitempty"` - OptionalDomainEntityType string `json:"optional_domain_entity_type,omitempty"` - Operation string `json:"operation,omitempty"` - EntityIDs []string `json:"entity_ids,omitempty"` +func (req clearAllPATReq) validate() error { + if req.token == "" { + return apiutil.ErrBearerToken } + return nil +} - if err := json.Unmarshal(data, &temp); err != nil { - return err - } +type addScopeReq struct { + token string + id string + Scopes []auth.Scope `json:"scopes,omitempty"` +} - pet, err := auth.ParsePlatformEntityType(temp.PlatformEntityType) - if err != nil { - return err +func (aser *addScopeReq) UnmarshalJSON(data []byte) error { + type Alias addScopeReq + aux := &struct { + *Alias + }{ + Alias: (*Alias)(aser), } - odt, err := auth.ParseDomainEntityType(temp.OptionalDomainEntityType) - if err != nil { - return err - } - op, err := auth.ParseOperationType(temp.Operation) - if err != nil { + + if err := json.Unmarshal(data, aux); err != nil { return err } - apser.PlatformEntityType = pet - apser.OptionalDomainID = temp.OptionalDomainID - apser.OptionalDomainEntityType = odt - apser.Operation = op - apser.EntityIDs = temp.EntityIDs + return nil } -func (req addPatScopeEntryReq) validate() (err error) { +func (req addScopeReq) validate() (err error) { if req.token == "" { return apiutil.ErrBearerToken } if req.id == "" { - return apiutil.ErrMissingID + return apiutil.ErrMissingPATID } - return nil -} - -type removePatScopeEntryReq struct { - token string - id string - PlatformEntityType auth.PlatformEntityType `json:"platform_entity_type,omitempty"` - OptionalDomainID string `json:"optional_domain_id,omitempty"` - OptionalDomainEntityType auth.DomainEntityType `json:"optional_domain_entity_type,omitempty"` - Operation auth.OperationType `json:"operation,omitempty"` - EntityIDs []string `json:"entity_ids,omitempty"` -} -func (rpser *removePatScopeEntryReq) UnmarshalJSON(data []byte) error { - var temp struct { - PlatformEntityType string `json:"platform_entity_type,omitempty"` - OptionalDomainID string `json:"optional_domain_id,omitempty"` - OptionalDomainEntityType string `json:"optional_domain_entity_type,omitempty"` - Operation string `json:"operation,omitempty"` - EntityIDs []string `json:"entity_ids,omitempty"` + if len(req.Scopes) == 0 { + return apiutil.ErrValidation } - if err := json.Unmarshal(data, &temp); err != nil { - return err + for _, scope := range req.Scopes { + if err := scope.Validate(); err != nil { + return errors.Wrap(apiutil.ErrValidation, err) + } } - pet, err := auth.ParsePlatformEntityType(temp.PlatformEntityType) - if err != nil { - return err - } - odt, err := auth.ParseDomainEntityType(temp.OptionalDomainEntityType) - if err != nil { - return err - } - op, err := auth.ParseOperationType(temp.Operation) - if err != nil { - return err - } - rpser.PlatformEntityType = pet - rpser.OptionalDomainID = temp.OptionalDomainID - rpser.OptionalDomainEntityType = odt - rpser.Operation = op - rpser.EntityIDs = temp.EntityIDs return nil } -func (req removePatScopeEntryReq) validate() (err error) { +type removeScopeReq struct { + token string + id string + ScopesID []string `json:"scopes_id,omitempty"` +} + +func (req removeScopeReq) validate() (err error) { if req.token == "" { return apiutil.ErrBearerToken } if req.id == "" { - return apiutil.ErrMissingID + return apiutil.ErrMissingPATID + } + if len(req.ScopesID) == 0 { + return apiutil.ErrValidation } return nil } -type clearAllScopeEntryReq struct { +type clearAllScopeReq struct { token string id string } -func (req clearAllScopeEntryReq) validate() (err error) { +func (req clearAllScopeReq) validate() (err error) { if req.token == "" { return apiutil.ErrBearerToken } if req.id == "" { - return apiutil.ErrMissingID + return apiutil.ErrMissingPATID + } + return nil +} + +type listScopesReq struct { + token string + offset uint64 + limit uint64 + patID string +} + +func (req listScopesReq) validate() (err error) { + if req.token == "" { + return apiutil.ErrBearerToken + } + if req.patID == "" { + return apiutil.ErrMissingPATID } return nil } diff --git a/auth/api/http/pats/responses.go b/auth/api/http/pats/responses.go index fe47f63c71..c801ee9215 100644 --- a/auth/api/http/pats/responses.go +++ b/auth/api/http/pats/responses.go @@ -18,13 +18,12 @@ var ( _ supermq.Response = (*deletePatRes)(nil) _ supermq.Response = (*resetPatSecretRes)(nil) _ supermq.Response = (*revokePatSecretRes)(nil) - _ supermq.Response = (*addPatScopeEntryRes)(nil) - _ supermq.Response = (*removePatScopeEntryRes)(nil) - _ supermq.Response = (*clearAllScopeEntryRes)(nil) + _ supermq.Response = (*scopeRes)(nil) + _ supermq.Response = (*clearAllRes)(nil) ) type createPatRes struct { - auth.PAT + auth.PAT `json:",inline"` } func (res createPatRes) Code() int { @@ -40,7 +39,7 @@ func (res createPatRes) Empty() bool { } type retrievePatRes struct { - auth.PAT + auth.PAT `json:",inline"` } func (res retrievePatRes) Code() int { @@ -56,7 +55,7 @@ func (res retrievePatRes) Empty() bool { } type updatePatNameRes struct { - auth.PAT + auth.PAT `json:",inline"` } func (res updatePatNameRes) Code() int { @@ -72,7 +71,7 @@ func (res updatePatNameRes) Empty() bool { } type updatePatDescriptionRes struct { - auth.PAT + auth.PAT `json:",inline"` } func (res updatePatDescriptionRes) Code() int { @@ -88,7 +87,7 @@ func (res updatePatDescriptionRes) Empty() bool { } type listPatsRes struct { - auth.PATSPage + auth.PATSPage `json:",inline"` } func (res listPatsRes) Code() int { @@ -118,7 +117,7 @@ func (res deletePatRes) Empty() bool { } type resetPatSecretRes struct { - auth.PAT + auth.PAT `json:",inline"` } func (res resetPatSecretRes) Code() int { @@ -147,48 +146,46 @@ func (res revokePatSecretRes) Empty() bool { return true } -type addPatScopeEntryRes struct { - auth.Scope -} +type scopeRes struct{} -func (res addPatScopeEntryRes) Code() int { +func (res scopeRes) Code() int { return http.StatusOK } -func (res addPatScopeEntryRes) Headers() map[string]string { +func (res scopeRes) Headers() map[string]string { return map[string]string{} } -func (res addPatScopeEntryRes) Empty() bool { - return false +func (res scopeRes) Empty() bool { + return true } -type removePatScopeEntryRes struct { - auth.Scope -} +type clearAllRes struct{} -func (res removePatScopeEntryRes) Code() int { +func (res clearAllRes) Code() int { return http.StatusOK } -func (res removePatScopeEntryRes) Headers() map[string]string { +func (res clearAllRes) Headers() map[string]string { return map[string]string{} } -func (res removePatScopeEntryRes) Empty() bool { - return false +func (res clearAllRes) Empty() bool { + return true } -type clearAllScopeEntryRes struct{} +type listScopeRes struct { + auth.ScopesPage `json:",inline"` +} -func (res clearAllScopeEntryRes) Code() int { +func (res listScopeRes) Code() int { return http.StatusOK } -func (res clearAllScopeEntryRes) Headers() map[string]string { +func (res listScopeRes) Headers() map[string]string { return map[string]string{} } -func (res clearAllScopeEntryRes) Empty() bool { - return true +func (res listScopeRes) Empty() bool { + return false } diff --git a/auth/api/http/pats/transport.go b/auth/api/http/pats/transport.go index 910bf323be..a7bf41d222 100644 --- a/auth/api/http/pats/transport.go +++ b/auth/api/http/pats/transport.go @@ -44,6 +44,13 @@ func MakeHandler(svc auth.Service, mux *chi.Mux, logger *slog.Logger) *chi.Mux { opts..., ).ServeHTTP) + r.Delete("/", kithttp.NewServer( + clearAllPATEndpoint(svc), + decodeClearAllPATRequest, + api.EncodeResponse, + opts..., + ).ServeHTTP) + r.Route("/{id}", func(r chi.Router) { r.Get("/", kithttp.NewServer( retrievePATEndpoint(svc), @@ -91,22 +98,29 @@ func MakeHandler(svc auth.Service, mux *chi.Mux, logger *slog.Logger) *chi.Mux { r.Route("/scope", func(r chi.Router) { r.Patch("/add", kithttp.NewServer( - addPATScopeEntryEndpoint(svc), - decodeAddPATScopeEntryRequest, + addScopeEndpoint(svc), + decodeAddScopeRequest, + api.EncodeResponse, + opts..., + ).ServeHTTP) + + r.Get("/", kithttp.NewServer( + listScopesEndpoint(svc), + decodeListScopeRequest, api.EncodeResponse, opts..., ).ServeHTTP) r.Patch("/remove", kithttp.NewServer( - removePATScopeEntryEndpoint(svc), - decodeRemovePATScopeEntryRequest, + removeScopeEndpoint(svc), + decodeRemoveScopeRequest, api.EncodeResponse, opts..., ).ServeHTTP) r.Delete("/", kithttp.NewServer( - clearPATAllScopeEntryEndpoint(svc), - decodeClearPATAllScopeEntryRequest, + clearAllScopeEndpoint(svc), + decodeClearAllScopeRequest, api.EncodeResponse, opts..., ).ServeHTTP) @@ -243,7 +257,18 @@ func decodeRevokePATSecretRequest(_ context.Context, r *http.Request) (interface }, nil } -func decodeAddPATScopeEntryRequest(_ context.Context, r *http.Request) (interface{}, error) { +func decodeClearAllPATRequest(_ context.Context, r *http.Request) (interface{}, error) { + token := apiutil.ExtractBearerToken(r) + if strings.HasPrefix(token, patPrefix) { + return nil, apiutil.ErrUnsupportedTokenType + } + + return clearAllPATReq{ + token: token, + }, nil +} + +func decodeAddScopeRequest(_ context.Context, r *http.Request) (interface{}, error) { if !strings.Contains(r.Header.Get("Content-Type"), contentType) { return nil, apiutil.ErrUnsupportedContentType } @@ -253,17 +278,41 @@ func decodeAddPATScopeEntryRequest(_ context.Context, r *http.Request) (interfac return nil, apiutil.ErrUnsupportedTokenType } - req := addPatScopeEntryReq{ + req := addScopeReq{ token: token, id: chi.URLParam(r, "id"), } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { return nil, errors.Wrap(errors.ErrMalformedEntity, err) } + + return req, nil +} + +func decodeListScopeRequest(_ context.Context, r *http.Request) (interface{}, error) { + l, err := apiutil.ReadNumQuery[uint64](r, api.LimitKey, api.DefLimit) + if err != nil { + return nil, errors.Wrap(apiutil.ErrValidation, err) + } + o, err := apiutil.ReadNumQuery[uint64](r, api.OffsetKey, api.DefOffset) + if err != nil { + return nil, errors.Wrap(apiutil.ErrValidation, err) + } + token := apiutil.ExtractBearerToken(r) + if strings.HasPrefix(token, patPrefix) { + return nil, apiutil.ErrUnsupportedTokenType + } + req := listScopesReq{ + token: token, + limit: l, + offset: o, + patID: chi.URLParam(r, "id"), + } return req, nil } -func decodeRemovePATScopeEntryRequest(_ context.Context, r *http.Request) (interface{}, error) { +func decodeRemoveScopeRequest(_ context.Context, r *http.Request) (interface{}, error) { if !strings.Contains(r.Header.Get("Content-Type"), contentType) { return nil, apiutil.ErrUnsupportedContentType } @@ -273,7 +322,7 @@ func decodeRemovePATScopeEntryRequest(_ context.Context, r *http.Request) (inter return nil, apiutil.ErrUnsupportedTokenType } - req := removePatScopeEntryReq{ + req := removeScopeReq{ token: token, id: chi.URLParam(r, "id"), } @@ -283,17 +332,13 @@ func decodeRemovePATScopeEntryRequest(_ context.Context, r *http.Request) (inter return req, nil } -func decodeClearPATAllScopeEntryRequest(_ context.Context, r *http.Request) (interface{}, error) { - if !strings.Contains(r.Header.Get("Content-Type"), contentType) { - return nil, apiutil.ErrUnsupportedContentType - } - +func decodeClearAllScopeRequest(_ context.Context, r *http.Request) (interface{}, error) { token := apiutil.ExtractBearerToken(r) if strings.HasPrefix(token, patPrefix) { return nil, apiutil.ErrUnsupportedTokenType } - return clearAllScopeEntryReq{ + return clearAllScopeReq{ token: token, id: chi.URLParam(r, "id"), }, nil diff --git a/auth/api/logging.go b/auth/api/logging.go index 94bacfd3d7..a92450bee6 100644 --- a/auth/api/logging.go +++ b/auth/api/logging.go @@ -125,14 +125,13 @@ func (lm *loggingMiddleware) Authorize(ctx context.Context, pr policies.Policy) return lm.svc.Authorize(ctx, pr) } -func (lm *loggingMiddleware) CreatePAT(ctx context.Context, token, name, description string, duration time.Duration, scope auth.Scope) (pa auth.PAT, err error) { +func (lm *loggingMiddleware) CreatePAT(ctx context.Context, token, name, description string, duration time.Duration) (pa auth.PAT, err error) { defer func(begin time.Time) { args := []any{ slog.String("duration", time.Since(begin).String()), slog.String("name", name), slog.String("description", description), slog.String("pat_duration", duration.String()), - slog.String("scope", scope.String()), } if err != nil { args = append(args, slog.Any("error", err)) @@ -141,7 +140,7 @@ func (lm *loggingMiddleware) CreatePAT(ctx context.Context, token, name, descrip } lm.logger.Info("Create PAT completed successfully", args...) }(time.Now()) - return lm.svc.CreatePAT(ctx, token, name, description, duration, scope) + return lm.svc.CreatePAT(ctx, token, name, description, duration) } func (lm *loggingMiddleware) UpdatePATName(ctx context.Context, token, patID, name string) (pa auth.PAT, err error) { @@ -211,6 +210,24 @@ func (lm *loggingMiddleware) ListPATS(ctx context.Context, token string, pm auth return lm.svc.ListPATS(ctx, token, pm) } +func (lm *loggingMiddleware) ListScopes(ctx context.Context, token string, pm auth.ScopesPageMeta) (pp auth.ScopesPage, err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.Uint64("limit", pm.Limit), + slog.Uint64("offset", pm.Offset), + slog.String("pat_id", pm.PatID), + } + if err != nil { + args = append(args, slog.Any("error", err)) + lm.logger.Warn("List Scopes failed", args...) + return + } + lm.logger.Info("List Scopes completed successfully", args...) + }(time.Now()) + return lm.svc.ListScopes(ctx, token, pm) +} + func (lm *loggingMiddleware) DeletePAT(ctx context.Context, token, patID string) (err error) { defer func(begin time.Time) { args := []any{ @@ -260,117 +277,113 @@ func (lm *loggingMiddleware) RevokePATSecret(ctx context.Context, token, patID s return lm.svc.RevokePATSecret(ctx, token, patID) } -func (lm *loggingMiddleware) AddPATScopeEntry(ctx context.Context, token, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (sc auth.Scope, err error) { +func (lm *loggingMiddleware) RemoveAllPAT(ctx context.Context, token string) (err error) { defer func(begin time.Time) { args := []any{ slog.String("duration", time.Since(begin).String()), - slog.String("pat_id", patID), - slog.String("platform_entity_type", platformEntityType.String()), - slog.String("optional_domain_id", optionalDomainID), - slog.String("optional_domain_entity_type", optionalDomainEntityType.String()), - slog.String("operation", operation.String()), - slog.Any("entities", entityIDs), } if err != nil { args = append(args, slog.Any("error", err)) - lm.logger.Warn("Add entry to PAT scope failed", args...) + lm.logger.Warn("Remove all PAT failed", args...) return } - lm.logger.Info("Add entry to PAT scope completed successfully", args...) + lm.logger.Info("Remove all of PAT completed successfully", args...) }(time.Now()) - return lm.svc.AddPATScopeEntry(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) + return lm.svc.RemoveAllPAT(ctx, token) } -func (lm *loggingMiddleware) RemovePATScopeEntry(ctx context.Context, token, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (sc auth.Scope, err error) { +func (lm *loggingMiddleware) AddScope(ctx context.Context, token, patID string, scopes []auth.Scope) (err error) { defer func(begin time.Time) { + var groupArgs []any + for _, s := range scopes { + groupArgs = append(groupArgs, slog.String("entity_type", s.EntityType.String())) + groupArgs = append(groupArgs, slog.String("optional_domain_id", s.OptionalDomainID)) + groupArgs = append(groupArgs, slog.String("operation", s.Operation.String())) + groupArgs = append(groupArgs, slog.String("entity_id", s.EntityID)) + } + args := []any{ slog.String("duration", time.Since(begin).String()), slog.String("pat_id", patID), - slog.String("platform_entity_type", platformEntityType.String()), - slog.String("optional_domain_id", optionalDomainID), - slog.String("optional_domain_entity_type", optionalDomainEntityType.String()), - slog.String("operation", operation.String()), - slog.Any("entities", entityIDs), + slog.Group("scope", groupArgs...), } if err != nil { args = append(args, slog.Any("error", err)) - lm.logger.Warn("Remove entry from PAT scope failed", args...) + lm.logger.Warn("Add PAT scope failed", args...) return } - lm.logger.Info("Remove entry from PAT scope completed successfully", args...) + lm.logger.Info("Add PAT scope completed successfully", args...) }(time.Now()) - return lm.svc.RemovePATScopeEntry(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) + return lm.svc.AddScope(ctx, token, patID, scopes) } -func (lm *loggingMiddleware) ClearPATAllScopeEntry(ctx context.Context, token, patID string) (err error) { +func (lm *loggingMiddleware) RemoveScope(ctx context.Context, token, patID string, scopesID ...string) (err error) { defer func(begin time.Time) { + var groupArgs []any + for _, s := range scopesID { + groupArgs = append(groupArgs, slog.String("scope_id", s)) + } args := []any{ slog.String("duration", time.Since(begin).String()), slog.String("pat_id", patID), + slog.Group("scope", groupArgs...), } if err != nil { args = append(args, slog.Any("error", err)) - lm.logger.Warn("Clear all entry from PAT scope failed", args...) + lm.logger.Warn("Remove entry from PAT scope failed", args...) return } - lm.logger.Info("Clear all entry from PAT scope completed successfully", args...) + lm.logger.Info("Remove entry from PAT scope completed successfully", args...) }(time.Now()) - return lm.svc.ClearPATAllScopeEntry(ctx, token, patID) + return lm.svc.RemoveScope(ctx, token, patID, scopesID...) } -func (lm *loggingMiddleware) IdentifyPAT(ctx context.Context, paToken string) (pa auth.PAT, err error) { +func (lm *loggingMiddleware) RemovePATAllScope(ctx context.Context, token, patID string) (err error) { defer func(begin time.Time) { args := []any{ slog.String("duration", time.Since(begin).String()), + slog.String("pat_id", patID), } if err != nil { args = append(args, slog.Any("error", err)) - lm.logger.Warn("Identify PAT failed", args...) + lm.logger.Warn("Remove all scopes from PAT failed", args...) return } - lm.logger.Info("Identify PAT completed successfully", args...) + lm.logger.Info("Remove all scopes from PAT completed successfully", args...) }(time.Now()) - return lm.svc.IdentifyPAT(ctx, paToken) + return lm.svc.RemovePATAllScope(ctx, token, patID) } -func (lm *loggingMiddleware) AuthorizePAT(ctx context.Context, userID, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (err error) { +func (lm *loggingMiddleware) IdentifyPAT(ctx context.Context, paToken string) (pa auth.PAT, err error) { defer func(begin time.Time) { args := []any{ slog.String("duration", time.Since(begin).String()), - slog.String("platform_entity_type", platformEntityType.String()), - slog.String("optional_domain_id", optionalDomainID), - slog.String("optional_domain_entity_type", optionalDomainEntityType.String()), - slog.String("operation", operation.String()), - slog.Any("entities", entityIDs), } if err != nil { args = append(args, slog.Any("error", err)) - lm.logger.Warn("Authorize PAT failed complete successfully", args...) + lm.logger.Warn("Identify PAT failed", args...) return } - lm.logger.Info("Authorize PAT completed successfully", args...) + lm.logger.Info("Identify PAT completed successfully", args...) }(time.Now()) - return lm.svc.AuthorizePAT(ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) + return lm.svc.IdentifyPAT(ctx, paToken) } -func (lm *loggingMiddleware) CheckPAT(ctx context.Context, userID, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (err error) { +func (lm *loggingMiddleware) AuthorizePAT(ctx context.Context, userID, patID string, entityType auth.EntityType, optionalDomainID string, operation auth.Operation, entityID string) (err error) { defer func(begin time.Time) { args := []any{ slog.String("duration", time.Since(begin).String()), - slog.String("user_id", userID), - slog.String("pat_id", patID), - slog.String("platform_entity_type", platformEntityType.String()), + slog.String("entity_type", entityType.String()), slog.String("optional_domain_id", optionalDomainID), - slog.String("optional_domain_entity_type", optionalDomainEntityType.String()), slog.String("operation", operation.String()), - slog.Any("entities", entityIDs), + slog.String("entities", entityID), } if err != nil { args = append(args, slog.Any("error", err)) - lm.logger.Warn("Check PAT failed complete successfully", args...) + lm.logger.Warn("Authorize PAT failed complete successfully", args...) return } - lm.logger.Info("Check PAT completed successfully", args...) + lm.logger.Info("Authorize PAT completed successfully", args...) }(time.Now()) - return lm.svc.CheckPAT(ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) + return lm.svc.AuthorizePAT(ctx, userID, patID, entityType, optionalDomainID, operation, entityID) } diff --git a/auth/api/metrics.go b/auth/api/metrics.go index 081165b3ff..a264fe15b0 100644 --- a/auth/api/metrics.go +++ b/auth/api/metrics.go @@ -75,12 +75,12 @@ func (ms *metricsMiddleware) Authorize(ctx context.Context, pr policies.Policy) return ms.svc.Authorize(ctx, pr) } -func (ms *metricsMiddleware) CreatePAT(ctx context.Context, token, name, description string, duration time.Duration, scope auth.Scope) (auth.PAT, error) { +func (ms *metricsMiddleware) CreatePAT(ctx context.Context, token, name, description string, duration time.Duration) (auth.PAT, error) { defer func(begin time.Time) { ms.counter.With("method", "create_pat").Add(1) ms.latency.With("method", "create_pat").Observe(time.Since(begin).Seconds()) }(time.Now()) - return ms.svc.CreatePAT(ctx, token, name, description, duration, scope) + return ms.svc.CreatePAT(ctx, token, name, description, duration) } func (ms *metricsMiddleware) UpdatePATName(ctx context.Context, token, patID, name string) (auth.PAT, error) { @@ -115,6 +115,14 @@ func (ms *metricsMiddleware) ListPATS(ctx context.Context, token string, pm auth return ms.svc.ListPATS(ctx, token, pm) } +func (ms *metricsMiddleware) ListScopes(ctx context.Context, token string, pm auth.ScopesPageMeta) (auth.ScopesPage, error) { + defer func(begin time.Time) { + ms.counter.With("method", "list_scopes").Add(1) + ms.latency.With("method", "list_scopes").Observe(time.Since(begin).Seconds()) + }(time.Now()) + return ms.svc.ListScopes(ctx, token, pm) +} + func (ms *metricsMiddleware) DeletePAT(ctx context.Context, token, patID string) error { defer func(begin time.Time) { ms.counter.With("method", "delete_pat").Add(1) @@ -139,28 +147,36 @@ func (ms *metricsMiddleware) RevokePATSecret(ctx context.Context, token, patID s return ms.svc.RevokePATSecret(ctx, token, patID) } -func (ms *metricsMiddleware) AddPATScopeEntry(ctx context.Context, token, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (auth.Scope, error) { +func (ms *metricsMiddleware) RemoveAllPAT(ctx context.Context, token string) error { + defer func(begin time.Time) { + ms.counter.With("method", "clear_all_pat").Add(1) + ms.latency.With("method", "clear_all_pat").Observe(time.Since(begin).Seconds()) + }(time.Now()) + return ms.svc.RemoveAllPAT(ctx, token) +} + +func (ms *metricsMiddleware) AddScope(ctx context.Context, token, patID string, scopes []auth.Scope) error { defer func(begin time.Time) { - ms.counter.With("method", "add_pat_scope_entry").Add(1) - ms.latency.With("method", "add_pat_scope_entry").Observe(time.Since(begin).Seconds()) + ms.counter.With("method", "add_pat_scope").Add(1) + ms.latency.With("method", "add_pat_scope").Observe(time.Since(begin).Seconds()) }(time.Now()) - return ms.svc.AddPATScopeEntry(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) + return ms.svc.AddScope(ctx, token, patID, scopes) } -func (ms *metricsMiddleware) RemovePATScopeEntry(ctx context.Context, token, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (auth.Scope, error) { +func (ms *metricsMiddleware) RemoveScope(ctx context.Context, token, patID string, scopesID ...string) error { defer func(begin time.Time) { - ms.counter.With("method", "remove_pat_scope_entry").Add(1) - ms.latency.With("method", "remove_pat_scope_entry").Observe(time.Since(begin).Seconds()) + ms.counter.With("method", "remove_pat_scope").Add(1) + ms.latency.With("method", "remove_pat_scope").Observe(time.Since(begin).Seconds()) }(time.Now()) - return ms.svc.RemovePATScopeEntry(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) + return ms.svc.RemoveScope(ctx, token, patID, scopesID...) } -func (ms *metricsMiddleware) ClearPATAllScopeEntry(ctx context.Context, token, patID string) error { +func (ms *metricsMiddleware) RemovePATAllScope(ctx context.Context, token, patID string) error { defer func(begin time.Time) { - ms.counter.With("method", "clear_pat_all_scope_entry").Add(1) - ms.latency.With("method", "clear_pat_all_scope_entry").Observe(time.Since(begin).Seconds()) + ms.counter.With("method", "clear_pat_all_scope").Add(1) + ms.latency.With("method", "clear_pat_all_scope").Observe(time.Since(begin).Seconds()) }(time.Now()) - return ms.svc.ClearPATAllScopeEntry(ctx, token, patID) + return ms.svc.RemovePATAllScope(ctx, token, patID) } func (ms *metricsMiddleware) IdentifyPAT(ctx context.Context, paToken string) (auth.PAT, error) { @@ -171,18 +187,10 @@ func (ms *metricsMiddleware) IdentifyPAT(ctx context.Context, paToken string) (a return ms.svc.IdentifyPAT(ctx, paToken) } -func (ms *metricsMiddleware) AuthorizePAT(ctx context.Context, userID, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) error { +func (ms *metricsMiddleware) AuthorizePAT(ctx context.Context, userID, patID string, entityType auth.EntityType, optionalDomainID string, operation auth.Operation, entityID string) error { defer func(begin time.Time) { ms.counter.With("method", "authorize_pat").Add(1) ms.latency.With("method", "authorize_pat").Observe(time.Since(begin).Seconds()) }(time.Now()) - return ms.svc.AuthorizePAT(ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) -} - -func (ms *metricsMiddleware) CheckPAT(ctx context.Context, userID, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) error { - defer func(begin time.Time) { - ms.counter.With("method", "check_pat").Add(1) - ms.latency.With("method", "check_pat").Observe(time.Since(begin).Seconds()) - }(time.Now()) - return ms.svc.CheckPAT(ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) + return ms.svc.AuthorizePAT(ctx, userID, patID, entityType, optionalDomainID, operation, entityID) } diff --git a/auth/bolt/doc.go b/auth/bolt/doc.go deleted file mode 100644 index dcd06ac566..0000000000 --- a/auth/bolt/doc.go +++ /dev/null @@ -1,6 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -// Package bolt contains PAT repository implementations using -// bolt as the underlying database. -package bolt diff --git a/auth/bolt/init.go b/auth/bolt/init.go deleted file mode 100644 index 2be5977dfe..0000000000 --- a/auth/bolt/init.go +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -// Package bolt contains PAT repository implementations using -// bolt as the underlying database. -package bolt - -import ( - "github.com/absmach/supermq/pkg/errors" - bolt "go.etcd.io/bbolt" -) - -var errInit = errors.New("failed to initialize BoltDB") - -func Init(tx *bolt.Tx, bucket string) error { - _, err := tx.CreateBucketIfNotExists([]byte(bucket)) - if err != nil { - return errors.Wrap(errInit, err) - } - return nil -} diff --git a/auth/bolt/pat.go b/auth/bolt/pat.go deleted file mode 100644 index e16f005842..0000000000 --- a/auth/bolt/pat.go +++ /dev/null @@ -1,812 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -package bolt - -import ( - "bytes" - "context" - "encoding/binary" - "fmt" - "strings" - "time" - - "github.com/absmach/supermq/auth" - "github.com/absmach/supermq/pkg/errors" - repoerr "github.com/absmach/supermq/pkg/errors/repository" - bolt "go.etcd.io/bbolt" -) - -const ( - idKey = "id" - userKey = "user" - nameKey = "name" - descriptionKey = "description" - secretKey = "secret_key" - scopeKey = "scope" - issuedAtKey = "issued_at" - expiresAtKey = "expires_at" - updatedAtKey = "updated_at" - lastUsedAtKey = "last_used_at" - revokedKey = "revoked" - revokedAtKey = "revoked_at" - platformEntitiesKey = "platform_entities" - patKey = "pat" - - keySeparator = ":" - anyID = "*" -) - -var ( - activateValue = []byte{0x00} - revokedValue = []byte{0x01} - entityValue = []byte{0x02} - anyIDValue = []byte{0x03} - selectedIDsValue = []byte{0x04} -) - -type patRepo struct { - db *bolt.DB - bucketName string -} - -// NewPATSRepository instantiates a bolt -// implementation of PAT repository. -func NewPATSRepository(db *bolt.DB, bucketName string) auth.PATSRepository { - return &patRepo{ - db: db, - bucketName: bucketName, - } -} - -func (pr *patRepo) Save(ctx context.Context, pat auth.PAT) error { - idxKey := []byte(pat.User + keySeparator + patKey + keySeparator + pat.ID) - kv, err := patToKeyValue(pat) - if err != nil { - return err - } - return pr.db.Update(func(tx *bolt.Tx) error { - rootBucket, err := pr.retrieveRootBucket(tx) - if err != nil { - return errors.Wrap(repoerr.ErrCreateEntity, err) - } - b, err := pr.createUserBucket(rootBucket, pat.User) - if err != nil { - return errors.Wrap(repoerr.ErrCreateEntity, err) - } - for key, value := range kv { - fullKey := []byte(pat.ID + keySeparator + key) - if err := b.Put(fullKey, value); err != nil { - return errors.Wrap(repoerr.ErrCreateEntity, err) - } - } - if err := rootBucket.Put(idxKey, []byte(pat.ID)); err != nil { - return errors.Wrap(repoerr.ErrCreateEntity, err) - } - return nil - }) -} - -func (pr *patRepo) Retrieve(ctx context.Context, userID, patID string) (auth.PAT, error) { - prefix := []byte(patID + keySeparator) - kv := map[string][]byte{} - if err := pr.db.View(func(tx *bolt.Tx) error { - b, err := pr.retrieveUserBucket(tx, userID, patID, repoerr.ErrViewEntity) - if err != nil { - return err - } - c := b.Cursor() - for k, v := c.Seek(prefix); k != nil && bytes.HasPrefix(k, prefix); k, v = c.Next() { - kv[string(k)] = v - } - return nil - }); err != nil { - return auth.PAT{}, err - } - - return keyValueToPAT(kv) -} - -func (pr *patRepo) RetrieveSecretAndRevokeStatus(ctx context.Context, userID, patID string) (string, bool, bool, error) { - revoked := true - expired := false - keySecret := patID + keySeparator + secretKey - keyRevoked := patID + keySeparator + revokedKey - keyExpiresAt := patID + keySeparator + expiresAtKey - var secretHash string - if err := pr.db.View(func(tx *bolt.Tx) error { - b, err := pr.retrieveUserBucket(tx, userID, patID, repoerr.ErrViewEntity) - if err != nil { - return err - } - secretHash = string(b.Get([]byte(keySecret))) - revoked = bytesToBoolean(b.Get([]byte(keyRevoked))) - expiresAt := bytesToTime(b.Get([]byte(keyExpiresAt))) - expired = time.Now().After(expiresAt) - return nil - }); err != nil { - return "", true, true, err - } - return secretHash, revoked, expired, nil -} - -func (pr *patRepo) UpdateName(ctx context.Context, userID, patID, name string) (auth.PAT, error) { - return pr.updatePATField(ctx, userID, patID, nameKey, []byte(name)) -} - -func (pr *patRepo) UpdateDescription(ctx context.Context, userID, patID, description string) (auth.PAT, error) { - return pr.updatePATField(ctx, userID, patID, descriptionKey, []byte(description)) -} - -func (pr *patRepo) UpdateTokenHash(ctx context.Context, userID, patID, tokenHash string, expiryAt time.Time) (auth.PAT, error) { - prefix := []byte(patID + keySeparator) - kv := map[string][]byte{} - if err := pr.db.Update(func(tx *bolt.Tx) error { - b, err := pr.retrieveUserBucket(tx, userID, patID, repoerr.ErrUpdateEntity) - if err != nil { - return err - } - if err := b.Put([]byte(patID+keySeparator+secretKey), []byte(tokenHash)); err != nil { - return errors.Wrap(repoerr.ErrUpdateEntity, err) - } - if err := b.Put([]byte(patID+keySeparator+expiresAtKey), timeToBytes(expiryAt)); err != nil { - return errors.Wrap(repoerr.ErrUpdateEntity, err) - } - if err := b.Put([]byte(patID+keySeparator+updatedAtKey), timeToBytes(time.Now())); err != nil { - return errors.Wrap(repoerr.ErrUpdateEntity, err) - } - c := b.Cursor() - for k, v := c.Seek(prefix); k != nil && bytes.HasPrefix(k, prefix); k, v = c.Next() { - kv[string(k)] = v - } - return nil - }); err != nil { - return auth.PAT{}, err - } - return keyValueToPAT(kv) -} - -func (pr *patRepo) RetrieveAll(ctx context.Context, userID string, pm auth.PATSPageMeta) (auth.PATSPage, error) { - prefix := []byte(userID + keySeparator + patKey + keySeparator) - - patIDs := []string{} - if err := pr.db.View(func(tx *bolt.Tx) error { - b, err := pr.retrieveRootBucket(tx) - if err != nil { - return errors.Wrap(repoerr.ErrViewEntity, err) - } - c := b.Cursor() - for k, v := c.Seek(prefix); k != nil && bytes.HasPrefix(k, prefix); k, v = c.Next() { - if v != nil { - patIDs = append(patIDs, string(v)) - } - } - return nil - }); err != nil { - return auth.PATSPage{}, err - } - - total := len(patIDs) - - var pats []auth.PAT - - patsPage := auth.PATSPage{ - Total: uint64(total), - Limit: pm.Limit, - Offset: pm.Offset, - PATS: pats, - } - - if int(pm.Offset) >= total { - return patsPage, nil - } - - aLimit := pm.Limit - if rLimit := total - int(pm.Offset); int(pm.Limit) > rLimit { - aLimit = uint64(rLimit) - } - - for i := pm.Offset; i < pm.Offset+aLimit; i++ { - if int(i) < total { - pat, err := pr.Retrieve(ctx, userID, patIDs[i]) - if err != nil { - return patsPage, err - } - patsPage.PATS = append(patsPage.PATS, pat) - } - } - - return patsPage, nil -} - -func (pr *patRepo) Revoke(ctx context.Context, userID, patID string) error { - if err := pr.db.Update(func(tx *bolt.Tx) error { - b, err := pr.retrieveUserBucket(tx, userID, patID, repoerr.ErrUpdateEntity) - if err != nil { - return err - } - if err := b.Put([]byte(patID+keySeparator+revokedKey), revokedValue); err != nil { - return errors.Wrap(repoerr.ErrUpdateEntity, err) - } - if err := b.Put([]byte(patID+keySeparator+revokedAtKey), timeToBytes(time.Now())); err != nil { - return errors.Wrap(repoerr.ErrUpdateEntity, err) - } - return nil - }); err != nil { - return err - } - return nil -} - -func (pr *patRepo) Reactivate(ctx context.Context, userID, patID string) error { - if err := pr.db.Update(func(tx *bolt.Tx) error { - b, err := pr.retrieveUserBucket(tx, userID, patID, repoerr.ErrUpdateEntity) - if err != nil { - return err - } - if err := b.Put([]byte(patID+keySeparator+revokedKey), activateValue); err != nil { - return errors.Wrap(repoerr.ErrUpdateEntity, err) - } - if err := b.Put([]byte(patID+keySeparator+revokedAtKey), []byte{}); err != nil { - return errors.Wrap(repoerr.ErrUpdateEntity, err) - } - return nil - }); err != nil { - return err - } - return nil -} - -func (pr *patRepo) Remove(ctx context.Context, userID, patID string) error { - prefix := []byte(patID + keySeparator) - idxKey := []byte(userID + keySeparator + patKey + keySeparator + patID) - if err := pr.db.Update(func(tx *bolt.Tx) error { - b, err := pr.retrieveUserBucket(tx, userID, patID, repoerr.ErrRemoveEntity) - if err != nil { - return err - } - c := b.Cursor() - for k, _ := c.Seek(prefix); k != nil && bytes.HasPrefix(k, prefix); k, _ = c.Next() { - if err := b.Delete(k); err != nil { - return errors.Wrap(repoerr.ErrRemoveEntity, err) - } - } - rb, err := pr.retrieveRootBucket(tx) - if err != nil { - return err - } - if err := rb.Delete(idxKey); err != nil { - return errors.Wrap(repoerr.ErrRemoveEntity, err) - } - return nil - }); err != nil { - return err - } - - return nil -} - -func (pr *patRepo) AddScopeEntry(ctx context.Context, userID, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (auth.Scope, error) { - prefix := []byte(patID + keySeparator + scopeKey) - rKV := make(map[string][]byte) - if err := pr.db.Update(func(tx *bolt.Tx) error { - b, err := pr.retrieveUserBucket(tx, userID, patID, repoerr.ErrCreateEntity) - if err != nil { - return err - } - kv, err := scopeEntryToKeyValue(platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) - if err != nil { - return err - } - for key, value := range kv { - fullKey := []byte(patID + keySeparator + key) - if err := b.Put(fullKey, value); err != nil { - return errors.Wrap(repoerr.ErrCreateEntity, err) - } - } - - c := b.Cursor() - for k, v := c.Seek(prefix); k != nil && bytes.HasPrefix(k, prefix); k, v = c.Next() { - rKV[string(k)] = v - } - return nil - }); err != nil { - return auth.Scope{}, err - } - - return parseKeyValueToScope(rKV) -} - -func (pr *patRepo) RemoveScopeEntry(ctx context.Context, userID, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (auth.Scope, error) { - if len(entityIDs) == 0 { - return auth.Scope{}, repoerr.ErrMalformedEntity - } - prefix := []byte(patID + keySeparator + scopeKey) - rKV := make(map[string][]byte) - if err := pr.db.Update(func(tx *bolt.Tx) error { - b, err := pr.retrieveUserBucket(tx, userID, patID, repoerr.ErrRemoveEntity) - if err != nil { - return err - } - kv, err := scopeEntryToKeyValue(platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) - if err != nil { - return err - } - for key := range kv { - fullKey := []byte(patID + keySeparator + key) - if err := b.Delete(fullKey); err != nil { - return errors.Wrap(repoerr.ErrRemoveEntity, err) - } - } - c := b.Cursor() - for k, v := c.Seek(prefix); k != nil && bytes.HasPrefix(k, prefix); k, v = c.Next() { - rKV[string(k)] = v - } - return nil - }); err != nil { - return auth.Scope{}, err - } - return parseKeyValueToScope(rKV) -} - -func (pr *patRepo) CheckScopeEntry(ctx context.Context, userID, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) error { - return pr.db.Update(func(tx *bolt.Tx) error { - b, err := pr.retrieveUserBucket(tx, userID, patID, repoerr.ErrViewEntity) - if err != nil { - return errors.Wrap(repoerr.ErrViewEntity, err) - } - srootKey, err := scopeRootKey(platformEntityType, optionalDomainID, optionalDomainEntityType, operation) - if err != nil { - return errors.Wrap(repoerr.ErrViewEntity, err) - } - - rootKey := patID + keySeparator + srootKey - if value := b.Get([]byte(rootKey)); bytes.Equal(value, anyIDValue) { - return nil - } - for _, entity := range entityIDs { - value := b.Get([]byte(rootKey + keySeparator + entity)) - if !bytes.Equal(value, entityValue) { - return repoerr.ErrNotFound - } - } - return nil - }) -} - -func (pr *patRepo) RemoveAllScopeEntry(ctx context.Context, userID, patID string) error { - return nil -} - -func (pr *patRepo) updatePATField(_ context.Context, userID, patID, key string, value []byte) (auth.PAT, error) { - prefix := []byte(patID + keySeparator) - kv := map[string][]byte{} - if err := pr.db.Update(func(tx *bolt.Tx) error { - b, err := pr.retrieveUserBucket(tx, userID, patID, repoerr.ErrUpdateEntity) - if err != nil { - return err - } - if err := b.Put([]byte(patID+keySeparator+key), value); err != nil { - return errors.Wrap(repoerr.ErrUpdateEntity, err) - } - if err := b.Put([]byte(patID+keySeparator+updatedAtKey), timeToBytes(time.Now())); err != nil { - return errors.Wrap(repoerr.ErrUpdateEntity, err) - } - c := b.Cursor() - for k, v := c.Seek(prefix); k != nil && bytes.HasPrefix(k, prefix); k, v = c.Next() { - kv[string(k)] = v - } - return nil - }); err != nil { - return auth.PAT{}, err - } - return keyValueToPAT(kv) -} - -func (pr *patRepo) createUserBucket(rootBucket *bolt.Bucket, userID string) (*bolt.Bucket, error) { - userBucket, err := rootBucket.CreateBucketIfNotExists([]byte(userID)) - if err != nil { - return nil, errors.Wrap(repoerr.ErrCreateEntity, fmt.Errorf("failed to retrieve or create bucket for user %s : %w", userID, err)) - } - - return userBucket, nil -} - -func (pr *patRepo) retrieveUserBucket(tx *bolt.Tx, userID, patID string, wrap error) (*bolt.Bucket, error) { - rootBucket, err := pr.retrieveRootBucket(tx) - if err != nil { - return nil, errors.Wrap(wrap, err) - } - - vPatID := rootBucket.Get([]byte(userID + keySeparator + patKey + keySeparator + patID)) - if vPatID == nil { - return nil, repoerr.ErrNotFound - } - - userBucket := rootBucket.Bucket([]byte(userID)) - if userBucket == nil { - return nil, errors.Wrap(wrap, fmt.Errorf("user %s not found", userID)) - } - return userBucket, nil -} - -func (pr *patRepo) retrieveRootBucket(tx *bolt.Tx) (*bolt.Bucket, error) { - rootBucket := tx.Bucket([]byte(pr.bucketName)) - if rootBucket == nil { - return nil, fmt.Errorf("bucket %s not found", pr.bucketName) - } - return rootBucket, nil -} - -func patToKeyValue(pat auth.PAT) (map[string][]byte, error) { - kv := map[string][]byte{ - idKey: []byte(pat.ID), - userKey: []byte(pat.User), - nameKey: []byte(pat.Name), - descriptionKey: []byte(pat.Description), - secretKey: []byte(pat.Secret), - issuedAtKey: timeToBytes(pat.IssuedAt), - expiresAtKey: timeToBytes(pat.ExpiresAt), - updatedAtKey: timeToBytes(pat.UpdatedAt), - lastUsedAtKey: timeToBytes(pat.LastUsedAt), - revokedKey: booleanToBytes(pat.Revoked), - revokedAtKey: timeToBytes(pat.RevokedAt), - } - scopeKV, err := scopeToKeyValue(pat.Scope) - if err != nil { - return nil, err - } - for k, v := range scopeKV { - kv[k] = v - } - return kv, nil -} - -func scopeToKeyValue(scope auth.Scope) (map[string][]byte, error) { - kv := map[string][]byte{} - for opType, scopeValue := range scope.Users { - tempKV, err := scopeEntryToKeyValue(auth.PlatformUsersScope, "", auth.DomainNullScope, opType, scopeValue.Values()...) - if err != nil { - return nil, err - } - for k, v := range tempKV { - kv[k] = v - } - } - for opType, scopeValue := range scope.Dashboard { - tempKV, err := scopeEntryToKeyValue(auth.PlatformDashBoardScope, "", auth.DomainNullScope, opType, scopeValue.Values()...) - if err != nil { - return nil, err - } - for k, v := range tempKV { - kv[k] = v - } - } - for opType, scopeValue := range scope.Messaging { - tempKV, err := scopeEntryToKeyValue(auth.PlatformMesagingScope, "", auth.DomainNullScope, opType, scopeValue.Values()...) - if err != nil { - return nil, err - } - for k, v := range tempKV { - kv[k] = v - } - } - for domainID, domainScope := range scope.Domains { - for opType, scopeValue := range domainScope.DomainManagement { - tempKV, err := scopeEntryToKeyValue(auth.PlatformDomainsScope, domainID, auth.DomainManagementScope, opType, scopeValue.Values()...) - if err != nil { - return nil, errors.Wrap(repoerr.ErrCreateEntity, err) - } - for k, v := range tempKV { - kv[k] = v - } - } - for entityType, scope := range domainScope.Entities { - for opType, scopeValue := range scope { - tempKV, err := scopeEntryToKeyValue(auth.PlatformDomainsScope, domainID, entityType, opType, scopeValue.Values()...) - if err != nil { - return nil, errors.Wrap(repoerr.ErrCreateEntity, err) - } - for k, v := range tempKV { - kv[k] = v - } - } - } - } - return kv, nil -} - -func scopeEntryToKeyValue(platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (map[string][]byte, error) { - if len(entityIDs) == 0 { - return nil, repoerr.ErrMalformedEntity - } - - rootKey, err := scopeRootKey(platformEntityType, optionalDomainID, optionalDomainEntityType, operation) - if err != nil { - return nil, err - } - if len(entityIDs) == 1 && entityIDs[0] == anyID { - return map[string][]byte{rootKey: anyIDValue}, nil - } - - kv := map[string][]byte{rootKey: selectedIDsValue} - - for _, entryID := range entityIDs { - if entryID == anyID { - return nil, repoerr.ErrMalformedEntity - } - kv[rootKey+keySeparator+entryID] = entityValue - } - - return kv, nil -} - -func scopeRootKey(platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType) (string, error) { - op, err := operation.ValidString() - if err != nil { - return "", errors.Wrap(repoerr.ErrMalformedEntity, err) - } - - var rootKey strings.Builder - - rootKey.WriteString(scopeKey) - rootKey.WriteString(keySeparator) - rootKey.WriteString(platformEntityType.String()) - rootKey.WriteString(keySeparator) - - switch platformEntityType { - case auth.PlatformUsersScope: - rootKey.WriteString(op) - case auth.PlatformDashBoardScope: - rootKey.WriteString(op) - case auth.PlatformMesagingScope: - rootKey.WriteString(op) - case auth.PlatformDomainsScope: - if optionalDomainID == "" { - return "", fmt.Errorf("failed to add platform %s scope: invalid domain id", platformEntityType.String()) - } - odet, err := optionalDomainEntityType.ValidString() - if err != nil { - return "", errors.Wrap(repoerr.ErrMalformedEntity, err) - } - rootKey.WriteString(optionalDomainID) - rootKey.WriteString(keySeparator) - rootKey.WriteString(odet) - rootKey.WriteString(keySeparator) - rootKey.WriteString(op) - default: - return "", errors.Wrap(repoerr.ErrMalformedEntity, fmt.Errorf("invalid platform entity type %s", platformEntityType.String())) - } - - return rootKey.String(), nil -} - -func keyValueToBasicPAT(kv map[string][]byte) auth.PAT { - var pat auth.PAT - for k, v := range kv { - switch { - case strings.HasSuffix(k, keySeparator+idKey): - pat.ID = string(v) - case strings.HasSuffix(k, keySeparator+userKey): - pat.User = string(v) - case strings.HasSuffix(k, keySeparator+nameKey): - pat.Name = string(v) - case strings.HasSuffix(k, keySeparator+descriptionKey): - pat.Description = string(v) - case strings.HasSuffix(k, keySeparator+issuedAtKey): - pat.IssuedAt = bytesToTime(v) - case strings.HasSuffix(k, keySeparator+expiresAtKey): - pat.ExpiresAt = bytesToTime(v) - case strings.HasSuffix(k, keySeparator+updatedAtKey): - pat.UpdatedAt = bytesToTime(v) - case strings.HasSuffix(k, keySeparator+lastUsedAtKey): - pat.LastUsedAt = bytesToTime(v) - case strings.HasSuffix(k, keySeparator+revokedKey): - pat.Revoked = bytesToBoolean(v) - case strings.HasSuffix(k, keySeparator+revokedAtKey): - pat.RevokedAt = bytesToTime(v) - } - } - return pat -} - -func keyValueToPAT(kv map[string][]byte) (auth.PAT, error) { - pat := keyValueToBasicPAT(kv) - scope, err := parseKeyValueToScope(kv) - if err != nil { - return auth.PAT{}, err - } - pat.Scope = scope - return pat, nil -} - -func parseKeyValueToScope(kv map[string][]byte) (auth.Scope, error) { - scope := auth.Scope{ - Domains: make(map[string]auth.DomainScope), - } - for key, value := range kv { - if strings.Index(key, keySeparator+scopeKey+keySeparator) > 0 { - keyParts := strings.Split(key, keySeparator) - - platformEntityType, err := auth.ParsePlatformEntityType(keyParts[2]) - if err != nil { - return auth.Scope{}, errors.Wrap(repoerr.ErrViewEntity, err) - } - - switch platformEntityType { - case auth.PlatformUsersScope: - scope.Users, err = parseOperation(platformEntityType, scope.Users, key, keyParts, value) - if err != nil { - return auth.Scope{}, errors.Wrap(repoerr.ErrViewEntity, err) - } - - case auth.PlatformDashBoardScope: - scope.Dashboard, err = parseOperation(platformEntityType, scope.Dashboard, key, keyParts, value) - if err != nil { - return auth.Scope{}, errors.Wrap(repoerr.ErrViewEntity, err) - } - - case auth.PlatformMesagingScope: - scope.Messaging, err = parseOperation(platformEntityType, scope.Messaging, key, keyParts, value) - if err != nil { - return auth.Scope{}, errors.Wrap(repoerr.ErrViewEntity, err) - } - - case auth.PlatformDomainsScope: - if len(keyParts) < 6 { - return auth.Scope{}, fmt.Errorf("invalid scope key format: %s", key) - } - domainID := keyParts[3] - if scope.Domains == nil { - scope.Domains = make(map[string]auth.DomainScope) - } - if _, ok := scope.Domains[domainID]; !ok { - scope.Domains[domainID] = auth.DomainScope{} - } - domainScope := scope.Domains[domainID] - - entityType := keyParts[4] - - switch entityType { - case auth.DomainManagementScope.String(): - domainScope.DomainManagement, err = parseOperation(platformEntityType, domainScope.DomainManagement, key, keyParts, value) - if err != nil { - return auth.Scope{}, errors.Wrap(repoerr.ErrViewEntity, err) - } - default: - etype, err := auth.ParseDomainEntityType(entityType) - if err != nil { - return auth.Scope{}, fmt.Errorf("key %s invalid entity type %s : %w", key, entityType, err) - } - if domainScope.Entities == nil { - domainScope.Entities = make(map[auth.DomainEntityType]auth.OperationScope) - } - if _, ok := domainScope.Entities[etype]; !ok { - domainScope.Entities[etype] = auth.OperationScope{} - } - entityOperationScope := domainScope.Entities[etype] - entityOperationScope, err = parseOperation(platformEntityType, entityOperationScope, key, keyParts, value) - if err != nil { - return auth.Scope{}, errors.Wrap(repoerr.ErrViewEntity, err) - } - domainScope.Entities[etype] = entityOperationScope - } - scope.Domains[domainID] = domainScope - default: - return auth.Scope{}, errors.Wrap(repoerr.ErrViewEntity, fmt.Errorf("invalid platform entity type : %s", platformEntityType.String())) - } - } - } - return scope, nil -} - -func parseOperation(platformEntityType auth.PlatformEntityType, opScope auth.OperationScope, key string, keyParts []string, value []byte) (auth.OperationScope, error) { - if opScope == nil { - opScope = make(map[auth.OperationType]auth.ScopeValue) - } - - if err := validateOperation(platformEntityType, opScope, key, keyParts, value); err != nil { - return auth.OperationScope{}, err - } - - switch string(value) { - case string(entityValue): - opType, err := auth.ParseOperationType(keyParts[len(keyParts)-2]) - if err != nil { - return auth.OperationScope{}, errors.Wrap(repoerr.ErrViewEntity, err) - } - entityID := keyParts[len(keyParts)-1] - - if _, oValueExists := opScope[opType]; !oValueExists { - opScope[opType] = &auth.SelectedIDs{} - } - oValue := opScope[opType] - if err := oValue.AddValues(entityID); err != nil { - return auth.OperationScope{}, fmt.Errorf("failed to add scope key %s with entity value %v : %w", key, entityID, err) - } - opScope[opType] = oValue - case string(anyIDValue): - opType, err := auth.ParseOperationType(keyParts[len(keyParts)-1]) - if err != nil { - return auth.OperationScope{}, errors.Wrap(repoerr.ErrViewEntity, err) - } - if oValue, oValueExists := opScope[opType]; oValueExists && oValue != nil { - if _, ok := oValue.(*auth.AnyIDs); !ok { - return auth.OperationScope{}, fmt.Errorf("failed to add scope key %s with entity anyIDs scope value : key already initialized with different type", key) - } - } - opScope[opType] = &auth.AnyIDs{} - case string(selectedIDsValue): - opType, err := auth.ParseOperationType(keyParts[len(keyParts)-1]) - if err != nil { - return auth.OperationScope{}, errors.Wrap(repoerr.ErrViewEntity, err) - } - oValue, oValueExists := opScope[opType] - if oValueExists && oValue != nil { - if _, ok := oValue.(*auth.SelectedIDs); !ok { - return auth.OperationScope{}, fmt.Errorf("failed to add scope key %s with entity selectedIDs scope value : key already initialized with different type", key) - } - } - if !oValueExists { - opScope[opType] = &auth.SelectedIDs{} - } - default: - return auth.OperationScope{}, fmt.Errorf("key %s have invalid value %v", key, value) - } - return opScope, nil -} - -func validateOperation(platformEntityType auth.PlatformEntityType, opScope auth.OperationScope, key string, keyParts []string, value []byte) error { - expectedKeyPartsLength := 0 - switch string(value) { - case string(entityValue): - switch platformEntityType { - case auth.PlatformDomainsScope: - expectedKeyPartsLength = 7 - case auth.PlatformUsersScope, auth.PlatformDashBoardScope, auth.PlatformMesagingScope: - expectedKeyPartsLength = 5 - default: - return fmt.Errorf("invalid platform entity type : %s", platformEntityType.String()) - } - case string(selectedIDsValue), string(anyIDValue): - switch platformEntityType { - case auth.PlatformDomainsScope: - expectedKeyPartsLength = 6 - case auth.PlatformUsersScope, auth.PlatformDashBoardScope, auth.PlatformMesagingScope: - expectedKeyPartsLength = 4 - default: - return fmt.Errorf("invalid platform entity type : %s", platformEntityType.String()) - } - default: - return fmt.Errorf("key %s have invalid value %v", key, value) - } - if len(keyParts) != expectedKeyPartsLength { - return fmt.Errorf("invalid scope key format: %s", key) - } - return nil -} - -func timeToBytes(t time.Time) []byte { - timeBytes := make([]byte, 8) - binary.BigEndian.PutUint64(timeBytes, uint64(t.Unix())) - return timeBytes -} - -func bytesToTime(b []byte) time.Time { - timeAtSeconds := binary.BigEndian.Uint64(b) - return time.Unix(int64(timeAtSeconds), 0) -} - -func booleanToBytes(b bool) []byte { - if b { - return []byte{1} - } - return []byte{0} -} - -func bytesToBoolean(b []byte) bool { - if len(b) > 1 || b[0] != activateValue[0] { - return true - } - return false -} diff --git a/auth/cache/doc.go b/auth/cache/doc.go new file mode 100644 index 0000000000..42396c9830 --- /dev/null +++ b/auth/cache/doc.go @@ -0,0 +1,4 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package cache diff --git a/auth/cache/pat.go b/auth/cache/pat.go new file mode 100644 index 0000000000..53002e7237 --- /dev/null +++ b/auth/cache/pat.go @@ -0,0 +1,120 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package cache + +import ( + "context" + "fmt" + "time" + + "github.com/absmach/supermq/auth" + "github.com/absmach/supermq/pkg/errors" + repoerr "github.com/absmach/supermq/pkg/errors/repository" + "github.com/redis/go-redis/v9" +) + +type patCache struct { + client *redis.Client + duration time.Duration +} + +func NewPatsCache(client *redis.Client, duration time.Duration) auth.Cache { + return &patCache{ + client: client, + duration: duration, + } +} + +func (pc *patCache) Save(ctx context.Context, userID string, scopes []auth.Scope) error { + for _, sc := range scopes { + key := generateKey(userID, sc.PatID, sc.OptionalDomainID, sc.EntityType, sc.Operation, sc.EntityID) + if err := pc.client.Set(ctx, key, sc.ID, pc.duration).Err(); err != nil { + return errors.Wrap(repoerr.ErrCreateEntity, err) + } + } + + return nil +} + +func (pc *patCache) CheckScope(ctx context.Context, userID, patID, optionalDomainID string, entityType auth.EntityType, operation auth.Operation, entityID string) bool { + exactKey := fmt.Sprintf("pat:%s:%s:%s:%s:%s:%s", userID, patID, entityType, optionalDomainID, operation, entityID) + wildcardKey := fmt.Sprintf("pat:%s:%s:%s:%s:%s:*", userID, patID, entityType, operation, operation) + + res, err := pc.client.Exists(ctx, exactKey, wildcardKey).Result() + if err != nil { + return false + } + + return res > 0 +} + +func (pc *patCache) Remove(ctx context.Context, userID string, scopeIDs []string) error { + if len(scopeIDs) == 0 { + return repoerr.ErrRemoveEntity + } + + pattern := fmt.Sprintf("pat:%s:*", userID) + iter := pc.client.Scan(ctx, 0, pattern, 0).Iterator() + + for iter.Next(ctx) { + key := iter.Val() + val, err := pc.client.Get(ctx, key).Result() + if err != nil { + if err == redis.Nil { + continue + } + return errors.Wrap(repoerr.ErrRemoveEntity, err) + } + + for _, scopeID := range scopeIDs { + if val == scopeID { + if err := pc.client.Del(ctx, key).Err(); err != nil { + return errors.Wrap(repoerr.ErrRemoveEntity, err) + } + break + } + } + } + + if err := iter.Err(); err != nil { + return errors.Wrap(repoerr.ErrRemoveEntity, err) + } + + return nil +} + +func (pc *patCache) RemoveUserAllScope(ctx context.Context, userID string) error { + pattern := fmt.Sprintf("pat:%s:*", userID) + iter := pc.client.Scan(ctx, 0, pattern, 0).Iterator() + for iter.Next(ctx) { + if err := pc.client.Del(ctx, iter.Val()).Err(); err != nil { + return errors.Wrap(repoerr.ErrRemoveEntity, err) + } + } + if err := iter.Err(); err != nil { + return errors.Wrap(repoerr.ErrRemoveEntity, err) + } + return nil +} + +func (pc *patCache) RemoveAllScope(ctx context.Context, userID, patID string) error { + pattern := fmt.Sprintf("pat:%s:%s", userID, patID) + + iter := pc.client.Scan(ctx, 0, pattern, 0).Iterator() + for iter.Next(ctx) { + if err := pc.client.Del(ctx, iter.Val()).Err(); err != nil { + return errors.Wrap(repoerr.ErrRemoveEntity, err) + } + } + + if err := iter.Err(); err != nil { + return errors.Wrap(repoerr.ErrRemoveEntity, err) + } + + return nil +} + +func generateKey(userID, patID, optionalDomainId string, entityType auth.EntityType, operation auth.Operation, entityID string) string { + return fmt.Sprintf("pat:%s:%s:%s:%s:%s:%s", userID, patID, entityType, optionalDomainId, operation, entityID) +} diff --git a/auth/mocks/cache.go b/auth/mocks/cache.go new file mode 100644 index 0000000000..0ca7db7bb8 --- /dev/null +++ b/auth/mocks/cache.go @@ -0,0 +1,122 @@ +// Code generated by mockery v2.43.2. DO NOT EDIT. + +// Copyright (c) Abstract Machines + +package mocks + +import ( + context "context" + + auth "github.com/absmach/supermq/auth" + + mock "github.com/stretchr/testify/mock" +) + +// Cache is an autogenerated mock type for the Cache type +type Cache struct { + mock.Mock +} + +// CheckScope provides a mock function with given fields: ctx, userID, patID, optionalDomainID, entityType, operation, entityID +func (_m *Cache) CheckScope(ctx context.Context, userID string, patID string, optionalDomainID string, entityType auth.EntityType, operation auth.Operation, entityID string) bool { + ret := _m.Called(ctx, userID, patID, optionalDomainID, entityType, operation, entityID) + + if len(ret) == 0 { + panic("no return value specified for CheckScope") + } + + var r0 bool + if rf, ok := ret.Get(0).(func(context.Context, string, string, string, auth.EntityType, auth.Operation, string) bool); ok { + r0 = rf(ctx, userID, patID, optionalDomainID, entityType, operation, entityID) + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + +// Remove provides a mock function with given fields: ctx, userID, scopesID +func (_m *Cache) Remove(ctx context.Context, userID string, scopesID []string) error { + ret := _m.Called(ctx, userID, scopesID) + + if len(ret) == 0 { + panic("no return value specified for Remove") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, []string) error); ok { + r0 = rf(ctx, userID, scopesID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// RemoveAllScope provides a mock function with given fields: ctx, userID, patID +func (_m *Cache) RemoveAllScope(ctx context.Context, userID string, patID string) error { + ret := _m.Called(ctx, userID, patID) + + if len(ret) == 0 { + panic("no return value specified for RemoveAllScope") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok { + r0 = rf(ctx, userID, patID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// RemoveUserAllScope provides a mock function with given fields: ctx, userID +func (_m *Cache) RemoveUserAllScope(ctx context.Context, userID string) error { + ret := _m.Called(ctx, userID) + + if len(ret) == 0 { + panic("no return value specified for RemoveUserAllScope") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string) error); ok { + r0 = rf(ctx, userID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Save provides a mock function with given fields: ctx, userID, scopes +func (_m *Cache) Save(ctx context.Context, userID string, scopes []auth.Scope) error { + ret := _m.Called(ctx, userID, scopes) + + if len(ret) == 0 { + panic("no return value specified for Save") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, []auth.Scope) error); ok { + r0 = rf(ctx, userID, scopes) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// NewCache creates a new instance of Cache. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewCache(t interface { + mock.TestingT + Cleanup(func()) +}) *Cache { + mock := &Cache{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/auth/mocks/pats.go b/auth/mocks/pats.go index 4b920bbd10..ffd4e30b35 100644 --- a/auth/mocks/pats.go +++ b/auth/mocks/pats.go @@ -19,84 +19,17 @@ type PATS struct { mock.Mock } -// AddPATScopeEntry provides a mock function with given fields: ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs -func (_m *PATS) AddPATScopeEntry(ctx context.Context, token string, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (auth.Scope, error) { - _va := make([]interface{}, len(entityIDs)) - for _i := range entityIDs { - _va[_i] = entityIDs[_i] - } - var _ca []interface{} - _ca = append(_ca, ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) - - if len(ret) == 0 { - panic("no return value specified for AddPATScopeEntry") - } - - var r0 auth.Scope - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) (auth.Scope, error)); ok { - return rf(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) - } - if rf, ok := ret.Get(0).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) auth.Scope); ok { - r0 = rf(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) - } else { - r0 = ret.Get(0).(auth.Scope) - } - - if rf, ok := ret.Get(1).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) error); ok { - r1 = rf(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// AuthorizePAT provides a mock function with given fields: ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs -func (_m *PATS) AuthorizePAT(ctx context.Context, userID string, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) error { - _va := make([]interface{}, len(entityIDs)) - for _i := range entityIDs { - _va[_i] = entityIDs[_i] - } - var _ca []interface{} - _ca = append(_ca, ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// AddScope provides a mock function with given fields: ctx, token, patID, scopes +func (_m *PATS) AddScope(ctx context.Context, token string, patID string, scopes []auth.Scope) error { + ret := _m.Called(ctx, token, patID, scopes) if len(ret) == 0 { - panic("no return value specified for AuthorizePAT") - } - - var r0 error - if rf, ok := ret.Get(0).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) error); ok { - r0 = rf(ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// CheckPAT provides a mock function with given fields: ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs -func (_m *PATS) CheckPAT(ctx context.Context, userID string, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) error { - _va := make([]interface{}, len(entityIDs)) - for _i := range entityIDs { - _va[_i] = entityIDs[_i] - } - var _ca []interface{} - _ca = append(_ca, ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) - - if len(ret) == 0 { - panic("no return value specified for CheckPAT") + panic("no return value specified for AddScope") } var r0 error - if rf, ok := ret.Get(0).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) error); ok { - r0 = rf(ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) + if rf, ok := ret.Get(0).(func(context.Context, string, string, []auth.Scope) error); ok { + r0 = rf(ctx, token, patID, scopes) } else { r0 = ret.Error(0) } @@ -104,17 +37,17 @@ func (_m *PATS) CheckPAT(ctx context.Context, userID string, patID string, platf return r0 } -// ClearPATAllScopeEntry provides a mock function with given fields: ctx, token, patID -func (_m *PATS) ClearPATAllScopeEntry(ctx context.Context, token string, patID string) error { - ret := _m.Called(ctx, token, patID) +// AuthorizePAT provides a mock function with given fields: ctx, userID, patID, entityType, optionalDomainID, operation, entityID +func (_m *PATS) AuthorizePAT(ctx context.Context, userID string, patID string, entityType auth.EntityType, optionalDomainID string, operation auth.Operation, entityID string) error { + ret := _m.Called(ctx, userID, patID, entityType, optionalDomainID, operation, entityID) if len(ret) == 0 { - panic("no return value specified for ClearPATAllScopeEntry") + panic("no return value specified for AuthorizePAT") } var r0 error - if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok { - r0 = rf(ctx, token, patID) + if rf, ok := ret.Get(0).(func(context.Context, string, string, auth.EntityType, string, auth.Operation, string) error); ok { + r0 = rf(ctx, userID, patID, entityType, optionalDomainID, operation, entityID) } else { r0 = ret.Error(0) } @@ -122,9 +55,9 @@ func (_m *PATS) ClearPATAllScopeEntry(ctx context.Context, token string, patID s return r0 } -// CreatePAT provides a mock function with given fields: ctx, token, name, description, duration, scope -func (_m *PATS) CreatePAT(ctx context.Context, token string, name string, description string, duration time.Duration, scope auth.Scope) (auth.PAT, error) { - ret := _m.Called(ctx, token, name, description, duration, scope) +// CreatePAT provides a mock function with given fields: ctx, token, name, description, duration +func (_m *PATS) CreatePAT(ctx context.Context, token string, name string, description string, duration time.Duration) (auth.PAT, error) { + ret := _m.Called(ctx, token, name, description, duration) if len(ret) == 0 { panic("no return value specified for CreatePAT") @@ -132,17 +65,17 @@ func (_m *PATS) CreatePAT(ctx context.Context, token string, name string, descri var r0 auth.PAT var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, string, string, time.Duration, auth.Scope) (auth.PAT, error)); ok { - return rf(ctx, token, name, description, duration, scope) + if rf, ok := ret.Get(0).(func(context.Context, string, string, string, time.Duration) (auth.PAT, error)); ok { + return rf(ctx, token, name, description, duration) } - if rf, ok := ret.Get(0).(func(context.Context, string, string, string, time.Duration, auth.Scope) auth.PAT); ok { - r0 = rf(ctx, token, name, description, duration, scope) + if rf, ok := ret.Get(0).(func(context.Context, string, string, string, time.Duration) auth.PAT); ok { + r0 = rf(ctx, token, name, description, duration) } else { r0 = ret.Get(0).(auth.PAT) } - if rf, ok := ret.Get(1).(func(context.Context, string, string, string, time.Duration, auth.Scope) error); ok { - r1 = rf(ctx, token, name, description, duration, scope) + if rf, ok := ret.Get(1).(func(context.Context, string, string, string, time.Duration) error); ok { + r1 = rf(ctx, token, name, description, duration) } else { r1 = ret.Error(1) } @@ -224,34 +157,27 @@ func (_m *PATS) ListPATS(ctx context.Context, token string, pm auth.PATSPageMeta return r0, r1 } -// RemovePATScopeEntry provides a mock function with given fields: ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs -func (_m *PATS) RemovePATScopeEntry(ctx context.Context, token string, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (auth.Scope, error) { - _va := make([]interface{}, len(entityIDs)) - for _i := range entityIDs { - _va[_i] = entityIDs[_i] - } - var _ca []interface{} - _ca = append(_ca, ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// ListScopes provides a mock function with given fields: ctx, token, pm +func (_m *PATS) ListScopes(ctx context.Context, token string, pm auth.ScopesPageMeta) (auth.ScopesPage, error) { + ret := _m.Called(ctx, token, pm) if len(ret) == 0 { - panic("no return value specified for RemovePATScopeEntry") + panic("no return value specified for ListScopes") } - var r0 auth.Scope + var r0 auth.ScopesPage var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) (auth.Scope, error)); ok { - return rf(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) + if rf, ok := ret.Get(0).(func(context.Context, string, auth.ScopesPageMeta) (auth.ScopesPage, error)); ok { + return rf(ctx, token, pm) } - if rf, ok := ret.Get(0).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) auth.Scope); ok { - r0 = rf(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) + if rf, ok := ret.Get(0).(func(context.Context, string, auth.ScopesPageMeta) auth.ScopesPage); ok { + r0 = rf(ctx, token, pm) } else { - r0 = ret.Get(0).(auth.Scope) + r0 = ret.Get(0).(auth.ScopesPage) } - if rf, ok := ret.Get(1).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) error); ok { - r1 = rf(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) + if rf, ok := ret.Get(1).(func(context.Context, string, auth.ScopesPageMeta) error); ok { + r1 = rf(ctx, token, pm) } else { r1 = ret.Error(1) } @@ -259,6 +185,67 @@ func (_m *PATS) RemovePATScopeEntry(ctx context.Context, token string, patID str return r0, r1 } +// RemoveAllPAT provides a mock function with given fields: ctx, token +func (_m *PATS) RemoveAllPAT(ctx context.Context, token string) error { + ret := _m.Called(ctx, token) + + if len(ret) == 0 { + panic("no return value specified for RemoveAllPAT") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string) error); ok { + r0 = rf(ctx, token) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// RemovePATAllScope provides a mock function with given fields: ctx, token, patID +func (_m *PATS) RemovePATAllScope(ctx context.Context, token string, patID string) error { + ret := _m.Called(ctx, token, patID) + + if len(ret) == 0 { + panic("no return value specified for RemovePATAllScope") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok { + r0 = rf(ctx, token, patID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// RemoveScope provides a mock function with given fields: ctx, token, patID, scopeIDs +func (_m *PATS) RemoveScope(ctx context.Context, token string, patID string, scopeIDs ...string) error { + _va := make([]interface{}, len(scopeIDs)) + for _i := range scopeIDs { + _va[_i] = scopeIDs[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, token, patID) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for RemoveScope") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, ...string) error); ok { + r0 = rf(ctx, token, patID, scopeIDs...) + } else { + r0 = ret.Error(0) + } + + return r0 +} + // ResetPATSecret provides a mock function with given fields: ctx, token, patID, duration func (_m *PATS) ResetPATSecret(ctx context.Context, token string, patID string, duration time.Duration) (auth.PAT, error) { ret := _m.Called(ctx, token, patID, duration) diff --git a/auth/mocks/patsrepo.go b/auth/mocks/patsrepo.go index c15a8752dc..7ed3ad1048 100644 --- a/auth/mocks/patsrepo.go +++ b/auth/mocks/patsrepo.go @@ -19,59 +19,35 @@ type PATSRepository struct { mock.Mock } -// AddScopeEntry provides a mock function with given fields: ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs -func (_m *PATSRepository) AddScopeEntry(ctx context.Context, userID string, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (auth.Scope, error) { - _va := make([]interface{}, len(entityIDs)) - for _i := range entityIDs { - _va[_i] = entityIDs[_i] - } - var _ca []interface{} - _ca = append(_ca, ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// AddScope provides a mock function with given fields: ctx, userID, scopes +func (_m *PATSRepository) AddScope(ctx context.Context, userID string, scopes []auth.Scope) error { + ret := _m.Called(ctx, userID, scopes) if len(ret) == 0 { - panic("no return value specified for AddScopeEntry") + panic("no return value specified for AddScope") } - var r0 auth.Scope - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) (auth.Scope, error)); ok { - return rf(ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) - } - if rf, ok := ret.Get(0).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) auth.Scope); ok { - r0 = rf(ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) - } else { - r0 = ret.Get(0).(auth.Scope) - } - - if rf, ok := ret.Get(1).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) error); ok { - r1 = rf(ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, []auth.Scope) error); ok { + r0 = rf(ctx, userID, scopes) } else { - r1 = ret.Error(1) + r0 = ret.Error(0) } - return r0, r1 + return r0 } -// CheckScopeEntry provides a mock function with given fields: ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs -func (_m *PATSRepository) CheckScopeEntry(ctx context.Context, userID string, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) error { - _va := make([]interface{}, len(entityIDs)) - for _i := range entityIDs { - _va[_i] = entityIDs[_i] - } - var _ca []interface{} - _ca = append(_ca, ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// CheckScope provides a mock function with given fields: ctx, userID, patID, entityType, optionalDomainID, operation, entityID +func (_m *PATSRepository) CheckScope(ctx context.Context, userID string, patID string, entityType auth.EntityType, optionalDomainID string, operation auth.Operation, entityID string) error { + ret := _m.Called(ctx, userID, patID, entityType, optionalDomainID, operation, entityID) if len(ret) == 0 { - panic("no return value specified for CheckScopeEntry") + panic("no return value specified for CheckScope") } var r0 error - if rf, ok := ret.Get(0).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) error); ok { - r0 = rf(ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) + if rf, ok := ret.Get(0).(func(context.Context, string, string, auth.EntityType, string, auth.Operation, string) error); ok { + r0 = rf(ctx, userID, patID, entityType, optionalDomainID, operation, entityID) } else { r0 = ret.Error(0) } @@ -115,17 +91,17 @@ func (_m *PATSRepository) Remove(ctx context.Context, userID string, patID strin return r0 } -// RemoveAllScopeEntry provides a mock function with given fields: ctx, userID, patID -func (_m *PATSRepository) RemoveAllScopeEntry(ctx context.Context, userID string, patID string) error { - ret := _m.Called(ctx, userID, patID) +// RemoveAllPAT provides a mock function with given fields: ctx, userID +func (_m *PATSRepository) RemoveAllPAT(ctx context.Context, userID string) error { + ret := _m.Called(ctx, userID) if len(ret) == 0 { - panic("no return value specified for RemoveAllScopeEntry") + panic("no return value specified for RemoveAllPAT") } var r0 error - if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok { - r0 = rf(ctx, userID, patID) + if rf, ok := ret.Get(0).(func(context.Context, string) error); ok { + r0 = rf(ctx, userID) } else { r0 = ret.Error(0) } @@ -133,39 +109,47 @@ func (_m *PATSRepository) RemoveAllScopeEntry(ctx context.Context, userID string return r0 } -// RemoveScopeEntry provides a mock function with given fields: ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs -func (_m *PATSRepository) RemoveScopeEntry(ctx context.Context, userID string, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (auth.Scope, error) { - _va := make([]interface{}, len(entityIDs)) - for _i := range entityIDs { - _va[_i] = entityIDs[_i] +// RemoveAllScope provides a mock function with given fields: ctx, patID +func (_m *PATSRepository) RemoveAllScope(ctx context.Context, patID string) error { + ret := _m.Called(ctx, patID) + + if len(ret) == 0 { + panic("no return value specified for RemoveAllScope") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string) error); ok { + r0 = rf(ctx, patID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// RemoveScope provides a mock function with given fields: ctx, userID, scopesIDs +func (_m *PATSRepository) RemoveScope(ctx context.Context, userID string, scopesIDs ...string) error { + _va := make([]interface{}, len(scopesIDs)) + for _i := range scopesIDs { + _va[_i] = scopesIDs[_i] } var _ca []interface{} - _ca = append(_ca, ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation) + _ca = append(_ca, ctx, userID) _ca = append(_ca, _va...) ret := _m.Called(_ca...) if len(ret) == 0 { - panic("no return value specified for RemoveScopeEntry") - } - - var r0 auth.Scope - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) (auth.Scope, error)); ok { - return rf(ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) - } - if rf, ok := ret.Get(0).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) auth.Scope); ok { - r0 = rf(ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) - } else { - r0 = ret.Get(0).(auth.Scope) + panic("no return value specified for RemoveScope") } - if rf, ok := ret.Get(1).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) error); ok { - r1 = rf(ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, ...string) error); ok { + r0 = rf(ctx, userID, scopesIDs...) } else { - r1 = ret.Error(1) + r0 = ret.Error(0) } - return r0, r1 + return r0 } // Retrieve provides a mock function with given fields: ctx, userID, patID @@ -224,6 +208,34 @@ func (_m *PATSRepository) RetrieveAll(ctx context.Context, userID string, pm aut return r0, r1 } +// RetrieveScope provides a mock function with given fields: ctx, pm +func (_m *PATSRepository) RetrieveScope(ctx context.Context, pm auth.ScopesPageMeta) (auth.ScopesPage, error) { + ret := _m.Called(ctx, pm) + + if len(ret) == 0 { + panic("no return value specified for RetrieveScope") + } + + var r0 auth.ScopesPage + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, auth.ScopesPageMeta) (auth.ScopesPage, error)); ok { + return rf(ctx, pm) + } + if rf, ok := ret.Get(0).(func(context.Context, auth.ScopesPageMeta) auth.ScopesPage); ok { + r0 = rf(ctx, pm) + } else { + r0 = ret.Get(0).(auth.ScopesPage) + } + + if rf, ok := ret.Get(1).(func(context.Context, auth.ScopesPageMeta) error); ok { + r1 = rf(ctx, pm) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // RetrieveSecretAndRevokeStatus provides a mock function with given fields: ctx, userID, patID func (_m *PATSRepository) RetrieveSecretAndRevokeStatus(ctx context.Context, userID string, patID string) (string, bool, bool, error) { ret := _m.Called(ctx, userID, patID) diff --git a/auth/mocks/service.go b/auth/mocks/service.go index 0a01591776..acf2515c6d 100644 --- a/auth/mocks/service.go +++ b/auth/mocks/service.go @@ -21,39 +21,22 @@ type Service struct { mock.Mock } -// AddPATScopeEntry provides a mock function with given fields: ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs -func (_m *Service) AddPATScopeEntry(ctx context.Context, token string, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (auth.Scope, error) { - _va := make([]interface{}, len(entityIDs)) - for _i := range entityIDs { - _va[_i] = entityIDs[_i] - } - var _ca []interface{} - _ca = append(_ca, ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// AddScope provides a mock function with given fields: ctx, token, patID, scopes +func (_m *Service) AddScope(ctx context.Context, token string, patID string, scopes []auth.Scope) error { + ret := _m.Called(ctx, token, patID, scopes) if len(ret) == 0 { - panic("no return value specified for AddPATScopeEntry") + panic("no return value specified for AddScope") } - var r0 auth.Scope - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) (auth.Scope, error)); ok { - return rf(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) - } - if rf, ok := ret.Get(0).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) auth.Scope); ok { - r0 = rf(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) - } else { - r0 = ret.Get(0).(auth.Scope) - } - - if rf, ok := ret.Get(1).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) error); ok { - r1 = rf(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, []auth.Scope) error); ok { + r0 = rf(ctx, token, patID, scopes) } else { - r1 = ret.Error(1) + r0 = ret.Error(0) } - return r0, r1 + return r0 } // Authorize provides a mock function with given fields: ctx, pr @@ -74,24 +57,17 @@ func (_m *Service) Authorize(ctx context.Context, pr policies.Policy) error { return r0 } -// AuthorizePAT provides a mock function with given fields: ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs -func (_m *Service) AuthorizePAT(ctx context.Context, userID string, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) error { - _va := make([]interface{}, len(entityIDs)) - for _i := range entityIDs { - _va[_i] = entityIDs[_i] - } - var _ca []interface{} - _ca = append(_ca, ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// AuthorizePAT provides a mock function with given fields: ctx, userID, patID, entityType, optionalDomainID, operation, entityID +func (_m *Service) AuthorizePAT(ctx context.Context, userID string, patID string, entityType auth.EntityType, optionalDomainID string, operation auth.Operation, entityID string) error { + ret := _m.Called(ctx, userID, patID, entityType, optionalDomainID, operation, entityID) if len(ret) == 0 { panic("no return value specified for AuthorizePAT") } var r0 error - if rf, ok := ret.Get(0).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) error); ok { - r0 = rf(ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) + if rf, ok := ret.Get(0).(func(context.Context, string, string, auth.EntityType, string, auth.Operation, string) error); ok { + r0 = rf(ctx, userID, patID, entityType, optionalDomainID, operation, entityID) } else { r0 = ret.Error(0) } @@ -99,52 +75,9 @@ func (_m *Service) AuthorizePAT(ctx context.Context, userID string, patID string return r0 } -// CheckPAT provides a mock function with given fields: ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs -func (_m *Service) CheckPAT(ctx context.Context, userID string, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) error { - _va := make([]interface{}, len(entityIDs)) - for _i := range entityIDs { - _va[_i] = entityIDs[_i] - } - var _ca []interface{} - _ca = append(_ca, ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) - - if len(ret) == 0 { - panic("no return value specified for CheckPAT") - } - - var r0 error - if rf, ok := ret.Get(0).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) error); ok { - r0 = rf(ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// ClearPATAllScopeEntry provides a mock function with given fields: ctx, token, patID -func (_m *Service) ClearPATAllScopeEntry(ctx context.Context, token string, patID string) error { - ret := _m.Called(ctx, token, patID) - - if len(ret) == 0 { - panic("no return value specified for ClearPATAllScopeEntry") - } - - var r0 error - if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok { - r0 = rf(ctx, token, patID) - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// CreatePAT provides a mock function with given fields: ctx, token, name, description, duration, scope -func (_m *Service) CreatePAT(ctx context.Context, token string, name string, description string, duration time.Duration, scope auth.Scope) (auth.PAT, error) { - ret := _m.Called(ctx, token, name, description, duration, scope) +// CreatePAT provides a mock function with given fields: ctx, token, name, description, duration +func (_m *Service) CreatePAT(ctx context.Context, token string, name string, description string, duration time.Duration) (auth.PAT, error) { + ret := _m.Called(ctx, token, name, description, duration) if len(ret) == 0 { panic("no return value specified for CreatePAT") @@ -152,17 +85,17 @@ func (_m *Service) CreatePAT(ctx context.Context, token string, name string, des var r0 auth.PAT var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, string, string, time.Duration, auth.Scope) (auth.PAT, error)); ok { - return rf(ctx, token, name, description, duration, scope) + if rf, ok := ret.Get(0).(func(context.Context, string, string, string, time.Duration) (auth.PAT, error)); ok { + return rf(ctx, token, name, description, duration) } - if rf, ok := ret.Get(0).(func(context.Context, string, string, string, time.Duration, auth.Scope) auth.PAT); ok { - r0 = rf(ctx, token, name, description, duration, scope) + if rf, ok := ret.Get(0).(func(context.Context, string, string, string, time.Duration) auth.PAT); ok { + r0 = rf(ctx, token, name, description, duration) } else { r0 = ret.Get(0).(auth.PAT) } - if rf, ok := ret.Get(1).(func(context.Context, string, string, string, time.Duration, auth.Scope) error); ok { - r1 = rf(ctx, token, name, description, duration, scope) + if rf, ok := ret.Get(1).(func(context.Context, string, string, string, time.Duration) error); ok { + r1 = rf(ctx, token, name, description, duration) } else { r1 = ret.Error(1) } @@ -300,34 +233,27 @@ func (_m *Service) ListPATS(ctx context.Context, token string, pm auth.PATSPageM return r0, r1 } -// RemovePATScopeEntry provides a mock function with given fields: ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs -func (_m *Service) RemovePATScopeEntry(ctx context.Context, token string, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (auth.Scope, error) { - _va := make([]interface{}, len(entityIDs)) - for _i := range entityIDs { - _va[_i] = entityIDs[_i] - } - var _ca []interface{} - _ca = append(_ca, ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// ListScopes provides a mock function with given fields: ctx, token, pm +func (_m *Service) ListScopes(ctx context.Context, token string, pm auth.ScopesPageMeta) (auth.ScopesPage, error) { + ret := _m.Called(ctx, token, pm) if len(ret) == 0 { - panic("no return value specified for RemovePATScopeEntry") + panic("no return value specified for ListScopes") } - var r0 auth.Scope + var r0 auth.ScopesPage var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) (auth.Scope, error)); ok { - return rf(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) + if rf, ok := ret.Get(0).(func(context.Context, string, auth.ScopesPageMeta) (auth.ScopesPage, error)); ok { + return rf(ctx, token, pm) } - if rf, ok := ret.Get(0).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) auth.Scope); ok { - r0 = rf(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) + if rf, ok := ret.Get(0).(func(context.Context, string, auth.ScopesPageMeta) auth.ScopesPage); ok { + r0 = rf(ctx, token, pm) } else { - r0 = ret.Get(0).(auth.Scope) + r0 = ret.Get(0).(auth.ScopesPage) } - if rf, ok := ret.Get(1).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) error); ok { - r1 = rf(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) + if rf, ok := ret.Get(1).(func(context.Context, string, auth.ScopesPageMeta) error); ok { + r1 = rf(ctx, token, pm) } else { r1 = ret.Error(1) } @@ -335,6 +261,67 @@ func (_m *Service) RemovePATScopeEntry(ctx context.Context, token string, patID return r0, r1 } +// RemoveAllPAT provides a mock function with given fields: ctx, token +func (_m *Service) RemoveAllPAT(ctx context.Context, token string) error { + ret := _m.Called(ctx, token) + + if len(ret) == 0 { + panic("no return value specified for RemoveAllPAT") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string) error); ok { + r0 = rf(ctx, token) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// RemovePATAllScope provides a mock function with given fields: ctx, token, patID +func (_m *Service) RemovePATAllScope(ctx context.Context, token string, patID string) error { + ret := _m.Called(ctx, token, patID) + + if len(ret) == 0 { + panic("no return value specified for RemovePATAllScope") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok { + r0 = rf(ctx, token, patID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// RemoveScope provides a mock function with given fields: ctx, token, patID, scopeIDs +func (_m *Service) RemoveScope(ctx context.Context, token string, patID string, scopeIDs ...string) error { + _va := make([]interface{}, len(scopeIDs)) + for _i := range scopeIDs { + _va[_i] = scopeIDs[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, token, patID) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for RemoveScope") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, ...string) error); ok { + r0 = rf(ctx, token, patID, scopeIDs...) + } else { + r0 = ret.Error(0) + } + + return r0 +} + // ResetPATSecret provides a mock function with given fields: ctx, token, patID, duration func (_m *Service) ResetPATSecret(ctx context.Context, token string, patID string, duration time.Duration) (auth.PAT, error) { ret := _m.Called(ctx, token, patID, duration) diff --git a/auth/pat.go b/auth/pat.go index 4a168fa0ec..8fe084437c 100644 --- a/auth/pat.go +++ b/auth/pat.go @@ -7,18 +7,18 @@ import ( "context" "encoding/json" "fmt" + "strings" "time" "github.com/absmach/supermq/pkg/errors" ) -var errAddEntityToAnyIDs = errors.New("could not add entity id to any ID scope value") +const AnyIDs = "*" -// Define OperationType. -type OperationType uint32 +type Operation uint32 const ( - CreateOp OperationType = iota + CreateOp Operation = iota ReadOp ListOp UpdateOp @@ -41,8 +41,8 @@ const ( SubscribeOpStr = "subscribe" ) -func (ot OperationType) String() string { - switch ot { +func (op Operation) String() string { + switch op { case CreateOp: return createOpStr case ReadOp: @@ -62,20 +62,20 @@ func (ot OperationType) String() string { case SubscribeOp: return SubscribeOpStr default: - return fmt.Sprintf("unknown operation type %d", ot) + return fmt.Sprintf("unknown operation type %d", op) } } -func (ot OperationType) ValidString() (string, error) { - str := ot.String() - if str == fmt.Sprintf("unknown operation type %d", ot) { +func (op Operation) ValidString() (string, error) { + str := op.String() + if str == fmt.Sprintf("unknown operation type %d", op) { return "", errors.New(str) } return str, nil } -func ParseOperationType(ot string) (OperationType, error) { - switch ot { +func ParseOperation(op string) (Operation, error) { + switch op { case createOpStr: return CreateOp, nil case readOpStr: @@ -95,592 +95,205 @@ func ParseOperationType(ot string) (OperationType, error) { case SubscribeOpStr: return SubscribeOp, nil default: - return 0, fmt.Errorf("unknown operation type %s", ot) + return 0, fmt.Errorf("unknown operation type %s", op) } } -func (ot OperationType) MarshalJSON() ([]byte, error) { - return []byte(ot.String()), nil +func (op Operation) MarshalJSON() ([]byte, error) { + return json.Marshal(op.String()) } -func (ot OperationType) MarshalText() (text []byte, err error) { - return []byte(ot.String()), nil -} - -func (ot *OperationType) UnmarshalText(data []byte) (err error) { - *ot, err = ParseOperationType(string(data)) +func (op *Operation) UnmarshalJSON(data []byte) error { + str := strings.Trim(string(data), "\"") + val, err := ParseOperation(str) + *op = val return err } -// Define DomainEntityType. -type DomainEntityType uint32 - -const ( - DomainManagementScope DomainEntityType = iota - DomainGroupsScope - DomainChannelsScope - DomainClientsScope - DomainNullScope -) - -const ( - domainManagementScopeStr = "domain_management" - domainGroupsScopeStr = "groups" - domainChannelsScopeStr = "channels" - domainClientsScopeStr = "clients" -) - -func (det DomainEntityType) String() string { - switch det { - case DomainManagementScope: - return domainManagementScopeStr - case DomainGroupsScope: - return domainGroupsScopeStr - case DomainChannelsScope: - return domainChannelsScopeStr - case DomainClientsScope: - return domainClientsScopeStr - default: - return fmt.Sprintf("unknown domain entity type %d", det) - } -} - -func (det DomainEntityType) ValidString() (string, error) { - str := det.String() - if str == fmt.Sprintf("unknown operation type %d", det) { - return "", errors.New(str) - } - return str, nil -} - -func ParseDomainEntityType(det string) (DomainEntityType, error) { - switch det { - case domainManagementScopeStr: - return DomainManagementScope, nil - case domainGroupsScopeStr: - return DomainGroupsScope, nil - case domainChannelsScopeStr: - return DomainChannelsScope, nil - case domainClientsScopeStr: - return DomainClientsScope, nil - default: - return 0, fmt.Errorf("unknown domain entity type %s", det) - } -} - -func (det DomainEntityType) MarshalJSON() ([]byte, error) { - return []byte(det.String()), nil -} - -func (det DomainEntityType) MarshalText() ([]byte, error) { - return []byte(det.String()), nil +func (op Operation) MarshalText() (text []byte, err error) { + return []byte(op.String()), nil } -func (det *DomainEntityType) UnmarshalText(data []byte) (err error) { - *det, err = ParseDomainEntityType(string(data)) +func (op *Operation) UnmarshalText(data []byte) (err error) { + str := strings.Trim(string(data), "\"") + *op, err = ParseOperation(str) return err } -// Define DomainEntityType. -type PlatformEntityType uint32 +type EntityType uint32 const ( - PlatformUsersScope PlatformEntityType = iota - PlatformDomainsScope - PlatformDashBoardScope - PlatformMesagingScope + GroupsType EntityType = iota + ChannelsType + ClientsType + DomainsType + UsersType + DashboardType + MessagesType ) const ( - platformUsersScopeStr = "users" - platformDomainsScopeStr = "domains" - PlatformDashBoardScopeStr = "dashboard" - PlatformMesagingScopeStr = "messaging" + GroupsScopeStr = "groups" + ChannelsScopeStr = "channels" + ClientsScopeStr = "clients" + DomainsStr = "domains" + UsersStr = "users" + DashboardsStr = "dashboards" + MessagesStr = "messages" ) -func (pet PlatformEntityType) String() string { - switch pet { - case PlatformUsersScope: - return platformUsersScopeStr - case PlatformDomainsScope: - return platformDomainsScopeStr - case PlatformDashBoardScope: - return PlatformDashBoardScopeStr - case PlatformMesagingScope: - return PlatformMesagingScopeStr +func (et EntityType) String() string { + switch et { + case GroupsType: + return GroupsScopeStr + case ChannelsType: + return ChannelsScopeStr + case ClientsType: + return ClientsScopeStr + case DomainsType: + return DomainsStr + case UsersType: + return UsersStr + case DashboardType: + return DashboardsStr + case MessagesType: + return MessagesStr default: - return fmt.Sprintf("unknown platform entity type %d", pet) + return fmt.Sprintf("unknown domain entity type %d", et) } } -func (pet PlatformEntityType) ValidString() (string, error) { - str := pet.String() - if str == fmt.Sprintf("unknown platform entity type %d", pet) { +func (et EntityType) ValidString() (string, error) { + str := et.String() + if str == fmt.Sprintf("unknown operation type %d", et) { return "", errors.New(str) } return str, nil } -func ParsePlatformEntityType(pet string) (PlatformEntityType, error) { - switch pet { - case platformUsersScopeStr: - return PlatformUsersScope, nil - case platformDomainsScopeStr: - return PlatformDomainsScope, nil +func ParseEntityType(et string) (EntityType, error) { + switch et { + case GroupsScopeStr: + return GroupsType, nil + case ChannelsScopeStr: + return ChannelsType, nil + case ClientsScopeStr: + return ClientsType, nil + case DomainsStr: + return DomainsType, nil + case UsersStr: + return UsersType, nil + case DashboardsStr: + return DashboardType, nil default: - return 0, fmt.Errorf("unknown platform entity type %s", pet) + return 0, fmt.Errorf("unknown domain entity type %s", et) } } -func (pet PlatformEntityType) MarshalJSON() ([]byte, error) { - return []byte(pet.String()), nil +func (et EntityType) MarshalJSON() ([]byte, error) { + return json.Marshal(et.String()) } -func (pet PlatformEntityType) MarshalText() (text []byte, err error) { - return []byte(pet.String()), nil -} - -func (pet *PlatformEntityType) UnmarshalText(data []byte) (err error) { - *pet, err = ParsePlatformEntityType(string(data)) +func (et *EntityType) UnmarshalJSON(data []byte) error { + str := strings.Trim(string(data), "\"") + val, err := ParseEntityType(str) + *et = val return err } -// ScopeValue interface for Any entity ids or for sets of entity ids. -type ScopeValue interface { - Contains(id string) bool - Values() []string - AddValues(ids ...string) error - RemoveValues(ids ...string) error -} - -// AnyIDs implements ScopeValue for any entity id value. -type AnyIDs struct{} - -func (s AnyIDs) Contains(id string) bool { return true } -func (s AnyIDs) Values() []string { return []string{"*"} } -func (s *AnyIDs) AddValues(ids ...string) error { return errAddEntityToAnyIDs } -func (s *AnyIDs) RemoveValues(ids ...string) error { return errAddEntityToAnyIDs } - -// SelectedIDs implements ScopeValue for sets of entity ids. -type SelectedIDs map[string]struct{} - -func (s SelectedIDs) Contains(id string) bool { _, ok := s[id]; return ok } -func (s SelectedIDs) Values() []string { - values := []string{} - for value := range s { - values = append(values, value) - } - return values -} - -func (s *SelectedIDs) AddValues(ids ...string) error { - if *s == nil { - *s = make(SelectedIDs) - } - for _, id := range ids { - (*s)[id] = struct{}{} - } - return nil -} - -func (s *SelectedIDs) RemoveValues(ids ...string) error { - if *s == nil { - return nil - } - for _, id := range ids { - delete(*s, id) - } - return nil -} - -// OperationScope contains map of OperationType with value of AnyIDs or SelectedIDs. -type OperationScope map[OperationType]ScopeValue - -func (os *OperationScope) UnmarshalJSON(data []byte) error { - type tempOperationScope map[OperationType]json.RawMessage - - var tempScope tempOperationScope - if err := json.Unmarshal(data, &tempScope); err != nil { - return err - } - // Initialize the Operations map - *os = OperationScope{} - - for opType, rawMessage := range tempScope { - var stringValue string - var stringArrayValue []string - - // Try to unmarshal as string - if err := json.Unmarshal(rawMessage, &stringValue); err == nil { - if err := os.Add(opType, stringValue); err != nil { - return err - } - continue - } - - // Try to unmarshal as []string - if err := json.Unmarshal(rawMessage, &stringArrayValue); err == nil { - if err := os.Add(opType, stringArrayValue...); err != nil { - return err - } - continue - } - - // If neither unmarshalling succeeded, return an error - return fmt.Errorf("invalid ScopeValue for OperationType %v", opType) - } - - return nil +func (et EntityType) MarshalText() ([]byte, error) { + return []byte(et.String()), nil } -func (os OperationScope) MarshalJSON() ([]byte, error) { - tempOperationScope := make(map[OperationType]interface{}) - for oType, scope := range os { - value := scope.Values() - if len(value) == 1 && value[0] == "*" { - tempOperationScope[oType] = "*" - continue - } - tempOperationScope[oType] = value - } - - b, err := json.Marshal(tempOperationScope) - if err != nil { - return nil, err - } - return b, nil +func (et *EntityType) UnmarshalText(data []byte) (err error) { + str := strings.Trim(string(data), "\"") + *et, err = ParseEntityType(str) + return err } -func (os *OperationScope) Add(operation OperationType, entityIDs ...string) error { - var value ScopeValue - - if os == nil { - os = &OperationScope{} - } +// Example Scope as JSON +// +// [ +// { +// "optional_domain_id": "domain_1", +// "entity_type": "groups", +// "operation": "create", +// "entity_id": "*" +// }, +// { +// "optional_domain_id": "domain_1", +// "entity_type": "channels", +// "operation": "delete", +// "entity_id": "channel1" +// }, +// { +// "optional_domain_id": "domain_1", +// "entity_type": "things", +// "operation": "update", +// "entity_id": "*" +// } +// ] - if len(entityIDs) == 0 { - return fmt.Errorf("entity ID is missing") - } - switch { - case len(entityIDs) == 1 && entityIDs[0] == "*": - value = &AnyIDs{} - default: - var sids SelectedIDs - for _, entityID := range entityIDs { - if entityID == "*" { - return fmt.Errorf("list contains wildcard") - } - if sids == nil { - sids = make(SelectedIDs) - } - sids[entityID] = struct{}{} - } - value = &sids - } - (*os)[operation] = value - return nil +type Scope struct { + ID string `json:"id,omitempty"` + PatID string `json:"pat_id,omitempty"` + OptionalDomainID string `json:"optional_domain_id,omitempty"` + EntityType EntityType `json:"entity_type,omitempty"` + EntityID string `json:"entity_id,omitempty"` + Operation Operation `json:"operation,omitempty"` } -func (os *OperationScope) Delete(operation OperationType, entityIDs ...string) error { - if os == nil { - return nil - } - - opEntityIDs, exists := (*os)[operation] - if !exists { - return nil +func (s *Scope) Authorized(entityType EntityType, optionalDomainID string, operation Operation, entityID string) bool { + if s == nil { + return false } - if len(entityIDs) == 0 { - return fmt.Errorf("failed to delete operation %s: entity ID is missing", operation.String()) + if s.EntityType != entityType { + return false } - switch eIDs := opEntityIDs.(type) { - case *AnyIDs: - if !(len(entityIDs) == 1 && entityIDs[0] == "*") { - return fmt.Errorf("failed to delete operation %s: invalid list", operation.String()) - } - delete((*os), operation) - return nil - case *SelectedIDs: - for _, entityID := range entityIDs { - if !eIDs.Contains(entityID) { - return fmt.Errorf("failed to delete operation %s: invalid entity ID in list", operation.String()) - } - } - for _, entityID := range entityIDs { - delete(*eIDs, entityID) - if len(*eIDs) == 0 { - delete((*os), operation) - } - } - return nil - default: - return fmt.Errorf("failed to delete operation: invalid entity id type %d", operation) + if optionalDomainID != "" && s.OptionalDomainID != optionalDomainID { + return false } -} -func (os *OperationScope) Check(operation OperationType, entityIDs ...string) bool { - if os == nil { + if s.Operation != operation { return false } - if scopeValue, ok := (*os)[operation]; ok { - if len(entityIDs) == 0 { - _, ok := scopeValue.(*AnyIDs) - return ok - } - for _, entityID := range entityIDs { - if !scopeValue.Contains(entityID) { - return false - } - } + if s.EntityID == "*" { return true } - return false -} - -type DomainScope struct { - DomainManagement OperationScope `json:"domain_management,omitempty"` - Entities map[DomainEntityType]OperationScope `json:"entities,omitempty"` -} - -// Add entry in Domain scope. -func (ds *DomainScope) Add(domainEntityType DomainEntityType, operation OperationType, entityIDs ...string) error { - if ds == nil { - return fmt.Errorf("failed to add domain %s scope: domain_scope is nil and not initialized", domainEntityType) - } - - if domainEntityType < DomainManagementScope || domainEntityType > DomainClientsScope { - return fmt.Errorf("failed to add domain %d scope: invalid domain entity type", domainEntityType) - } - if domainEntityType == DomainManagementScope { - if err := ds.DomainManagement.Add(operation, entityIDs...); err != nil { - return fmt.Errorf("failed to delete domain management scope: %w", err) - } - } - - if ds.Entities == nil { - ds.Entities = make(map[DomainEntityType]OperationScope) - } - - opReg, ok := ds.Entities[domainEntityType] - if !ok { - opReg = OperationScope{} - } - - if err := opReg.Add(operation, entityIDs...); err != nil { - return fmt.Errorf("failed to add domain %s scope: %w ", domainEntityType.String(), err) - } - ds.Entities[domainEntityType] = opReg - return nil -} - -// Delete entry in Domain scope. -func (ds *DomainScope) Delete(domainEntityType DomainEntityType, operation OperationType, entityIDs ...string) error { - if ds == nil { - return nil - } - - if domainEntityType < DomainManagementScope || domainEntityType > DomainClientsScope { - return fmt.Errorf("failed to delete domain %d scope: invalid domain entity type", domainEntityType) - } - if ds.Entities == nil { - return nil - } - - if domainEntityType == DomainManagementScope { - if err := ds.DomainManagement.Delete(operation, entityIDs...); err != nil { - return fmt.Errorf("failed to delete domain management scope: %w", err) - } - } - - os, exists := ds.Entities[domainEntityType] - if !exists { - return nil - } - - if err := os.Delete(operation, entityIDs...); err != nil { - return fmt.Errorf("failed to delete domain %s scope: %w", domainEntityType.String(), err) - } - - if len(os) == 0 { - delete(ds.Entities, domainEntityType) - } - return nil -} - -// Check entry in Domain scope. -func (ds *DomainScope) Check(domainEntityType DomainEntityType, operation OperationType, ids ...string) bool { - if ds.Entities == nil { - return false - } - if domainEntityType < DomainManagementScope || domainEntityType > DomainClientsScope { - return false - } - if domainEntityType == DomainManagementScope { - return ds.DomainManagement.Check(operation, ids...) - } - os, exists := ds.Entities[domainEntityType] - if !exists { - return false + if s.EntityID == entityID { + return true } - - return os.Check(operation, ids...) -} - -// Example Scope as JSON -// -// { -// "users": { -// "create": ["*"], -// "read": ["*"], -// "list": ["*"], -// "update": ["*"], -// "delete": ["*"] -// }, -// "domains": { -// "domain_1": { -// "entities": { -// "groups": { -// "create": ["*"] // this for all groups in domain -// }, -// "channels": { -// // for particular channel in domain -// "delete": [ -// "channel1", -// "channel2" -// ] -// }, -// "things": { -// "update": ["*"] // this for all things in domain -// } -// } -// } -// } -// } -type Scope struct { - Users OperationScope `json:"users,omitempty"` - Domains map[string]DomainScope `json:"domains,omitempty"` - Dashboard OperationScope `json:"dashboard,omitempty"` - Messaging OperationScope `json:"messaging,omitempty"` + return false } -// Add entry in Domain scope. -func (s *Scope) Add(platformEntityType PlatformEntityType, optionalDomainID string, optionalDomainEntityType DomainEntityType, operation OperationType, entityIDs ...string) error { +func (s *Scope) Validate() error { if s == nil { - return fmt.Errorf("failed to add platform %s scope: scope is nil and not initialized", platformEntityType.String()) + return errInvalidScope } - switch platformEntityType { - case PlatformUsersScope: - if err := s.Users.Add(operation, entityIDs...); err != nil { - return fmt.Errorf("failed to add platform %s scope: %w", platformEntityType.String(), err) - } - case PlatformDashBoardScope: - if err := s.Dashboard.Add(operation, entityIDs...); err != nil { - return fmt.Errorf("failed to add platform %s scope: %w", platformEntityType.String(), err) - } - case PlatformMesagingScope: - if err := s.Messaging.Add(operation, entityIDs...); err != nil { - return fmt.Errorf("failed to add platform %s scope: %w", platformEntityType.String(), err) - } - case PlatformDomainsScope: - if optionalDomainID == "" { - return fmt.Errorf("failed to add platform %s scope: invalid domain id", platformEntityType.String()) - } - if len(s.Domains) == 0 { - s.Domains = make(map[string]DomainScope) - } - - ds, ok := s.Domains[optionalDomainID] - if !ok { - ds = DomainScope{} - } - if err := ds.Add(optionalDomainEntityType, operation, entityIDs...); err != nil { - return fmt.Errorf("failed to add platform %s id %s scope : %w", platformEntityType.String(), optionalDomainID, err) - } - s.Domains[optionalDomainID] = ds - default: - return fmt.Errorf("failed to add platform %d scope: invalid platform entity type ", platformEntityType) + if s.EntityID == "" { + return errors.New("missing entityID") } - return nil -} -// Delete entry in Domain scope. -func (s *Scope) Delete(platformEntityType PlatformEntityType, optionalDomainID string, optionalDomainEntityType DomainEntityType, operation OperationType, entityIDs ...string) error { - if s == nil { - return nil - } - switch platformEntityType { - case PlatformUsersScope: - if err := s.Users.Delete(operation, entityIDs...); err != nil { - return fmt.Errorf("failed to delete platform %s scope: %w", platformEntityType.String(), err) - } - case PlatformDashBoardScope: - if err := s.Dashboard.Delete(operation, entityIDs...); err != nil { - return fmt.Errorf("failed to delete platform %s scope: %w", platformEntityType.String(), err) - } - case PlatformMesagingScope: - if err := s.Messaging.Delete(operation, entityIDs...); err != nil { - return fmt.Errorf("failed to delete platform %s scope: %w", platformEntityType.String(), err) + switch s.EntityType { + case ChannelsType, GroupsType, ClientsType: + if s.OptionalDomainID == "" { + return errors.New("missing domainID") } - case PlatformDomainsScope: - if optionalDomainID == "" { - return fmt.Errorf("failed to delete platform %s scope: invalid domain id", platformEntityType.String()) - } - ds, ok := s.Domains[optionalDomainID] - if !ok { - return nil - } - if err := ds.Delete(optionalDomainEntityType, operation, entityIDs...); err != nil { - return fmt.Errorf("failed to delete platform %s id %s scope : %w", platformEntityType.String(), optionalDomainID, err) - } - default: - return fmt.Errorf("failed to add platform %d scope: invalid platform entity type ", platformEntityType) } return nil } -// Check entry in Domain scope. -func (s *Scope) Check(platformEntityType PlatformEntityType, optionalDomainID string, optionalDomainEntityType DomainEntityType, operation OperationType, entityIDs ...string) bool { - if s == nil { - return false - } - switch platformEntityType { - case PlatformUsersScope: - return s.Users.Check(operation, entityIDs...) - case PlatformDashBoardScope: - return s.Dashboard.Check(operation, entityIDs...) - case PlatformMesagingScope: - return s.Messaging.Check(operation, entityIDs...) - case PlatformDomainsScope: - ds, ok := s.Domains[optionalDomainID] - if !ok { - return false - } - return ds.Check(optionalDomainEntityType, operation, entityIDs...) - default: - return false - } -} - -func (s *Scope) String() string { - str, err := json.Marshal(s) // , "", " ") - if err != nil { - return fmt.Sprintf("failed to convert scope to string: json marshal error :%s", err.Error()) - } - return string(str) -} - // PAT represents Personal Access Token. type PAT struct { ID string `json:"id,omitempty"` - User string `json:"user,omitempty"` + User string `json:"user_id,omitempty"` Name string `json:"name,omitempty"` Description string `json:"description,omitempty"` Secret string `json:"secret,omitempty"` - Scope Scope `json:"scope,omitempty"` IssuedAt time.Time `json:"issued_at,omitempty"` ExpiresAt time.Time `json:"expires_at,omitempty"` UpdatedAt time.Time `json:"updated_at,omitempty"` @@ -697,7 +310,29 @@ type PATSPage struct { Total uint64 `json:"total"` Offset uint64 `json:"offset"` Limit uint64 `json:"limit"` - PATS []PAT `json:"pats"` + PATS []PAT `json:"pats,omitempty"` +} + +type ScopesPageMeta struct { + Offset uint64 `json:"offset"` + Limit uint64 `json:"limit"` + PatID string `json:"pat_id"` + ID string `json:"id"` +} + +type ScopesPage struct { + Total uint64 `json:"total"` + Offset uint64 `json:"offset,omitempty"` + Limit uint64 `json:"limit,omitempy"` + Scopes []Scope `json:"scopes,omitempty"` +} + +func (pat PAT) MarshalBinary() ([]byte, error) { + return json.Marshal(pat) +} + +func (pat *PAT) UnmarshalBinary(data []byte) error { + return json.Unmarshal(data, pat) } func (pat *PAT) String() string { @@ -708,17 +343,12 @@ func (pat *PAT) String() string { return string(str) } -// Expired verifies if the key is expired. -func (pat PAT) Expired() bool { - return pat.ExpiresAt.UTC().Before(time.Now().UTC()) -} - // PATS specifies function which are required for Personal access Token implementation. //go:generate mockery --name PATS --output=./mocks --filename pats.go --quiet --note "Copyright (c) Abstract Machines" type PATS interface { // Create function creates new PAT for given valid inputs. - CreatePAT(ctx context.Context, token, name, description string, duration time.Duration, scope Scope) (PAT, error) + CreatePAT(ctx context.Context, token, name, description string, duration time.Duration) (PAT, error) // UpdateName function updates the name for the given PAT ID. UpdatePATName(ctx context.Context, token, patID, name string) (PAT, error) @@ -729,7 +359,10 @@ type PATS interface { // Retrieve function retrieves the PAT for given ID. RetrievePAT(ctx context.Context, userID string, patID string) (PAT, error) - // List function lists all the PATs for the user. + // RemoveAllPAT function removes all PATs of user. + RemoveAllPAT(ctx context.Context, token string) error + + // ListPATS function lists all the PATs for the user. ListPATS(ctx context.Context, token string, pm PATSPageMeta) (PATSPage, error) // Delete function deletes the PAT for given ID. @@ -741,23 +374,23 @@ type PATS interface { // RevokeSecret function revokes the secret for the given ID. RevokePATSecret(ctx context.Context, token, patID string) error - // AddScope function adds a new scope entry. - AddPATScopeEntry(ctx context.Context, token, patID string, platformEntityType PlatformEntityType, optionalDomainID string, optionalDomainEntityType DomainEntityType, operation OperationType, entityIDs ...string) (Scope, error) + // AddScope function adds a new scope. + AddScope(ctx context.Context, token, patID string, scopes []Scope) error + + // RemoveScope function removes a scope. + RemoveScope(ctx context.Context, token string, patID string, scopeIDs ...string) error - // RemoveScope function removes a scope entry. - RemovePATScopeEntry(ctx context.Context, token, patID string, platformEntityType PlatformEntityType, optionalDomainID string, optionalDomainEntityType DomainEntityType, operation OperationType, entityIDs ...string) (Scope, error) + // RemovePATAllScope function removes all scope. + RemovePATAllScope(ctx context.Context, token, patID string) error - // ClearAllScope function removes all scope entry. - ClearPATAllScopeEntry(ctx context.Context, token, patID string) error + // List function lists all the Scopes for the patID. + ListScopes(ctx context.Context, token string, pm ScopesPageMeta) (ScopesPage, error) // IdentifyPAT function will valid the secret. IdentifyPAT(ctx context.Context, paToken string) (PAT, error) // AuthorizePAT function will valid the secret and check the given scope exists. - AuthorizePAT(ctx context.Context, userID, patID string, platformEntityType PlatformEntityType, optionalDomainID string, optionalDomainEntityType DomainEntityType, operation OperationType, entityIDs ...string) error - - // CheckPAT function will check the given scope exists. - CheckPAT(ctx context.Context, userID, patID string, platformEntityType PlatformEntityType, optionalDomainID string, optionalDomainEntityType DomainEntityType, operation OperationType, entityIDs ...string) error + AuthorizePAT(ctx context.Context, userID, patID string, entityType EntityType, optionalDomainID string, operation Operation, entityID string) error } // PATSRepository specifies PATS persistence API. @@ -770,6 +403,9 @@ type PATSRepository interface { // Retrieve retrieves users PAT by its unique identifier. Retrieve(ctx context.Context, userID, patID string) (pat PAT, err error) + // RetrieveScope retrieves PAT scopes by its unique identifier. + RetrieveScope(ctx context.Context, pm ScopesPageMeta) (scopes ScopesPage, err error) + // RetrieveSecretAndRevokeStatus retrieves secret and revoke status of PAT by its unique identifier. RetrieveSecretAndRevokeStatus(ctx context.Context, userID, patID string) (string, bool, bool, error) @@ -794,11 +430,27 @@ type PATSRepository interface { // Remove removes Key with provided ID. Remove(ctx context.Context, userID, patID string) error - AddScopeEntry(ctx context.Context, userID, patID string, platformEntityType PlatformEntityType, optionalDomainID string, optionalDomainEntityType DomainEntityType, operation OperationType, entityIDs ...string) (Scope, error) + // RemoveAllPAT removes all PAT for a given user. + RemoveAllPAT(ctx context.Context, userID string) error + + AddScope(ctx context.Context, userID string, scopes []Scope) error + + RemoveScope(ctx context.Context, userID string, scopesIDs ...string) error + + CheckScope(ctx context.Context, userID, patID string, entityType EntityType, optionalDomainID string, operation Operation, entityID string) error + + RemoveAllScope(ctx context.Context, patID string) error +} + +//go:generate mockery --name Cache --output=./mocks --filename cache.go --quiet --note "Copyright (c) Abstract Machines" +type Cache interface { + Save(ctx context.Context, userID string, scopes []Scope) error + + CheckScope(ctx context.Context, userID, patID, optionalDomainID string, entityType EntityType, operation Operation, entityID string) bool - RemoveScopeEntry(ctx context.Context, userID, patID string, platformEntityType PlatformEntityType, optionalDomainID string, optionalDomainEntityType DomainEntityType, operation OperationType, entityIDs ...string) (Scope, error) + Remove(ctx context.Context, userID string, scopesID []string) error - CheckScopeEntry(ctx context.Context, userID, patID string, platformEntityType PlatformEntityType, optionalDomainID string, optionalDomainEntityType DomainEntityType, operation OperationType, entityIDs ...string) error + RemoveUserAllScope(ctx context.Context, userID string) error - RemoveAllScopeEntry(ctx context.Context, userID, patID string) error + RemoveAllScope(ctx context.Context, userID, patID string) error } diff --git a/auth/postgres/init.go b/auth/postgres/init.go index 32e0ab002d..c5dd77b2e6 100644 --- a/auth/postgres/init.go +++ b/auth/postgres/init.go @@ -65,6 +65,45 @@ func Migration() *migrate.MemoryMigrationSource { `, }, }, + { + Id: "auth_4", + Up: []string{ + `CREATE TABLE IF NOT EXISTS pats ( + id VARCHAR(36) PRIMARY KEY, + name VARCHAR(254) NOT NULL, + user_id VARCHAR(36), + description TEXT, + secret TEXT, + issued_at TIMESTAMP, + expires_at TIMESTAMP, + updated_at TIMESTAMP, + revoked BOOLEAN, + revoked_at TIMESTAMP, + last_used_at TIMESTAMP, + UNIQUE (id, name, secret) + )`, + }, + Down: []string{ + `DROP TABLE IF EXISTS pats`, + }, + }, + { + Id: "auth_5", + Up: []string{ + `CREATE TABLE IF NOT EXISTS pat_scopes ( + id VARCHAR(36) PRIMARY KEY, + pat_id VARCHAR(36) REFERENCES pats(id) ON DELETE CASCADE, + optional_domain_id VARCHAR(36), + entity_type VARCHAR(50) NOT NULL, + operation VARCHAR(50) NOT NULL, + entity_id VARCHAR(50) NOT NULL, + UNIQUE (pat_id, optional_domain_id, entity_type, operation, entity_id) + );`, + }, + Down: []string{ + `DROP TABLE IF EXISTS pat_scopes;`, + }, + }, }, } } diff --git a/auth/postgres/pat.go b/auth/postgres/pat.go new file mode 100644 index 0000000000..f344cbee24 --- /dev/null +++ b/auth/postgres/pat.go @@ -0,0 +1,167 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package postgres + +import ( + "database/sql" + "time" + + "github.com/absmach/supermq/auth" + repoerr "github.com/absmach/supermq/pkg/errors/repository" +) + +type dbPat struct { + ID string `db:"id,omitempty"` + User string `db:"user_id,omitempty"` + Name string `db:"name,omitempty"` + Description string `db:"description,omitempty"` + Secret string `db:"secret,omitempty"` + IssuedAt time.Time `db:"issued_at,omitempty"` + ExpiresAt time.Time `db:"expires_at,omitempty"` + UpdatedAt sql.NullTime `db:"updated_at,omitempty"` + LastUsedAt sql.NullTime `db:"last_used_at,omitempty"` + Revoked bool `db:"revoked,omitempty"` + RevokedAt sql.NullTime `db:"revoked_at,omitempty"` +} + +type dbScope struct { + ID string `db:"id,omitempty"` + PatID string `db:"pat_id,omitempty"` + OptionalDomainID string `db:"optional_domain_id,omitempty"` + EntityType string `db:"entity_type,omitempty"` + EntityID string `db:"entity_id,omitempty"` + Operation string `db:"operation,omitempty"` +} + +type dbPagemeta struct { + Limit uint64 `db:"limit"` + Offset uint64 `db:"offset"` + User string `db:"user_id"` + PatID string `db:"pat_id"` + ScopesID []string `db:"scopes_id"` + ID string `db:"id"` + Name string `db:"name"` + UpdatedAt sql.NullTime `db:"updated_at"` + ExpiresAt time.Time `db:"expires_at"` + RevokedAt sql.NullTime `db:"revoked_at"` + Description string `db:"description"` + Secret string `db:"secret"` +} + +func toAuthPat(db dbPat) (auth.PAT, error) { + if db.ID == "" { + return auth.PAT{}, repoerr.ErrNotFound + } + + updatedAt := time.Time{} + lastUsedAt := time.Time{} + revokedAt := time.Time{} + + if db.UpdatedAt.Valid { + updatedAt = db.UpdatedAt.Time + } + + if db.LastUsedAt.Valid { + lastUsedAt = db.LastUsedAt.Time + } + + if db.RevokedAt.Valid { + revokedAt = db.RevokedAt.Time + } + + pat := auth.PAT{ + ID: db.ID, + User: db.User, + Name: db.Name, + Description: db.Description, + Secret: db.Secret, + IssuedAt: db.IssuedAt, + ExpiresAt: db.ExpiresAt, + UpdatedAt: updatedAt, + LastUsedAt: lastUsedAt, + Revoked: db.Revoked, + RevokedAt: revokedAt, + } + + return pat, nil +} + +func toAuthScope(dsc []dbScope) ([]auth.Scope, error) { + scope := []auth.Scope{} + + for _, s := range dsc { + entityType, err := auth.ParseEntityType(s.EntityType) + if err != nil { + return []auth.Scope{}, err + } + operation, err := auth.ParseOperation(s.Operation) + if err != nil { + return []auth.Scope{}, err + } + scope = append(scope, auth.Scope{ + ID: s.ID, + PatID: s.PatID, + OptionalDomainID: s.OptionalDomainID, + EntityType: entityType, + EntityID: s.EntityID, + Operation: operation, + }) + } + + return scope, nil +} + +func toDBPats(pat auth.PAT) (dbPat, error) { + var updatedAt, lastUsedAt, revokedAt sql.NullTime + + if !pat.UpdatedAt.IsZero() { + updatedAt = sql.NullTime{ + Time: pat.UpdatedAt, + Valid: true, + } + } + + if !pat.LastUsedAt.IsZero() { + lastUsedAt = sql.NullTime{ + Time: pat.LastUsedAt, + Valid: true, + } + } + + if !pat.RevokedAt.IsZero() { + revokedAt = sql.NullTime{ + Time: pat.RevokedAt, + Valid: true, + } + } + + return dbPat{ + ID: pat.ID, + User: pat.User, + Name: pat.Name, + Description: pat.Description, + Secret: pat.Secret, + IssuedAt: pat.IssuedAt, + ExpiresAt: pat.ExpiresAt, + UpdatedAt: updatedAt, + LastUsedAt: lastUsedAt, + Revoked: pat.Revoked, + RevokedAt: revokedAt, + }, nil +} + +func toDBScope(sc []auth.Scope) []dbScope { + var scopes []dbScope + for _, s := range sc { + scopes = append(scopes, dbScope{ + ID: s.ID, + PatID: s.PatID, + OptionalDomainID: s.OptionalDomainID, + EntityType: s.EntityType.String(), + EntityID: s.EntityID, + Operation: s.Operation.String(), + }) + } + return scopes +} diff --git a/auth/postgres/repo.go b/auth/postgres/repo.go new file mode 100644 index 0000000000..5efa702d6e --- /dev/null +++ b/auth/postgres/repo.go @@ -0,0 +1,657 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package postgres + +import ( + "context" + "database/sql" + "fmt" + "strings" + "time" + + "github.com/absmach/supermq/auth" + "github.com/absmach/supermq/pkg/errors" + repoerr "github.com/absmach/supermq/pkg/errors/repository" + "github.com/absmach/supermq/pkg/postgres" +) + +var _ auth.PATSRepository = (*patRepo)(nil) + +type patRepo struct { + db postgres.Database + cache auth.Cache +} + +func NewPatRepo(db postgres.Database, cache auth.Cache) auth.PATSRepository { + return &patRepo{ + db: db, + cache: cache, + } +} + +func (pr *patRepo) Save(ctx context.Context, pat auth.PAT) error { + q := ` + INSERT INTO pats ( + id, user_id, name, description, secret, issued_at, expires_at, + updated_at, last_used_at, revoked, revoked_at + ) VALUES ( + :id, :user_id, :name, :description, :secret, :issued_at, :expires_at, + :updated_at, :last_used_at, :revoked, :revoked_at + )` + + dbPat, err := toDBPats(pat) + if err != nil { + return errors.Wrap(repoerr.ErrCreateEntity, err) + } + + _, err = pr.db.NamedQueryContext(ctx, q, dbPat) + if err != nil { + return postgres.HandleError(repoerr.ErrCreateEntity, err) + } + + return nil +} + +func (pr *patRepo) Retrieve(ctx context.Context, userID, patID string) (auth.PAT, error) { + pat, err := pr.retrievePATFromDB(ctx, userID, patID) + if err != nil { + return auth.PAT{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + + return pat, nil +} + +func (pr *patRepo) RetrieveAll(ctx context.Context, userID string, pm auth.PATSPageMeta) (auth.PATSPage, error) { + q := ` + SELECT + p.id, p.user_id, p.name, p.description, p.issued_at, p.expires_at, + p.updated_at, p.revoked, p.revoked_at + FROM pats p WHERE user_id = :user_id + ORDER BY issued_at DESC + LIMIT :limit OFFSET :offset` + + dbPage := dbPagemeta{ + Limit: pm.Limit, + Offset: pm.Offset, + User: userID, + } + + rows, err := pr.db.NamedQueryContext(ctx, q, dbPage) + if err != nil { + return auth.PATSPage{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + defer rows.Close() + + var items []auth.PAT + for rows.Next() { + var pat dbPat + if err := rows.StructScan(&pat); err != nil { + return auth.PATSPage{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + + var updatedAt, revokedAt time.Time + if pat.UpdatedAt.Valid { + updatedAt = pat.UpdatedAt.Time + } + if pat.RevokedAt.Valid { + revokedAt = pat.RevokedAt.Time + } + + items = append(items, auth.PAT{ + ID: pat.ID, + User: pat.User, + Name: pat.Name, + Description: pat.Description, + IssuedAt: pat.IssuedAt, + ExpiresAt: pat.ExpiresAt, + UpdatedAt: updatedAt, + Revoked: pat.Revoked, + RevokedAt: revokedAt, + }) + } + + cq := `SELECT COUNT(*) FROM pats p WHERE user_id = :user_id` + + total, err := postgres.Total(ctx, pr.db, cq, dbPage) + if err != nil { + return auth.PATSPage{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + + page := auth.PATSPage{ + PATS: items, + Total: total, + Offset: pm.Offset, + Limit: pm.Limit, + } + return page, nil +} + +func (pr *patRepo) RetrieveSecretAndRevokeStatus(ctx context.Context, userID, patID string) (string, bool, bool, error) { + q := ` + SELECT p.secret, p.revoked, p.expires_at + FROM pats p + WHERE user_id = $1 AND id = $2` + + rows, err := pr.db.QueryContext(ctx, q, userID, patID) + if err != nil { + return "", true, true, postgres.HandleError(repoerr.ErrNotFound, err) + } + defer rows.Close() + + var secret string + var revoked bool + var expiresAt time.Time + + if !rows.Next() { + return "", true, true, repoerr.ErrNotFound + } + + if err := rows.Scan(&secret, &revoked, &expiresAt); err != nil { + return "", true, true, postgres.HandleError(repoerr.ErrNotFound, err) + } + + expired := time.Now().After(expiresAt) + return secret, revoked, expired, nil +} + +func (pr *patRepo) UpdateName(ctx context.Context, userID, patID, name string) (auth.PAT, error) { + q := ` + UPDATE pats p + SET name = :name, updated_at = :updated_at + WHERE user_id = :user_id AND id = :id + RETURNING id, user_id, name, description, secret, issued_at, updated_at, expires_at, revoked, revoked_at, last_used_at` + + upm := dbPagemeta{ + User: userID, + ID: patID, + Name: name, + UpdatedAt: sql.NullTime{ + Time: time.Now(), + Valid: true, + }, + } + + rows, err := pr.db.NamedQueryContext(ctx, q, upm) + if err != nil { + return auth.PAT{}, errors.Wrap(repoerr.ErrUpdateEntity, err) + } + + if !rows.Next() { + return auth.PAT{}, repoerr.ErrNotFound + } + + var pat dbPat + if err := rows.StructScan(&pat); err != nil { + return auth.PAT{}, errors.Wrap(repoerr.ErrUpdateEntity, err) + } + + res, err := toAuthPat(pat) + if err != nil { + return auth.PAT{}, errors.Wrap(repoerr.ErrUpdateEntity, err) + } + + return res, nil +} + +func (pr *patRepo) UpdateDescription(ctx context.Context, userID, patID, description string) (auth.PAT, error) { + q := ` + UPDATE pats + SET description = :description, updated_at = :updated_at + WHERE user_id = :user_id AND id = :id + RETURNING id, user_id, name, description, secret, issued_at, updated_at, expires_at, revoked, revoked_at, last_used_at` + + upm := dbPagemeta{ + User: userID, + ID: patID, + UpdatedAt: sql.NullTime{ + Time: time.Now(), + Valid: true, + }, + Description: description, + } + + rows, err := pr.db.NamedQueryContext(ctx, q, upm) + if err != nil { + return auth.PAT{}, errors.Wrap(repoerr.ErrUpdateEntity, err) + } + + if !rows.Next() { + return auth.PAT{}, repoerr.ErrNotFound + } + + var pat dbPat + if err := rows.StructScan(&pat); err != nil { + return auth.PAT{}, errors.Wrap(repoerr.ErrUpdateEntity, err) + } + + res, err := toAuthPat(pat) + if err != nil { + return auth.PAT{}, errors.Wrap(repoerr.ErrUpdateEntity, err) + } + + return res, nil +} + +func (pr *patRepo) UpdateTokenHash(ctx context.Context, userID, patID, tokenHash string, expiryAt time.Time) (auth.PAT, error) { + q := ` + UPDATE pats + SET secret = :secret, expires_at = :expires_at, updated_at = :updated_at + WHERE user_id = :user_id AND id = :id + RETURNING id, user_id, name, description, secret, issued_at, updated_at, expires_at, revoked, revoked_at, last_used_at` + + upm := dbPagemeta{ + User: userID, + ID: patID, + UpdatedAt: sql.NullTime{ + Time: time.Now(), + Valid: true, + }, + ExpiresAt: expiryAt, + Secret: tokenHash, + } + + rows, err := pr.db.NamedQueryContext(ctx, q, upm) + if err != nil { + return auth.PAT{}, errors.Wrap(repoerr.ErrUpdateEntity, err) + } + + if !rows.Next() { + return auth.PAT{}, repoerr.ErrNotFound + } + + var pat dbPat + if err := rows.StructScan(&pat); err != nil { + return auth.PAT{}, errors.Wrap(repoerr.ErrUpdateEntity, err) + } + + res, err := toAuthPat(pat) + if err != nil { + return auth.PAT{}, errors.Wrap(repoerr.ErrUpdateEntity, err) + } + + return res, nil +} + +func (pr *patRepo) Revoke(ctx context.Context, userID, patID string) error { + q := ` + UPDATE pats + SET revoked = true, revoked_at = :revoked_at + WHERE user_id = :user_id AND id = :id` + + upm := dbPagemeta{ + User: userID, + ID: patID, + RevokedAt: sql.NullTime{ + Time: time.Now(), + Valid: true, + }, + } + + _, err := pr.db.NamedQueryContext(ctx, q, upm) + if err != nil { + return errors.Wrap(repoerr.ErrUpdateEntity, err) + } + + return nil +} + +func (pr *patRepo) Reactivate(ctx context.Context, userID, patID string) error { + q := ` + UPDATE pats + SET revoked = false, revoked_at = NULL + WHERE user_id = :user_id AND id = :id` + + upm := dbPagemeta{ + User: userID, + ID: patID, + } + + _, err := pr.db.NamedQueryContext(ctx, q, upm) + if err != nil { + return errors.Wrap(repoerr.ErrUpdateEntity, err) + } + + return nil +} + +func (pr *patRepo) Remove(ctx context.Context, userID, patID string) error { + q := `DELETE FROM pats WHERE user_id = :user_id AND id = :id` + upm := dbPagemeta{ + User: userID, + ID: patID, + } + + _, err := pr.db.NamedQueryContext(ctx, q, upm) + if err != nil { + return postgres.HandleError(repoerr.ErrRemoveEntity, err) + } + + return nil +} + +func (pr *patRepo) RemoveAllPAT(ctx context.Context, userID string) error { + q := `DELETE FROM pats WHERE user_id = :user_id` + + pm := dbPagemeta{ + User: userID, + } + + _, err := pr.db.NamedQueryContext(ctx, q, pm) + if err != nil { + return postgres.HandleError(repoerr.ErrRemoveEntity, err) + } + + if err := pr.cache.RemoveUserAllScope(ctx, userID); err != nil { + return errors.Wrap(repoerr.ErrRemoveEntity, err) + } + + return nil +} + +func (pr *patRepo) AddScope(ctx context.Context, userID string, scopes []auth.Scope) error { + q := ` + INSERT INTO pat_scopes (id, pat_id, entity_type, optional_domain_id, operation, entity_id) + VALUES (:id, :pat_id, :entity_type, :optional_domain_id, :operation, :entity_id)` + + var newScopes []auth.Scope + + for _, sc := range scopes { + processedScope, err := pr.processScope(ctx, sc) + if err != nil { + return err + } + if processedScope.ID != "" { + newScopes = append(newScopes, processedScope) + } + } + + if len(newScopes) > 0 { + _, err := pr.db.NamedQueryContext(ctx, q, toDBScope(newScopes)) + if err != nil { + return postgres.HandleError(repoerr.ErrUpdateEntity, err) + } + } + + if err := pr.cache.Save(ctx, userID, scopes); err != nil { + return errors.Wrap(repoerr.ErrUpdateEntity, err) + } + + return nil +} + +func (pr *patRepo) processScope(ctx context.Context, sc auth.Scope) (auth.Scope, error) { + q := ` + SELECT COUNT(*) + FROM pat_scopes + WHERE pat_id = :pat_id + AND entity_type = :entity_type + AND optional_domain_id = :optional_domain_id + AND operation = :operation + AND entity_id = :entity_id + LIMIT 1` + + params := dbScope{ + PatID: sc.PatID, + OptionalDomainID: sc.OptionalDomainID, + EntityType: sc.EntityType.String(), + Operation: sc.Operation.String(), + EntityID: auth.AnyIDs, + } + + rows, err := pr.db.NamedQueryContext(ctx, q, params) + if err != nil { + return auth.Scope{}, postgres.HandleError(repoerr.ErrUpdateEntity, err) + } + defer rows.Close() + + var count int + if rows.Next() { + if err := rows.Scan(&count); err != nil { + return auth.Scope{}, postgres.HandleError(repoerr.ErrViewEntity, err) + } + } + + if count > 0 { + return auth.Scope{}, repoerr.ErrConflict + } + + if sc.EntityID == auth.AnyIDs { + newParams := dbScope{ + PatID: sc.PatID, + OptionalDomainID: sc.OptionalDomainID, + EntityType: sc.EntityType.String(), + Operation: sc.Operation.String(), + } + + checkEntityQuery := ` + SELECT COUNT(*) + FROM pat_scopes + WHERE pat_id = :pat_id + AND entity_type = :entity_type + AND optional_domain_id = :optional_domain_id + AND operation = :operation + LIMIT 1` + + rows, err := pr.db.NamedQueryContext(ctx, checkEntityQuery, newParams) + if err != nil { + return auth.Scope{}, postgres.HandleError(repoerr.ErrUpdateEntity, err) + } + defer rows.Close() + + var count int + if rows.Next() { + if err := rows.Scan(&count); err != nil { + return auth.Scope{}, postgres.HandleError(repoerr.ErrViewEntity, err) + } + } + + if count > 0 { + updateWithWildcardQuery := ` + UPDATE pat_scopes + SET entity_id = :entity_id + WHERE pat_id = :pat_id + AND entity_type = :entity_type + AND optional_domain_id = :optional_domain_id + AND operation = :operation` + + _, err = pr.db.NamedQueryContext(ctx, updateWithWildcardQuery, params) + if err != nil { + return auth.Scope{}, postgres.HandleError(repoerr.ErrUpdateEntity, err) + } + return auth.Scope{}, nil + } + } + + return sc, nil +} + +func (pr *patRepo) RemoveScope(ctx context.Context, userID string, scopesIDs ...string) error { + deleteScopesQuery := fmt.Sprintf(`DELETE FROM pat_scopes WHERE id IN ('%s')`, strings.Join(scopesIDs, ",")) + + res, err := pr.db.ExecContext(ctx, deleteScopesQuery) + if err != nil { + return errors.Wrap(repoerr.ErrRemoveEntity, err) + } + + if rows, _ := res.RowsAffected(); rows == 0 { + return repoerr.ErrNotFound + } + + if err := pr.cache.Remove(ctx, userID, scopesIDs); err != nil { + return errors.Wrap(repoerr.ErrRemoveEntity, err) + } + + return nil +} + +func (pr *patRepo) CheckScope(ctx context.Context, userID, patID string, entityType auth.EntityType, optionalDomainID string, operation auth.Operation, entityID string) error { + q := ` + SELECT id, pat_id, entity_type, optional_domain_id, operation, entity_id + FROM pat_scopes + WHERE pat_id = :pat_id + AND entity_type = :entity_type + AND optional_domain_id = :optional_domain_id + AND operation = :operation + AND (entity_id = :entity_id OR entity_id = '*') + LIMIT 1` + + authorized := pr.cache.CheckScope(ctx, userID, patID, optionalDomainID, entityType, operation, entityID) + if authorized { + return nil + } + + scope := dbScope{ + PatID: patID, + EntityType: entityType.String(), + OptionalDomainID: optionalDomainID, + Operation: operation.String(), + EntityID: entityID, + } + + rows, err := pr.db.NamedQueryContext(ctx, q, scope) + if err != nil { + return errors.Wrap(repoerr.ErrViewEntity, err) + } + defer rows.Close() + + if rows.Next() { + var sc dbScope + if err := rows.StructScan(&sc); err != nil { + return errors.Wrap(repoerr.ErrViewEntity, err) + } + + entityType, err := auth.ParseEntityType(sc.EntityType) + if err != nil { + return errors.Wrap(repoerr.ErrViewEntity, err) + } + operation, err := auth.ParseOperation(sc.Operation) + if err != nil { + return errors.Wrap(repoerr.ErrViewEntity, err) + } + authScope := auth.Scope{ + ID: sc.ID, + PatID: sc.PatID, + OptionalDomainID: sc.OptionalDomainID, + EntityType: entityType, + EntityID: sc.EntityID, + Operation: operation, + } + + if err := pr.cache.Save(ctx, userID, []auth.Scope{authScope}); err != nil { + return err + } + + if authScope.Authorized(entityType, optionalDomainID, operation, entityID) { + return nil + } + } + + return repoerr.ErrNotFound +} + +func (pr *patRepo) RemoveAllScope(ctx context.Context, patID string) error { + pm := dbPagemeta{ + PatID: patID, + } + + q := `DELETE FROM pat_scopes WHERE pat_id = :pat_id` + + _, err := pr.db.NamedQueryContext(ctx, q, pm) + if err != nil { + return postgres.HandleError(repoerr.ErrRemoveEntity, err) + } + + if err := pr.cache.RemoveAllScope(ctx, pm.User, patID); err != nil { + return errors.Wrap(repoerr.ErrRemoveEntity, err) + } + + return nil +} + +func (pr *patRepo) RetrieveScope(ctx context.Context, pm auth.ScopesPageMeta) (auth.ScopesPage, error) { + dbs := dbPagemeta{ + PatID: pm.PatID, + Offset: pm.Offset, + Limit: pm.Limit, + } + + scopes, err := pr.retrieveScopeFromDB(ctx, dbs) + if err != nil { + return auth.ScopesPage{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + + cq := `SELECT COUNT(*) FROM pat_scopes WHERE pat_id = :pat_id` + + total, err := postgres.Total(ctx, pr.db, cq, dbs) + if err != nil { + return auth.ScopesPage{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + + return auth.ScopesPage{ + Total: total, + Scopes: scopes, + Offset: pm.Offset, + Limit: pm.Limit, + }, nil +} + +func (pr *patRepo) retrieveScopeFromDB(ctx context.Context, pm dbPagemeta) ([]auth.Scope, error) { + q := ` + SELECT id, pat_id, entity_type, optional_domain_id, operation, entity_id + FROM pat_scopes WHERE pat_id = :pat_id OFFSET :offset LIMIT :limit` + scopeRows, err := pr.db.NamedQueryContext(ctx, q, pm) + if err != nil { + return []auth.Scope{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + defer scopeRows.Close() + + var scopes []dbScope + for scopeRows.Next() { + var scope dbScope + if err := scopeRows.StructScan(&scope); err != nil { + return []auth.Scope{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + scopes = append(scopes, scope) + } + + sc, err := toAuthScope(scopes) + if err != nil { + return []auth.Scope{}, err + } + + return sc, nil +} + +func (pr *patRepo) retrievePATFromDB(ctx context.Context, userID, patID string) (auth.PAT, error) { + q := ` + SELECT + id, user_id, name, description, secret, issued_at, expires_at, + updated_at, last_used_at, revoked, revoked_at + FROM pats WHERE user_id = :user_id AND id = :id` + + dbp := dbPagemeta{ + ID: patID, + User: userID, + } + + rows, err := pr.db.NamedQueryContext(ctx, q, dbp) + if err != nil { + return auth.PAT{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + defer rows.Close() + + var record dbPat + if rows.Next() { + if err := rows.StructScan(&record); err != nil { + return auth.PAT{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + } + + pat, err := toAuthPat(record) + if err != nil { + return auth.PAT{}, err + } + + return pat, nil +} diff --git a/auth/service.go b/auth/service.go index bf31a40850..a3c28fb183 100644 --- a/auth/service.go +++ b/auth/service.go @@ -44,8 +44,7 @@ var ( errUpdatePAT = errors.New("failed to update PAT") errRetrievePAT = errors.New("failed to retrieve PAT") errDeletePAT = errors.New("failed to delete PAT") - errRevokePAT = errors.New("failed to revoke PAT") - errClearAllScope = errors.New("failed to clear all entry in scope") + errInvalidScope = errors.New("invalid scope") ) // Authz represents a authorization service. It exposes @@ -100,6 +99,7 @@ var _ Service = (*service)(nil) type service struct { keys KeyRepository pats PATSRepository + cache Cache hasher Hasher idProvider supermq.IDProvider evaluator policies.Evaluator @@ -111,11 +111,12 @@ type service struct { } // New instantiates the auth service implementation. -func New(keys KeyRepository, pats PATSRepository, hasher Hasher, idp supermq.IDProvider, tokenizer Tokenizer, policyEvaluator policies.Evaluator, policyService policies.Service, loginDuration, refreshDuration, invitationDuration time.Duration) Service { +func New(keys KeyRepository, repo PATSRepository, cache Cache, hasher Hasher, idp supermq.IDProvider, tokenizer Tokenizer, policyEvaluator policies.Evaluator, policyService policies.Service, loginDuration, refreshDuration, invitationDuration time.Duration) Service { return &service{ tokenizer: tokenizer, keys: keys, - pats: pats, + pats: repo, + cache: cache, hasher: hasher, idProvider: idp, evaluator: policyEvaluator, @@ -457,7 +458,7 @@ func DecodeDomainUserID(domainUserID string) (string, string) { } } -func (svc service) CreatePAT(ctx context.Context, token, name, description string, duration time.Duration, scope Scope) (PAT, error) { +func (svc service) CreatePAT(ctx context.Context, token, name, description string, duration time.Duration) (PAT, error) { key, err := svc.Identify(ctx, token) if err != nil { return PAT{}, err @@ -481,17 +482,18 @@ func (svc service) CreatePAT(ctx context.Context, token, name, description strin Secret: hash, IssuedAt: now, ExpiresAt: now.Add(duration), - Scope: scope, } + if err := svc.pats.Save(ctx, pat); err != nil { return PAT{}, errors.Wrap(errCreatePAT, err) } pat.Secret = secret + return pat, nil } func (svc service) UpdatePATName(ctx context.Context, token, patID, name string) (PAT, error) { - key, err := svc.Identify(ctx, token) + key, err := svc.authnAuthzUserPAT(ctx, token, patID) if err != nil { return PAT{}, err } @@ -503,7 +505,7 @@ func (svc service) UpdatePATName(ctx context.Context, token, patID, name string) } func (svc service) UpdatePATDescription(ctx context.Context, token, patID, description string) (PAT, error) { - key, err := svc.Identify(ctx, token) + key, err := svc.authnAuthzUserPAT(ctx, token, patID) if err != nil { return PAT{}, err } @@ -514,8 +516,12 @@ func (svc service) UpdatePATDescription(ctx context.Context, token, patID, descr return pat, nil } -func (svc service) RetrievePAT(ctx context.Context, userID, patID string) (PAT, error) { - pat, err := svc.pats.Retrieve(ctx, userID, patID) +func (svc service) RetrievePAT(ctx context.Context, token, patID string) (PAT, error) { + key, err := svc.Identify(ctx, token) + if err != nil { + return PAT{}, err + } + pat, err := svc.pats.Retrieve(ctx, key.User, patID) if err != nil { return PAT{}, errors.Wrap(errRetrievePAT, err) } @@ -535,7 +541,7 @@ func (svc service) ListPATS(ctx context.Context, token string, pm PATSPageMeta) } func (svc service) DeletePAT(ctx context.Context, token, patID string) error { - key, err := svc.Identify(ctx, token) + key, err := svc.authnAuthzUserPAT(ctx, token, patID) if err != nil { return err } @@ -546,7 +552,7 @@ func (svc service) DeletePAT(ctx context.Context, token, patID string) error { } func (svc service) ResetPATSecret(ctx context.Context, token, patID string, duration time.Duration) (PAT, error) { - key, err := svc.Identify(ctx, token) + key, err := svc.authnAuthzUserPAT(ctx, token, patID) if err != nil { return PAT{}, err } @@ -572,48 +578,83 @@ func (svc service) ResetPATSecret(ctx context.Context, token, patID string, dura } func (svc service) RevokePATSecret(ctx context.Context, token, patID string) error { - key, err := svc.Identify(ctx, token) + key, err := svc.authnAuthzUserPAT(ctx, token, patID) if err != nil { return err } if err := svc.pats.Revoke(ctx, key.User, patID); err != nil { - return errors.Wrap(errRevokePAT, err) + return errors.Wrap(svcerr.ErrUpdateEntity, err) } return nil } -func (svc service) AddPATScopeEntry(ctx context.Context, token, patID string, platformEntityType PlatformEntityType, optionalDomainID string, optionalDomainEntityType DomainEntityType, operation OperationType, entityIDs ...string) (Scope, error) { +func (svc service) RemoveAllPAT(ctx context.Context, token string) error { key, err := svc.Identify(ctx, token) if err != nil { - return Scope{}, err + return err + } + if err := svc.pats.RemoveAllPAT(ctx, key.User); err != nil { + return errors.Wrap(svcerr.ErrRemoveEntity, err) + } + return nil +} + +func (svc service) AddScope(ctx context.Context, token, patID string, scopes []Scope) error { + key, err := svc.authnAuthzUserPAT(ctx, token, patID) + if err != nil { + return err + } + + for i := range len(scopes) { + scopes[i].ID, err = svc.idProvider.ID() + if err != nil { + return errors.Wrap(svcerr.ErrCreateEntity, err) + } + + scopes[i].PatID = patID } - scope, err := svc.pats.AddScopeEntry(ctx, key.User, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) + + err = svc.pats.AddScope(ctx, key.User, scopes) if err != nil { - return Scope{}, errors.Wrap(errRevokePAT, err) + return errors.Wrap(svcerr.ErrCreateEntity, err) } - return scope, nil + return nil } -func (svc service) RemovePATScopeEntry(ctx context.Context, token, patID string, platformEntityType PlatformEntityType, optionalDomainID string, optionalDomainEntityType DomainEntityType, operation OperationType, entityIDs ...string) (Scope, error) { - key, err := svc.Identify(ctx, token) +func (svc service) RemoveScope(ctx context.Context, token, patID string, scopesIDs ...string) error { + key, err := svc.authnAuthzUserPAT(ctx, token, patID) + if err != nil { + return err + } + + err = svc.pats.RemoveScope(ctx, key.User, scopesIDs...) + if err != nil { + return errors.Wrap(svcerr.ErrRemoveEntity, err) + } + return nil +} + +func (svc service) ListScopes(ctx context.Context, token string, pm ScopesPageMeta) (ScopesPage, error) { + _, err := svc.authnAuthzUserPAT(ctx, token, pm.PatID) if err != nil { - return Scope{}, err + return ScopesPage{}, err } - scope, err := svc.pats.RemoveScopeEntry(ctx, key.User, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) + patsPage, err := svc.pats.RetrieveScope(ctx, pm) if err != nil { - return Scope{}, err + return ScopesPage{}, errors.Wrap(errRetrievePAT, err) } - return scope, nil + + return patsPage, nil } -func (svc service) ClearPATAllScopeEntry(ctx context.Context, token, patID string) error { - key, err := svc.Identify(ctx, token) +func (svc service) RemovePATAllScope(ctx context.Context, token, patID string) error { + _, err := svc.authnAuthzUserPAT(ctx, token, patID) if err != nil { return err } - if err := svc.pats.RemoveAllScopeEntry(ctx, key.User, patID); err != nil { - return errors.Wrap(errClearAllScope, err) + if err := svc.pats.RemoveAllScope(ctx, patID); err != nil { + return errors.Wrap(svcerr.ErrRemoveEntity, err) } return nil } @@ -643,21 +684,11 @@ func (svc service) IdentifyPAT(ctx context.Context, secret string) (PAT, error) return PAT{ID: patID.String(), User: userID.String()}, nil } -func (svc service) AuthorizePAT(ctx context.Context, userID, patID string, platformEntityType PlatformEntityType, optionalDomainID string, optionalDomainEntityType DomainEntityType, operation OperationType, entityIDs ...string) error { - res, err := svc.RetrievePAT(ctx, userID, patID) - if err != nil { - return err - } - if err := svc.pats.CheckScopeEntry(ctx, res.User, res.ID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...); err != nil { +func (svc service) AuthorizePAT(ctx context.Context, userID, patID string, entityType EntityType, optionalDomainID string, operation Operation, entityID string) error { + if err := svc.pats.CheckScope(ctx, userID, patID, entityType, optionalDomainID, operation, entityID); err != nil { return errors.Wrap(svcerr.ErrAuthorization, err) } - return nil -} -func (svc service) CheckPAT(ctx context.Context, userID, patID string, platformEntityType PlatformEntityType, optionalDomainID string, optionalDomainEntityType DomainEntityType, operation OperationType, entityIDs ...string) error { - if err := svc.pats.CheckScopeEntry(ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...); err != nil { - return errors.Wrap(svcerr.ErrAuthorization, err) - } return nil } @@ -707,3 +738,17 @@ func generateRandomString(n int) string { } return string(b) } + +func (svc service) authnAuthzUserPAT(ctx context.Context, token, patID string) (Key, error) { + key, err := svc.Identify(ctx, token) + if err != nil { + return Key{}, err + } + + _, err = svc.pats.Retrieve(ctx, key.User, patID) + if err != nil { + return Key{}, errors.Wrap(svcerr.ErrAuthorization, err) + } + + return key, nil +} diff --git a/auth/service_test.go b/auth/service_test.go index 18115fd4c7..f3c8b0d97b 100644 --- a/auth/service_test.go +++ b/auth/service_test.go @@ -50,11 +50,13 @@ var ( pService *policymocks.Service pEvaluator *policymocks.Evaluator patsrepo *mocks.PATSRepository + cache *mocks.Cache hasher *mocks.Hasher ) func newService() (auth.Service, string) { krepo = new(mocks.KeyRepository) + cache = new(mocks.Cache) pService = new(policymocks.Service) pEvaluator = new(policymocks.Evaluator) patsrepo = new(mocks.PATSRepository) @@ -72,7 +74,7 @@ func newService() (auth.Service, string) { } token, _ := t.Issue(key) - return auth.New(krepo, patsrepo, hasher, idProvider, t, pEvaluator, pService, loginDuration, refreshDuration, invalidDuration), token + return auth.New(krepo, patsrepo, cache, hasher, idProvider, t, pEvaluator, pService, loginDuration, refreshDuration, invalidDuration), token } func TestIssue(t *testing.T) { diff --git a/auth/tracing/tracing.go b/auth/tracing/tracing.go index 945bd8f4e8..fc76df98e3 100644 --- a/auth/tracing/tracing.go +++ b/auth/tracing/tracing.go @@ -76,15 +76,14 @@ func (tm *tracingMiddleware) Authorize(ctx context.Context, pr policies.Policy) return tm.svc.Authorize(ctx, pr) } -func (tm *tracingMiddleware) CreatePAT(ctx context.Context, token, name, description string, duration time.Duration, scope auth.Scope) (auth.PAT, error) { +func (tm *tracingMiddleware) CreatePAT(ctx context.Context, token, name, description string, duration time.Duration) (auth.PAT, error) { ctx, span := tm.tracer.Start(ctx, "create_pat", trace.WithAttributes( attribute.String("name", name), attribute.String("description", description), attribute.String("duration", duration.String()), - attribute.String("scope", scope.String()), )) defer span.End() - return tm.svc.CreatePAT(ctx, token, name, description, duration, scope) + return tm.svc.CreatePAT(ctx, token, name, description, duration) } func (tm *tracingMiddleware) UpdatePATName(ctx context.Context, token, patID, name string) (auth.PAT, error) { @@ -122,6 +121,15 @@ func (tm *tracingMiddleware) ListPATS(ctx context.Context, token string, pm auth return tm.svc.ListPATS(ctx, token, pm) } +func (tm *tracingMiddleware) ListScopes(ctx context.Context, token string, pm auth.ScopesPageMeta) (auth.ScopesPage, error) { + ctx, span := tm.tracer.Start(ctx, "list_scopes", trace.WithAttributes( + attribute.Int64("limit", int64(pm.Limit)), + attribute.Int64("offset", int64(pm.Offset)), + )) + defer span.End() + return tm.svc.ListScopes(ctx, token, pm) +} + func (tm *tracingMiddleware) DeletePAT(ctx context.Context, token, patID string) error { ctx, span := tm.tracer.Start(ctx, "delete_pat", trace.WithAttributes( attribute.String("pat_id", patID), @@ -147,38 +155,47 @@ func (tm *tracingMiddleware) RevokePATSecret(ctx context.Context, token, patID s return tm.svc.RevokePATSecret(ctx, token, patID) } -func (tm *tracingMiddleware) AddPATScopeEntry(ctx context.Context, token, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (auth.Scope, error) { - ctx, span := tm.tracer.Start(ctx, "add_pat_scope_entry", trace.WithAttributes( - attribute.String("pat_id", patID), - attribute.String("platform_entity", platformEntityType.String()), - attribute.String("optional_domain_id", optionalDomainID), - attribute.String("optional_domain_entity", optionalDomainEntityType.String()), - attribute.String("operation", operation.String()), - attribute.StringSlice("entities", entityIDs), - )) +func (tm *tracingMiddleware) RemoveAllPAT(ctx context.Context, token string) error { + ctx, span := tm.tracer.Start(ctx, "clear_all_pat") defer span.End() - return tm.svc.AddPATScopeEntry(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) + return tm.svc.RemoveAllPAT(ctx, token) } -func (tm *tracingMiddleware) RemovePATScopeEntry(ctx context.Context, token, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (auth.Scope, error) { - ctx, span := tm.tracer.Start(ctx, "remove_pat_scope_entry", trace.WithAttributes( - attribute.String("pat_id", patID), - attribute.String("platform_entity", platformEntityType.String()), - attribute.String("optional_domain_id", optionalDomainID), - attribute.String("optional_domain_entity", optionalDomainEntityType.String()), - attribute.String("operation", operation.String()), - attribute.StringSlice("entities", entityIDs), - )) +func (tm *tracingMiddleware) AddScope(ctx context.Context, token, patID string, scopes []auth.Scope) error { + var attributes []attribute.KeyValue + for _, s := range scopes { + attributes = append(attributes, attribute.String("entity_type", s.EntityType.String())) + attributes = append(attributes, attribute.String("optional_domain_id", s.OptionalDomainID)) + attributes = append(attributes, attribute.String("operation", s.Operation.String())) + attributes = append(attributes, attribute.String("entity_id", s.EntityID)) + } + + attributes = append(attributes, attribute.String("pat_id", patID)) + + ctx, span := tm.tracer.Start(ctx, "add_pat_scope", trace.WithAttributes(attributes...)) + defer span.End() + return tm.svc.AddScope(ctx, token, patID, scopes) +} + +func (tm *tracingMiddleware) RemoveScope(ctx context.Context, token, patID string, scopesID ...string) error { + var attributes []attribute.KeyValue + for _, s := range scopesID { + attributes = append(attributes, attribute.String("scope_id", s)) + } + + attributes = append(attributes, attribute.String("pat_id", patID)) + + ctx, span := tm.tracer.Start(ctx, "remove_pat_scope", trace.WithAttributes(attributes...)) defer span.End() - return tm.svc.RemovePATScopeEntry(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) + return tm.svc.RemoveScope(ctx, token, patID, scopesID...) } -func (tm *tracingMiddleware) ClearPATAllScopeEntry(ctx context.Context, token, patID string) error { - ctx, span := tm.tracer.Start(ctx, "clear_pat_all_scope_entry", trace.WithAttributes( +func (tm *tracingMiddleware) RemovePATAllScope(ctx context.Context, token, patID string) error { + ctx, span := tm.tracer.Start(ctx, "clear_pat_all_scope", trace.WithAttributes( attribute.String("pat_id", patID), )) defer span.End() - return tm.svc.ClearPATAllScopeEntry(ctx, token, patID) + return tm.svc.RemovePATAllScope(ctx, token, patID) } func (tm *tracingMiddleware) IdentifyPAT(ctx context.Context, paToken string) (auth.PAT, error) { @@ -187,29 +204,14 @@ func (tm *tracingMiddleware) IdentifyPAT(ctx context.Context, paToken string) (a return tm.svc.IdentifyPAT(ctx, paToken) } -func (tm *tracingMiddleware) AuthorizePAT(ctx context.Context, userID, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) error { +func (tm *tracingMiddleware) AuthorizePAT(ctx context.Context, userID, patID string, entityType auth.EntityType, optionalDomainID string, operation auth.Operation, entityID string) error { ctx, span := tm.tracer.Start(ctx, "authorize_pat", trace.WithAttributes( attribute.String("pat_id", patID), - attribute.String("platform_entity", platformEntityType.String()), - attribute.String("optional_domain_id", optionalDomainID), - attribute.String("optional_domain_entity", optionalDomainEntityType.String()), - attribute.String("operation", operation.String()), - attribute.StringSlice("entities", entityIDs), - )) - defer span.End() - return tm.svc.AuthorizePAT(ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) -} - -func (tm *tracingMiddleware) CheckPAT(ctx context.Context, userID, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) error { - ctx, span := tm.tracer.Start(ctx, "check_pat", trace.WithAttributes( - attribute.String("user_id", userID), - attribute.String("pat_id", patID), - attribute.String("platform_entity", platformEntityType.String()), + attribute.String("entity_type", entityType.String()), attribute.String("optional_domain_id", optionalDomainID), - attribute.String("optional_domain_entity", optionalDomainEntityType.String()), attribute.String("operation", operation.String()), - attribute.StringSlice("entities", entityIDs), + attribute.String("entities", entityID), )) defer span.End() - return tm.svc.CheckPAT(ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) + return tm.svc.AuthorizePAT(ctx, userID, patID, entityType, optionalDomainID, operation, entityID) } diff --git a/channels/middleware/authorization.go b/channels/middleware/authorization.go index ebf0ff19ec..a699e0fbd9 100644 --- a/channels/middleware/authorization.go +++ b/channels/middleware/authorization.go @@ -85,13 +85,12 @@ func AuthorizationMiddleware(svc channels.Service, repo channels.Repository, aut func (am *authorizationMiddleware) CreateChannels(ctx context.Context, session authn.Session, chs ...channels.Channel) ([]channels.Channel, []roles.RoleProvision, error) { if session.Type == authn.PersonalAccessToken { if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{ - UserID: session.UserID, - PatID: session.PatID, - PlatformEntityType: auth.PlatformDomainsScope, - OptionalDomainID: session.DomainID, - OptionalDomainEntityType: auth.DomainChannelsScope, - Operation: auth.CreateOp, - EntityIDs: auth.AnyIDs{}.Values(), + UserID: session.UserID, + PatID: session.PatID, + EntityType: auth.ChannelsType, + OptionalDomainID: session.DomainID, + Operation: auth.CreateOp, + EntityID: auth.AnyIDs, }); err != nil { return []channels.Channel{}, []roles.RoleProvision{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err) } @@ -126,13 +125,12 @@ func (am *authorizationMiddleware) CreateChannels(ctx context.Context, session a func (am *authorizationMiddleware) ViewChannel(ctx context.Context, session authn.Session, id string) (channels.Channel, error) { if session.Type == authn.PersonalAccessToken { if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{ - UserID: session.UserID, - PatID: session.PatID, - PlatformEntityType: auth.PlatformDomainsScope, - OptionalDomainID: session.DomainID, - OptionalDomainEntityType: auth.DomainChannelsScope, - Operation: auth.ReadOp, - EntityIDs: []string{id}, + UserID: session.UserID, + PatID: session.PatID, + EntityType: auth.ChannelsType, + OptionalDomainID: session.DomainID, + Operation: auth.ReadOp, + EntityID: id, }); err != nil { return channels.Channel{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err) } @@ -153,13 +151,12 @@ func (am *authorizationMiddleware) ViewChannel(ctx context.Context, session auth func (am *authorizationMiddleware) ListChannels(ctx context.Context, session authn.Session, pm channels.PageMetadata) (channels.Page, error) { if session.Type == authn.PersonalAccessToken { if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{ - UserID: session.UserID, - PatID: session.PatID, - PlatformEntityType: auth.PlatformDomainsScope, - OptionalDomainID: session.DomainID, - OptionalDomainEntityType: auth.DomainChannelsScope, - Operation: auth.ListOp, - EntityIDs: auth.AnyIDs{}.Values(), + UserID: session.UserID, + PatID: session.PatID, + EntityType: auth.ChannelsType, + OptionalDomainID: session.DomainID, + Operation: auth.ListOp, + EntityID: auth.AnyIDs, }); err != nil { return channels.Page{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err) } @@ -174,13 +171,12 @@ func (am *authorizationMiddleware) ListChannels(ctx context.Context, session aut func (am *authorizationMiddleware) ListUserChannels(ctx context.Context, session authn.Session, userID string, pm channels.PageMetadata) (channels.Page, error) { if session.Type == authn.PersonalAccessToken { if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{ - UserID: session.UserID, - PatID: session.PatID, - PlatformEntityType: auth.PlatformDomainsScope, - OptionalDomainID: session.DomainID, - OptionalDomainEntityType: auth.DomainChannelsScope, - Operation: auth.ListOp, - EntityIDs: auth.AnyIDs{}.Values(), + UserID: session.UserID, + PatID: session.PatID, + EntityType: auth.ChannelsType, + OptionalDomainID: session.DomainID, + Operation: auth.ListOp, + EntityID: auth.AnyIDs, }); err != nil { return channels.Page{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err) } @@ -194,13 +190,12 @@ func (am *authorizationMiddleware) ListUserChannels(ctx context.Context, session func (am *authorizationMiddleware) UpdateChannel(ctx context.Context, session authn.Session, channel channels.Channel) (channels.Channel, error) { if session.Type == authn.PersonalAccessToken { if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{ - UserID: session.UserID, - PatID: session.PatID, - PlatformEntityType: auth.PlatformDomainsScope, - OptionalDomainID: session.DomainID, - OptionalDomainEntityType: auth.DomainChannelsScope, - Operation: auth.UpdateOp, - EntityIDs: []string{channel.ID}, + UserID: session.UserID, + PatID: session.PatID, + EntityType: auth.ChannelsType, + OptionalDomainID: session.DomainID, + Operation: auth.UpdateOp, + EntityID: channel.ID, }); err != nil { return channels.Channel{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err) } @@ -221,13 +216,12 @@ func (am *authorizationMiddleware) UpdateChannel(ctx context.Context, session au func (am *authorizationMiddleware) UpdateChannelTags(ctx context.Context, session authn.Session, channel channels.Channel) (channels.Channel, error) { if session.Type == authn.PersonalAccessToken { if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{ - UserID: session.UserID, - PatID: session.PatID, - PlatformEntityType: auth.PlatformDomainsScope, - OptionalDomainID: session.DomainID, - OptionalDomainEntityType: auth.DomainChannelsScope, - Operation: auth.UpdateOp, - EntityIDs: []string{channel.ID}, + UserID: session.UserID, + PatID: session.PatID, + EntityType: auth.ChannelsType, + OptionalDomainID: session.DomainID, + Operation: auth.UpdateOp, + EntityID: channel.ID, }); err != nil { return channels.Channel{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err) } @@ -248,13 +242,12 @@ func (am *authorizationMiddleware) UpdateChannelTags(ctx context.Context, sessio func (am *authorizationMiddleware) EnableChannel(ctx context.Context, session authn.Session, id string) (channels.Channel, error) { if session.Type == authn.PersonalAccessToken { if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{ - UserID: session.UserID, - PatID: session.PatID, - PlatformEntityType: auth.PlatformDomainsScope, - OptionalDomainID: session.DomainID, - OptionalDomainEntityType: auth.DomainChannelsScope, - Operation: auth.UpdateOp, - EntityIDs: []string{id}, + UserID: session.UserID, + PatID: session.PatID, + EntityType: auth.ChannelsType, + OptionalDomainID: session.DomainID, + Operation: auth.UpdateOp, + EntityID: id, }); err != nil { return channels.Channel{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err) } @@ -275,13 +268,12 @@ func (am *authorizationMiddleware) EnableChannel(ctx context.Context, session au func (am *authorizationMiddleware) DisableChannel(ctx context.Context, session authn.Session, id string) (channels.Channel, error) { if session.Type == authn.PersonalAccessToken { if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{ - UserID: session.UserID, - PatID: session.PatID, - PlatformEntityType: auth.PlatformDomainsScope, - OptionalDomainID: session.DomainID, - OptionalDomainEntityType: auth.DomainChannelsScope, - Operation: auth.UpdateOp, - EntityIDs: []string{id}, + UserID: session.UserID, + PatID: session.PatID, + EntityType: auth.ChannelsType, + OptionalDomainID: session.DomainID, + Operation: auth.UpdateOp, + EntityID: id, }); err != nil { return channels.Channel{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err) } @@ -302,13 +294,12 @@ func (am *authorizationMiddleware) DisableChannel(ctx context.Context, session a func (am *authorizationMiddleware) RemoveChannel(ctx context.Context, session authn.Session, id string) error { if session.Type == authn.PersonalAccessToken { if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{ - UserID: session.UserID, - PatID: session.PatID, - PlatformEntityType: auth.PlatformDomainsScope, - OptionalDomainID: session.DomainID, - OptionalDomainEntityType: auth.DomainChannelsScope, - Operation: auth.DeleteOp, - EntityIDs: []string{id}, + UserID: session.UserID, + PatID: session.PatID, + EntityType: auth.ChannelsType, + OptionalDomainID: session.DomainID, + Operation: auth.DeleteOp, + EntityID: id, }); err != nil { return errors.Wrap(svcerr.ErrUnauthorizedPAT, err) } @@ -328,28 +319,29 @@ func (am *authorizationMiddleware) RemoveChannel(ctx context.Context, session au func (am *authorizationMiddleware) Connect(ctx context.Context, session authn.Session, chIDs, thIDs []string, connTypes []connections.ConnType) error { if session.Type == authn.PersonalAccessToken { - if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{ - UserID: session.UserID, - PatID: session.PatID, - PlatformEntityType: auth.PlatformDomainsScope, - OptionalDomainID: session.DomainID, - OptionalDomainEntityType: auth.DomainChannelsScope, - Operation: auth.CreateOp, - EntityIDs: chIDs, - }); err != nil { - return errors.Wrap(svcerr.ErrUnauthorizedPAT, err) + for _, chID := range chIDs { + if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{ + UserID: session.UserID, + PatID: session.PatID, + EntityType: auth.ChannelsType, + OptionalDomainID: session.DomainID, + Operation: auth.CreateOp, + EntityID: chID, + }); err != nil { + return errors.Wrap(svcerr.ErrUnauthorizedPAT, err) + } } - - if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{ - UserID: session.UserID, - PatID: session.PatID, - PlatformEntityType: auth.PlatformDomainsScope, - OptionalDomainID: session.DomainID, - OptionalDomainEntityType: auth.DomainClientsScope, - Operation: auth.CreateOp, - EntityIDs: thIDs, - }); err != nil { - return errors.Wrap(svcerr.ErrUnauthorizedPAT, err) + for _, thID := range thIDs { + if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{ + UserID: session.UserID, + PatID: session.PatID, + EntityType: auth.ClientsType, + OptionalDomainID: session.DomainID, + Operation: auth.CreateOp, + EntityID: thID, + }); err != nil { + return errors.Wrap(svcerr.ErrUnauthorizedPAT, err) + } } } for _, chID := range chIDs { @@ -380,28 +372,29 @@ func (am *authorizationMiddleware) Connect(ctx context.Context, session authn.Se func (am *authorizationMiddleware) Disconnect(ctx context.Context, session authn.Session, chIDs, thIDs []string, connTypes []connections.ConnType) error { if session.Type == authn.PersonalAccessToken { - if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{ - UserID: session.UserID, - PatID: session.PatID, - PlatformEntityType: auth.PlatformDomainsScope, - OptionalDomainID: session.DomainID, - OptionalDomainEntityType: auth.DomainChannelsScope, - Operation: auth.DeleteOp, - EntityIDs: chIDs, - }); err != nil { - return errors.Wrap(svcerr.ErrUnauthorizedPAT, err) + for _, chID := range chIDs { + if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{ + UserID: session.UserID, + PatID: session.PatID, + EntityType: auth.ChannelsType, + OptionalDomainID: session.DomainID, + Operation: auth.DeleteOp, + EntityID: chID, + }); err != nil { + return errors.Wrap(svcerr.ErrUnauthorizedPAT, err) + } } - - if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{ - UserID: session.UserID, - PatID: session.PatID, - PlatformEntityType: auth.PlatformDomainsScope, - OptionalDomainID: session.DomainID, - OptionalDomainEntityType: auth.DomainClientsScope, - Operation: auth.DeleteOp, - EntityIDs: thIDs, - }); err != nil { - return errors.Wrap(svcerr.ErrUnauthorizedPAT, err) + for _, thID := range thIDs { + if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{ + UserID: session.UserID, + PatID: session.PatID, + EntityType: auth.ClientsType, + OptionalDomainID: session.DomainID, + Operation: auth.DeleteOp, + EntityID: thID, + }); err != nil { + return errors.Wrap(svcerr.ErrUnauthorizedPAT, err) + } } } @@ -434,13 +427,12 @@ func (am *authorizationMiddleware) Disconnect(ctx context.Context, session authn func (am *authorizationMiddleware) SetParentGroup(ctx context.Context, session authn.Session, parentGroupID string, id string) error { if session.Type == authn.PersonalAccessToken { if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{ - UserID: session.UserID, - PatID: session.PatID, - PlatformEntityType: auth.PlatformDomainsScope, - OptionalDomainID: session.DomainID, - OptionalDomainEntityType: auth.DomainChannelsScope, - Operation: auth.UpdateOp, - EntityIDs: []string{id}, + UserID: session.UserID, + PatID: session.PatID, + EntityType: auth.ChannelsType, + OptionalDomainID: session.DomainID, + Operation: auth.UpdateOp, + EntityID: id, }); err != nil { return errors.Wrap(svcerr.ErrUnauthorizedPAT, err) } @@ -471,13 +463,12 @@ func (am *authorizationMiddleware) SetParentGroup(ctx context.Context, session a func (am *authorizationMiddleware) RemoveParentGroup(ctx context.Context, session authn.Session, id string) error { if session.Type == authn.PersonalAccessToken { if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{ - UserID: session.UserID, - PatID: session.PatID, - PlatformEntityType: auth.PlatformDomainsScope, - OptionalDomainID: session.DomainID, - OptionalDomainEntityType: auth.DomainChannelsScope, - Operation: auth.DeleteOp, - EntityIDs: []string{id}, + UserID: session.UserID, + PatID: session.PatID, + EntityType: auth.ChannelsType, + OptionalDomainID: session.DomainID, + Operation: auth.DeleteOp, + EntityID: id, }); err != nil { return errors.Wrap(svcerr.ErrUnauthorizedPAT, err) } diff --git a/clients/middleware/authorization.go b/clients/middleware/authorization.go index 8a3483b601..409efd08a1 100644 --- a/clients/middleware/authorization.go +++ b/clients/middleware/authorization.go @@ -77,13 +77,12 @@ func AuthorizationMiddleware(entityType string, svc clients.Service, authz smqau func (am *authorizationMiddleware) CreateClients(ctx context.Context, session authn.Session, client ...clients.Client) ([]clients.Client, []roles.RoleProvision, error) { if session.Type == authn.PersonalAccessToken { if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{ - UserID: session.UserID, - PatID: session.PatID, - PlatformEntityType: auth.PlatformDomainsScope, - OptionalDomainID: session.DomainID, - OptionalDomainEntityType: auth.DomainClientsScope, - Operation: auth.CreateOp, - EntityIDs: auth.AnyIDs{}.Values(), + UserID: session.UserID, + PatID: session.PatID, + EntityType: auth.ClientsType, + OptionalDomainID: session.DomainID, + Operation: auth.CreateOp, + EntityID: auth.AnyIDs, }); err != nil { return []clients.Client{}, []roles.RoleProvision{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err) } @@ -105,13 +104,12 @@ func (am *authorizationMiddleware) CreateClients(ctx context.Context, session au func (am *authorizationMiddleware) View(ctx context.Context, session authn.Session, id string) (clients.Client, error) { if session.Type == authn.PersonalAccessToken { if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{ - UserID: session.UserID, - PatID: session.PatID, - PlatformEntityType: auth.PlatformDomainsScope, - OptionalDomainID: session.DomainID, - OptionalDomainEntityType: auth.DomainClientsScope, - Operation: auth.ReadOp, - EntityIDs: []string{id}, + UserID: session.UserID, + PatID: session.PatID, + EntityType: auth.ClientsType, + OptionalDomainID: session.DomainID, + Operation: auth.ReadOp, + EntityID: id, }); err != nil { return clients.Client{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err) } @@ -132,13 +130,12 @@ func (am *authorizationMiddleware) View(ctx context.Context, session authn.Sessi func (am *authorizationMiddleware) ListClients(ctx context.Context, session authn.Session, pm clients.Page) (clients.ClientsPage, error) { if session.Type == authn.PersonalAccessToken { if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{ - UserID: session.UserID, - PatID: session.PatID, - PlatformEntityType: auth.PlatformDomainsScope, - OptionalDomainID: session.DomainID, - OptionalDomainEntityType: auth.DomainClientsScope, - Operation: auth.ListOp, - EntityIDs: auth.AnyIDs{}.Values(), + UserID: session.UserID, + PatID: session.PatID, + EntityType: auth.ClientsType, + OptionalDomainID: session.DomainID, + Operation: auth.ListOp, + EntityID: auth.AnyIDs, }); err != nil { return clients.ClientsPage{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err) } @@ -154,13 +151,12 @@ func (am *authorizationMiddleware) ListClients(ctx context.Context, session auth func (am *authorizationMiddleware) ListUserClients(ctx context.Context, session authn.Session, userID string, pm clients.Page) (clients.ClientsPage, error) { if session.Type == authn.PersonalAccessToken { if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{ - UserID: session.UserID, - PatID: session.PatID, - PlatformEntityType: auth.PlatformDomainsScope, - OptionalDomainID: session.DomainID, - OptionalDomainEntityType: auth.DomainClientsScope, - Operation: auth.ListOp, - EntityIDs: auth.AnyIDs{}.Values(), + UserID: session.UserID, + PatID: session.PatID, + EntityType: auth.ClientsType, + OptionalDomainID: session.DomainID, + Operation: auth.ListOp, + EntityID: auth.AnyIDs, }); err != nil { return clients.ClientsPage{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err) } @@ -176,13 +172,12 @@ func (am *authorizationMiddleware) ListUserClients(ctx context.Context, session func (am *authorizationMiddleware) Update(ctx context.Context, session authn.Session, client clients.Client) (clients.Client, error) { if session.Type == authn.PersonalAccessToken { if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{ - UserID: session.UserID, - PatID: session.PatID, - PlatformEntityType: auth.PlatformDomainsScope, - OptionalDomainID: session.DomainID, - OptionalDomainEntityType: auth.DomainClientsScope, - Operation: auth.UpdateOp, - EntityIDs: []string{client.ID}, + UserID: session.UserID, + PatID: session.PatID, + EntityType: auth.ClientsType, + OptionalDomainID: session.DomainID, + Operation: auth.UpdateOp, + EntityID: client.ID, }); err != nil { return clients.Client{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err) } @@ -204,13 +199,12 @@ func (am *authorizationMiddleware) Update(ctx context.Context, session authn.Ses func (am *authorizationMiddleware) UpdateTags(ctx context.Context, session authn.Session, client clients.Client) (clients.Client, error) { if session.Type == authn.PersonalAccessToken { if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{ - UserID: session.UserID, - PatID: session.PatID, - PlatformEntityType: auth.PlatformDomainsScope, - OptionalDomainID: session.DomainID, - OptionalDomainEntityType: auth.DomainClientsScope, - Operation: auth.UpdateOp, - EntityIDs: []string{client.ID}, + UserID: session.UserID, + PatID: session.PatID, + EntityType: auth.ClientsType, + OptionalDomainID: session.DomainID, + Operation: auth.UpdateOp, + EntityID: client.ID, }); err != nil { return clients.Client{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err) } @@ -232,13 +226,12 @@ func (am *authorizationMiddleware) UpdateTags(ctx context.Context, session authn func (am *authorizationMiddleware) UpdateSecret(ctx context.Context, session authn.Session, id, key string) (clients.Client, error) { if session.Type == authn.PersonalAccessToken { if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{ - UserID: session.UserID, - PatID: session.PatID, - PlatformEntityType: auth.PlatformDomainsScope, - OptionalDomainID: session.DomainID, - OptionalDomainEntityType: auth.DomainClientsScope, - Operation: auth.UpdateOp, - EntityIDs: []string{id}, + UserID: session.UserID, + PatID: session.PatID, + EntityType: auth.ClientsType, + OptionalDomainID: session.DomainID, + Operation: auth.UpdateOp, + EntityID: id, }); err != nil { return clients.Client{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err) } @@ -259,13 +252,12 @@ func (am *authorizationMiddleware) UpdateSecret(ctx context.Context, session aut func (am *authorizationMiddleware) Enable(ctx context.Context, session authn.Session, id string) (clients.Client, error) { if session.Type == authn.PersonalAccessToken { if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{ - UserID: session.UserID, - PatID: session.PatID, - PlatformEntityType: auth.PlatformDomainsScope, - OptionalDomainID: session.DomainID, - OptionalDomainEntityType: auth.DomainClientsScope, - Operation: auth.UpdateOp, - EntityIDs: []string{id}, + UserID: session.UserID, + PatID: session.PatID, + EntityType: auth.ClientsType, + OptionalDomainID: session.DomainID, + Operation: auth.UpdateOp, + EntityID: id, }); err != nil { return clients.Client{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err) } @@ -287,13 +279,12 @@ func (am *authorizationMiddleware) Enable(ctx context.Context, session authn.Ses func (am *authorizationMiddleware) Disable(ctx context.Context, session authn.Session, id string) (clients.Client, error) { if session.Type == authn.PersonalAccessToken { if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{ - UserID: session.UserID, - PatID: session.PatID, - PlatformEntityType: auth.PlatformDomainsScope, - OptionalDomainID: session.DomainID, - OptionalDomainEntityType: auth.DomainClientsScope, - Operation: auth.UpdateOp, - EntityIDs: []string{id}, + UserID: session.UserID, + PatID: session.PatID, + EntityType: auth.ClientsType, + OptionalDomainID: session.DomainID, + Operation: auth.UpdateOp, + EntityID: id, }); err != nil { return clients.Client{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err) } @@ -314,13 +305,12 @@ func (am *authorizationMiddleware) Disable(ctx context.Context, session authn.Se func (am *authorizationMiddleware) Delete(ctx context.Context, session authn.Session, id string) error { if session.Type == authn.PersonalAccessToken { if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{ - UserID: session.UserID, - PatID: session.PatID, - PlatformEntityType: auth.PlatformDomainsScope, - OptionalDomainID: session.DomainID, - OptionalDomainEntityType: auth.DomainClientsScope, - Operation: auth.DeleteOp, - EntityIDs: []string{id}, + UserID: session.UserID, + PatID: session.PatID, + EntityType: auth.ClientsType, + OptionalDomainID: session.DomainID, + Operation: auth.DeleteOp, + EntityID: id, }); err != nil { return errors.Wrap(svcerr.ErrUnauthorizedPAT, err) } @@ -341,13 +331,12 @@ func (am *authorizationMiddleware) Delete(ctx context.Context, session authn.Ses func (am *authorizationMiddleware) SetParentGroup(ctx context.Context, session authn.Session, parentGroupID string, id string) error { if session.Type == authn.PersonalAccessToken { if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{ - UserID: session.UserID, - PatID: session.PatID, - PlatformEntityType: auth.PlatformDomainsScope, - OptionalDomainID: session.DomainID, - OptionalDomainEntityType: auth.DomainGroupsScope, - Operation: auth.UpdateOp, - EntityIDs: []string{id}, + UserID: session.UserID, + PatID: session.PatID, + EntityType: auth.ClientsType, + OptionalDomainID: session.DomainID, + Operation: auth.UpdateOp, + EntityID: id, }); err != nil { return errors.Wrap(svcerr.ErrUnauthorizedPAT, err) } @@ -378,13 +367,12 @@ func (am *authorizationMiddleware) SetParentGroup(ctx context.Context, session a func (am *authorizationMiddleware) RemoveParentGroup(ctx context.Context, session authn.Session, id string) error { if session.Type == authn.PersonalAccessToken { if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{ - UserID: session.UserID, - PatID: session.PatID, - PlatformEntityType: auth.PlatformDomainsScope, - OptionalDomainID: session.DomainID, - OptionalDomainEntityType: auth.DomainGroupsScope, - Operation: auth.DeleteOp, - EntityIDs: []string{id}, + UserID: session.UserID, + PatID: session.PatID, + EntityType: auth.ClientsType, + OptionalDomainID: session.DomainID, + Operation: auth.DeleteOp, + EntityID: id, }); err != nil { return errors.Wrap(svcerr.ErrUnauthorizedPAT, err) } diff --git a/cmd/auth/main.go b/cmd/auth/main.go index c5ef6d2b5e..88dbd40b83 100644 --- a/cmd/auth/main.go +++ b/cmd/auth/main.go @@ -21,12 +21,12 @@ import ( authgrpcapi "github.com/absmach/supermq/auth/api/grpc/auth" tokengrpcapi "github.com/absmach/supermq/auth/api/grpc/token" httpapi "github.com/absmach/supermq/auth/api/http" - "github.com/absmach/supermq/auth/bolt" + "github.com/absmach/supermq/auth/cache" "github.com/absmach/supermq/auth/hasher" "github.com/absmach/supermq/auth/jwt" apostgres "github.com/absmach/supermq/auth/postgres" "github.com/absmach/supermq/auth/tracing" - boltclient "github.com/absmach/supermq/internal/clients/bolt" + redisclient "github.com/absmach/supermq/internal/clients/redis" smqlog "github.com/absmach/supermq/logger" "github.com/absmach/supermq/pkg/jaeger" "github.com/absmach/supermq/pkg/policies/spicedb" @@ -41,7 +41,7 @@ import ( "github.com/authzed/grpcutil" "github.com/caarlos0/env/v11" "github.com/jmoiron/sqlx" - "go.etcd.io/bbolt" + "github.com/redis/go-redis/v9" "go.opentelemetry.io/otel/trace" "golang.org/x/sync/errgroup" "google.golang.org/grpc" @@ -75,6 +75,8 @@ type config struct { SpicedbPreSharedKey string `env:"SMQ_SPICEDB_PRE_SHARED_KEY" envDefault:"12345678"` TraceRatio float64 `env:"SMQ_JAEGER_TRACE_RATIO" envDefault:"1.0"` ESURL string `env:"SMQ_ES_URL" envDefault:"nats://localhost:4222"` + CacheURL string `env:"SMQ_AUTH_CACHE_URL" envDefault:"redis://localhost:6379/0"` + CacheKeyDuration time.Duration `env:"SMQ_AUTH_CACHE_KEY_DURATION" envDefault:"10m"` } func main() { @@ -107,6 +109,14 @@ func main() { logger.Error(err.Error()) } + cacheclient, err := redisclient.Connect(cfg.CacheURL) + if err != nil { + logger.Error(err.Error()) + exitCode = 1 + return + } + defer cacheclient.Close() + am := apostgres.Migration() db, err := pgclient.Setup(dbConfig, *am) if err != nil { @@ -136,22 +146,7 @@ func main() { return } - boltDBConfig := boltclient.Config{} - if err := env.ParseWithOptions(&boltDBConfig, env.Options{Prefix: envPrefixPATDB}); err != nil { - logger.Error(fmt.Sprintf("failed to parse bolt db config : %s\n", err.Error())) - exitCode = 1 - return - } - - bClient, err := boltclient.Connect(boltDBConfig, bolt.Init) - if err != nil { - logger.Error(fmt.Sprintf("failed to connect to bolt db : %s\n", err.Error())) - exitCode = 1 - return - } - defer bClient.Close() - - svc := newService(ctx, db, tracer, cfg, dbConfig, logger, spicedbclient, bClient, boltDBConfig) + svc := newService(ctx, db, tracer, cfg, dbConfig, logger, spicedbclient, cacheclient, cfg.CacheKeyDuration) grpcServerConfig := server.Config{Port: defSvcGRPCPort} if err := env.ParseWithOptions(&grpcServerConfig, env.Options{Prefix: envPrefixGrpc}); err != nil { @@ -231,10 +226,12 @@ func initSchema(ctx context.Context, client *authzed.ClientWithExperimental, sch return nil } -func newService(_ context.Context, db *sqlx.DB, tracer trace.Tracer, cfg config, dbConfig pgclient.Config, logger *slog.Logger, spicedbClient *authzed.ClientWithExperimental, bClient *bbolt.DB, bConfig boltclient.Config) auth.Service { +func newService(_ context.Context, db *sqlx.DB, tracer trace.Tracer, cfg config, dbConfig pgclient.Config, logger *slog.Logger, spicedbClient *authzed.ClientWithExperimental, cacheClient *redis.Client, keyDuration time.Duration) auth.Service { + cache := cache.NewPatsCache(cacheClient, keyDuration) + database := pgclient.NewDatabase(db, dbConfig, tracer) keysRepo := apostgres.New(database) - patsRepo := bolt.NewPATSRepository(bClient, bConfig.Bucket) + patsRepo := apostgres.NewPatRepo(database, cache) hasher := hasher.New() idProvider := uuid.New() @@ -243,7 +240,7 @@ func newService(_ context.Context, db *sqlx.DB, tracer trace.Tracer, cfg config, t := jwt.New([]byte(cfg.SecretKey)) - svc := auth.New(keysRepo, patsRepo, hasher, idProvider, t, pEvaluator, pService, cfg.AccessDuration, cfg.RefreshDuration, cfg.InvitationDuration) + svc := auth.New(keysRepo, patsRepo, nil, hasher, idProvider, t, pEvaluator, pService, cfg.AccessDuration, cfg.RefreshDuration, cfg.InvitationDuration) svc = api.LoggingMiddleware(svc, logger) counter, latency := prometheus.MakeMetrics("auth", "api") svc = api.MetricsMiddleware(svc, counter, latency) diff --git a/docker/.env b/docker/.env index 3c17331b5e..d2029ab890 100644 --- a/docker/.env +++ b/docker/.env @@ -99,6 +99,8 @@ SMQ_AUTH_ACCESS_TOKEN_DURATION="1h" SMQ_AUTH_REFRESH_TOKEN_DURATION="24h" SMQ_AUTH_INVITATION_DURATION="168h" SMQ_AUTH_ADAPTER_INSTANCE_ID= +SMQ_AUTH_CACHE_URL=redis://auth-redis:${SMQ_REDIS_TCP_PORT}/0 +SMQ_AUTH_CACHE_KEY_DURATION=10m #### Auth Client Config SMQ_AUTH_URL=auth:9001 diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml index a6f68ff1cd..ef45982a41 100644 --- a/docker/docker-compose.yml +++ b/docker/docker-compose.yml @@ -21,6 +21,7 @@ volumes: supermq-domains-db-volume: supermq-domains-redis-volume: supermq-ui-db-volume: + supermq-auth-redis-volume: services: spicedb: @@ -84,6 +85,15 @@ services: volumes: - supermq-auth-db-volume:/var/lib/postgresql/data + auth-redis: + image: redis:7.2.4-alpine + container_name: supermq-auth-redis + restart: on-failure + networks: + - supermq-base-net + volumes: + - supermq-auth-redis-volume:/data + auth: image: supermq/auth:${SMQ_RELEASE_TAG} container_name: supermq-auth @@ -130,6 +140,8 @@ services: SMQ_SEND_TELEMETRY: ${SMQ_SEND_TELEMETRY} SMQ_AUTH_ADAPTER_INSTANCE_ID: ${SMQ_AUTH_ADAPTER_INSTANCE_ID} SMQ_ES_URL: ${SMQ_ES_URL} + SMQ_AUTH_CACHE_URL: ${SMQ_AUTH_CACHE_URL} + SMQ_AUTH_CACHE_KEY_DURATION: ${SMQ_AUTH_CACHE_KEY_DURATION} ports: - ${SMQ_AUTH_HTTP_PORT}:${SMQ_AUTH_HTTP_PORT} - ${SMQ_AUTH_GRPC_PORT}:${SMQ_AUTH_GRPC_PORT} diff --git a/docker/nginx/entrypoint.sh b/docker/nginx/entrypoint.sh index 66b2c51d3e..b221dc0549 100755 --- a/docker/nginx/entrypoint.sh +++ b/docker/nginx/entrypoint.sh @@ -13,6 +13,7 @@ fi envsubst ' ${SMQ_NGINX_SERVER_NAME} + ${SMQ_AUTH_HTTP_PORT} ${SMQ_DOMAINS_HTTP_PORT} ${SMQ_GROUPS_HTTP_PORT} ${SMQ_USERS_HTTP_PORT} diff --git a/docker/nginx/nginx-key.conf b/docker/nginx/nginx-key.conf index 4878705af3..d4ea0dbdcd 100644 --- a/docker/nginx/nginx-key.conf +++ b/docker/nginx/nginx-key.conf @@ -57,6 +57,13 @@ http { add_header Access-Control-Allow-Methods '*'; add_header Access-Control-Allow-Headers '*'; + # Proxy pass to auth service + location ~ ^/(pats) { + include snippets/proxy-headers.conf; + add_header Access-Control-Expose-Headers Location; + proxy_pass http://auth:${SMQ_AUTH_HTTP_PORT}; + } + # Proxy pass to domains service location ~ ^/(domains|invitations) { include snippets/proxy-headers.conf; diff --git a/docker/nginx/nginx-x509.conf b/docker/nginx/nginx-x509.conf index a26b2f530b..dadcb547a3 100644 --- a/docker/nginx/nginx-x509.conf +++ b/docker/nginx/nginx-x509.conf @@ -66,6 +66,13 @@ http { add_header Access-Control-Allow-Methods '*'; add_header Access-Control-Allow-Headers '*'; + # Proxy pass to auth service + location ~ ^/(pats) { + include snippets/proxy-headers.conf; + add_header Access-Control-Expose-Headers Location; + proxy_pass http://auth:${SMQ_AUTH_HTTP_PORT}; + } + # Proxy pass to domains service location ~ ^/(domains|invitations) { include snippets/proxy-headers.conf; diff --git a/go.mod b/go.mod index 1d1fadd3ac..233ebc3f92 100644 --- a/go.mod +++ b/go.mod @@ -41,7 +41,6 @@ require ( github.com/spf13/cobra v1.9.1 github.com/sqids/sqids-go v0.4.1 github.com/stretchr/testify v1.10.0 - go.etcd.io/bbolt v1.4.0 go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.59.0 go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.59.0 go.opentelemetry.io/otel v1.34.0 diff --git a/go.sum b/go.sum index abfaba4e13..b15c199939 100644 --- a/go.sum +++ b/go.sum @@ -425,8 +425,6 @@ github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9de github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q= -go.etcd.io/bbolt v1.4.0 h1:TU77id3TnN/zKr7CO/uk+fBCwF2jGcMuw2B/FMAzYIk= -go.etcd.io/bbolt v1.4.0/go.mod h1:AsD+OCi/qPN1giOX1aiLAha3o1U8rAz65bvN4j0sRuk= go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A= go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.59.0 h1:rgMkmiGfix9vFJDcDi1PK8WEQP4FLQwLDfhp5ZLpFeE= diff --git a/groups/middleware/authorization.go b/groups/middleware/authorization.go index 668baa01e6..31b1f0419c 100644 --- a/groups/middleware/authorization.go +++ b/groups/middleware/authorization.go @@ -84,13 +84,12 @@ func AuthorizationMiddleware(entityType string, svc groups.Service, repo groups. func (am *authorizationMiddleware) CreateGroup(ctx context.Context, session authn.Session, g groups.Group) (groups.Group, []roles.RoleProvision, error) { if session.Type == authn.PersonalAccessToken { if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{ - UserID: session.UserID, - PatID: session.PatID, - PlatformEntityType: auth.PlatformDomainsScope, - OptionalDomainID: session.DomainID, - OptionalDomainEntityType: auth.DomainGroupsScope, - Operation: auth.CreateOp, - EntityIDs: auth.AnyIDs{}.Values(), + UserID: session.UserID, + PatID: session.PatID, + EntityType: auth.GroupsType, + OptionalDomainID: session.DomainID, + Operation: auth.CreateOp, + EntityID: auth.AnyIDs, }); err != nil { return groups.Group{}, []roles.RoleProvision{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err) } @@ -125,13 +124,12 @@ func (am *authorizationMiddleware) CreateGroup(ctx context.Context, session auth func (am *authorizationMiddleware) UpdateGroup(ctx context.Context, session authn.Session, g groups.Group) (groups.Group, error) { if session.Type == authn.PersonalAccessToken { if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{ - UserID: session.UserID, - PatID: session.PatID, - PlatformEntityType: auth.PlatformDomainsScope, - OptionalDomainID: session.DomainID, - OptionalDomainEntityType: auth.DomainGroupsScope, - Operation: auth.UpdateOp, - EntityIDs: []string{g.ID}, + UserID: session.UserID, + PatID: session.PatID, + EntityType: auth.GroupsType, + OptionalDomainID: session.DomainID, + Operation: auth.UpdateOp, + EntityID: g.ID, }); err != nil { return groups.Group{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err) } @@ -154,13 +152,12 @@ func (am *authorizationMiddleware) UpdateGroup(ctx context.Context, session auth func (am *authorizationMiddleware) ViewGroup(ctx context.Context, session authn.Session, id string) (groups.Group, error) { if session.Type == authn.PersonalAccessToken { if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{ - UserID: session.UserID, - PatID: session.PatID, - PlatformEntityType: auth.PlatformDomainsScope, - OptionalDomainID: session.DomainID, - OptionalDomainEntityType: auth.DomainGroupsScope, - Operation: auth.ReadOp, - EntityIDs: []string{id}, + UserID: session.UserID, + PatID: session.PatID, + EntityType: auth.GroupsType, + OptionalDomainID: session.DomainID, + Operation: auth.ReadOp, + EntityID: id, }); err != nil { return groups.Group{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err) } @@ -183,13 +180,12 @@ func (am *authorizationMiddleware) ViewGroup(ctx context.Context, session authn. func (am *authorizationMiddleware) ListGroups(ctx context.Context, session authn.Session, gm groups.PageMeta) (groups.Page, error) { if session.Type == authn.PersonalAccessToken { if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{ - UserID: session.UserID, - PatID: session.PatID, - PlatformEntityType: auth.PlatformDomainsScope, - OptionalDomainID: session.DomainID, - OptionalDomainEntityType: auth.DomainGroupsScope, - Operation: auth.ListOp, - EntityIDs: auth.AnyIDs{}.Values(), + UserID: session.UserID, + PatID: session.PatID, + EntityType: auth.GroupsType, + OptionalDomainID: session.DomainID, + Operation: auth.ListOp, + EntityID: auth.AnyIDs, }); err != nil { return groups.Page{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err) } @@ -235,13 +231,12 @@ func (am *authorizationMiddleware) ListUserGroups(ctx context.Context, session a func (am *authorizationMiddleware) EnableGroup(ctx context.Context, session authn.Session, id string) (groups.Group, error) { if session.Type == authn.PersonalAccessToken { if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{ - UserID: session.UserID, - PatID: session.PatID, - PlatformEntityType: auth.PlatformDomainsScope, - OptionalDomainID: session.DomainID, - OptionalDomainEntityType: auth.DomainGroupsScope, - Operation: auth.UpdateOp, - EntityIDs: []string{id}, + UserID: session.UserID, + PatID: session.PatID, + EntityType: auth.GroupsType, + OptionalDomainID: session.DomainID, + Operation: auth.UpdateOp, + EntityID: id, }); err != nil { return groups.Group{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err) } @@ -263,13 +258,12 @@ func (am *authorizationMiddleware) EnableGroup(ctx context.Context, session auth func (am *authorizationMiddleware) DisableGroup(ctx context.Context, session authn.Session, id string) (groups.Group, error) { if session.Type == authn.PersonalAccessToken { if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{ - UserID: session.UserID, - PatID: session.PatID, - PlatformEntityType: auth.PlatformDomainsScope, - OptionalDomainID: session.DomainID, - OptionalDomainEntityType: auth.DomainGroupsScope, - Operation: auth.UpdateOp, - EntityIDs: []string{id}, + UserID: session.UserID, + PatID: session.PatID, + EntityType: auth.GroupsType, + OptionalDomainID: session.DomainID, + Operation: auth.UpdateOp, + EntityID: id, }); err != nil { return groups.Group{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err) } @@ -291,13 +285,12 @@ func (am *authorizationMiddleware) DisableGroup(ctx context.Context, session aut func (am *authorizationMiddleware) DeleteGroup(ctx context.Context, session authn.Session, id string) error { if session.Type == authn.PersonalAccessToken { if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{ - UserID: session.UserID, - PatID: session.PatID, - PlatformEntityType: auth.PlatformDomainsScope, - OptionalDomainID: session.DomainID, - OptionalDomainEntityType: auth.DomainGroupsScope, - Operation: auth.DeleteOp, - EntityIDs: []string{id}, + UserID: session.UserID, + PatID: session.PatID, + EntityType: auth.GroupsType, + OptionalDomainID: session.DomainID, + Operation: auth.DeleteOp, + EntityID: id, }); err != nil { return errors.Wrap(svcerr.ErrUnauthorizedPAT, err) } @@ -319,13 +312,12 @@ func (am *authorizationMiddleware) DeleteGroup(ctx context.Context, session auth func (am *authorizationMiddleware) RetrieveGroupHierarchy(ctx context.Context, session authn.Session, id string, hm groups.HierarchyPageMeta) (groups.HierarchyPage, error) { if session.Type == authn.PersonalAccessToken { if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{ - UserID: session.UserID, - PatID: session.PatID, - PlatformEntityType: auth.PlatformDomainsScope, - OptionalDomainID: session.DomainID, - OptionalDomainEntityType: auth.DomainGroupsScope, - Operation: auth.ListOp, - EntityIDs: []string{id}, + UserID: session.UserID, + PatID: session.PatID, + EntityType: auth.GroupsType, + OptionalDomainID: session.DomainID, + Operation: auth.ListOp, + EntityID: id, }); err != nil { return groups.HierarchyPage{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err) } @@ -346,13 +338,12 @@ func (am *authorizationMiddleware) RetrieveGroupHierarchy(ctx context.Context, s func (am *authorizationMiddleware) AddParentGroup(ctx context.Context, session authn.Session, id, parentID string) error { if session.Type == authn.PersonalAccessToken { if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{ - UserID: session.UserID, - PatID: session.PatID, - PlatformEntityType: auth.PlatformDomainsScope, - OptionalDomainID: session.DomainID, - OptionalDomainEntityType: auth.DomainGroupsScope, - Operation: auth.UpdateOp, - EntityIDs: []string{id}, + UserID: session.UserID, + PatID: session.PatID, + EntityType: auth.GroupsType, + OptionalDomainID: session.DomainID, + Operation: auth.UpdateOp, + EntityID: id, }); err != nil { return errors.Wrap(svcerr.ErrUnauthorizedPAT, err) } @@ -383,13 +374,12 @@ func (am *authorizationMiddleware) AddParentGroup(ctx context.Context, session a func (am *authorizationMiddleware) RemoveParentGroup(ctx context.Context, session authn.Session, id string) error { if session.Type == authn.PersonalAccessToken { if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{ - UserID: session.UserID, - PatID: session.PatID, - PlatformEntityType: auth.PlatformDomainsScope, - OptionalDomainID: session.DomainID, - OptionalDomainEntityType: auth.DomainGroupsScope, - Operation: auth.DeleteOp, - EntityIDs: []string{id}, + UserID: session.UserID, + PatID: session.PatID, + EntityType: auth.GroupsType, + OptionalDomainID: session.DomainID, + Operation: auth.DeleteOp, + EntityID: id, }); err != nil { return errors.Wrap(svcerr.ErrUnauthorizedPAT, err) } @@ -427,13 +417,12 @@ func (am *authorizationMiddleware) RemoveParentGroup(ctx context.Context, sessio func (am *authorizationMiddleware) AddChildrenGroups(ctx context.Context, session authn.Session, id string, childrenGroupIDs []string) error { if session.Type == authn.PersonalAccessToken { if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{ - UserID: session.UserID, - PatID: session.PatID, - PlatformEntityType: auth.PlatformDomainsScope, - OptionalDomainID: session.DomainID, - OptionalDomainEntityType: auth.DomainGroupsScope, - Operation: auth.UpdateOp, - EntityIDs: []string{id}, + UserID: session.UserID, + PatID: session.PatID, + EntityType: auth.GroupsType, + OptionalDomainID: session.DomainID, + Operation: auth.UpdateOp, + EntityID: id, }); err != nil { return errors.Wrap(svcerr.ErrUnauthorizedPAT, err) } @@ -467,13 +456,12 @@ func (am *authorizationMiddleware) AddChildrenGroups(ctx context.Context, sessio func (am *authorizationMiddleware) RemoveChildrenGroups(ctx context.Context, session authn.Session, id string, childrenGroupIDs []string) error { if session.Type == authn.PersonalAccessToken { if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{ - UserID: session.UserID, - PatID: session.PatID, - PlatformEntityType: auth.PlatformDomainsScope, - OptionalDomainID: session.DomainID, - OptionalDomainEntityType: auth.DomainGroupsScope, - Operation: auth.DeleteOp, - EntityIDs: []string{id}, + UserID: session.UserID, + PatID: session.PatID, + EntityType: auth.GroupsType, + OptionalDomainID: session.DomainID, + Operation: auth.DeleteOp, + EntityID: id, }); err != nil { return errors.Wrap(svcerr.ErrUnauthorizedPAT, err) } @@ -495,13 +483,12 @@ func (am *authorizationMiddleware) RemoveChildrenGroups(ctx context.Context, ses func (am *authorizationMiddleware) RemoveAllChildrenGroups(ctx context.Context, session authn.Session, id string) error { if session.Type == authn.PersonalAccessToken { if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{ - UserID: session.UserID, - PatID: session.PatID, - PlatformEntityType: auth.PlatformDomainsScope, - OptionalDomainID: session.DomainID, - OptionalDomainEntityType: auth.DomainGroupsScope, - Operation: auth.DeleteOp, - EntityIDs: []string{id}, + UserID: session.UserID, + PatID: session.PatID, + EntityType: auth.GroupsType, + OptionalDomainID: session.DomainID, + Operation: auth.DeleteOp, + EntityID: id, }); err != nil { return errors.Wrap(svcerr.ErrUnauthorizedPAT, err) } @@ -523,13 +510,12 @@ func (am *authorizationMiddleware) RemoveAllChildrenGroups(ctx context.Context, func (am *authorizationMiddleware) ListChildrenGroups(ctx context.Context, session authn.Session, id string, startLevel, endLevel int64, pm groups.PageMeta) (groups.Page, error) { if session.Type == authn.PersonalAccessToken { if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{ - UserID: session.UserID, - PatID: session.PatID, - PlatformEntityType: auth.PlatformDomainsScope, - OptionalDomainID: session.DomainID, - OptionalDomainEntityType: auth.DomainGroupsScope, - Operation: auth.ListOp, - EntityIDs: []string{id}, + UserID: session.UserID, + PatID: session.PatID, + EntityType: auth.GroupsType, + OptionalDomainID: session.DomainID, + Operation: auth.ListOp, + EntityID: id, }); err != nil { return groups.Page{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err) } diff --git a/internal/clients/bolt/bolt.go b/internal/clients/bolt/bolt.go deleted file mode 100644 index 8e2afebf97..0000000000 --- a/internal/clients/bolt/bolt.go +++ /dev/null @@ -1,83 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -package bolt - -import ( - "io/fs" - "strconv" - "time" - - "github.com/absmach/supermq/pkg/errors" - "github.com/caarlos0/env/v11" - bolt "go.etcd.io/bbolt" -) - -var ( - errConfig = errors.New("failed to load BoltDB configuration") - errConnect = errors.New("failed to connect to BoltDB database") - errInit = errors.New("failed to initialize to BoltDB database") -) - -type FileMode fs.FileMode - -func (fm *FileMode) UnmarshalText(text []byte) error { - temp, err := strconv.ParseUint(string(text), 8, 32) - if err != nil { - return err - } - *fm = FileMode(temp) - return nil -} - -// Config contains BoltDB specific parameters. -type Config struct { - FileDirPath string `env:"FILE_DIR_PATH" envDefault:"./supermq-data"` - FileName string `env:"FILE_NAME" envDefault:"supermq-pat.db"` - FileMode FileMode `env:"FILE_MODE" envDefault:"0600"` - Bucket string `env:"BUCKET" envDefault:"supermq"` - Timeout time.Duration `env:"TIMEOUT" envDefault:"0"` -} - -// Setup load configuration from environment and creates new BoltDB. -func Setup(envPrefix string, initFn func(*bolt.Tx, string) error) (*bolt.DB, error) { - return SetupDB(envPrefix, initFn) -} - -// SetupDB load configuration from environment,. -func SetupDB(envPrefix string, initFn func(*bolt.Tx, string) error) (*bolt.DB, error) { - cfg := Config{} - if err := env.ParseWithOptions(&cfg, env.Options{Prefix: envPrefix}); err != nil { - return nil, errors.Wrap(errConfig, err) - } - bdb, err := Connect(cfg, initFn) - if err != nil { - return nil, err - } - - return bdb, nil -} - -// Connect establishes connection to the BoltDB. -func Connect(cfg Config, initFn func(*bolt.Tx, string) error) (*bolt.DB, error) { - filePath := cfg.FileDirPath + "/" + cfg.FileName - db, err := bolt.Open(filePath, fs.FileMode(cfg.FileMode), nil) - if err != nil { - return nil, errors.Wrap(errConnect, err) - } - if initFn != nil { - if err := Init(db, cfg, initFn); err != nil { - return nil, err - } - } - return db, nil -} - -func Init(db *bolt.DB, cfg Config, initFn func(*bolt.Tx, string) error) error { - if err := db.Update(func(tx *bolt.Tx) error { - return initFn(tx, cfg.Bucket) - }); err != nil { - return errors.Wrap(errInit, err) - } - return nil -} diff --git a/internal/clients/bolt/doc.go b/internal/clients/bolt/doc.go deleted file mode 100644 index 24fc0f92a5..0000000000 --- a/internal/clients/bolt/doc.go +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -// Package BoltDB contains the domain concept definitions needed to support -// Supermq BoltDB database functionality. -// -// It provides the abstraction of the BoltDB database service, which is used -// to configure, setup and connect to the BoltDB database. -package bolt diff --git a/internal/proto/auth/v1/auth.proto b/internal/proto/auth/v1/auth.proto index 7d28932d57..4aaea25b2d 100644 --- a/internal/proto/auth/v1/auth.proto +++ b/internal/proto/auth/v1/auth.proto @@ -39,13 +39,12 @@ message AuthZReq { } message AuthZPatReq { - string user_id = 1; // User id - string pat_id = 2; // Pat id - uint32 platform_entity_type = 3; // Platform entity type - string optional_domain_id = 4; // Optional domain id - uint32 optional_domain_entity_type = 5; // Optional domain entity type - uint32 operation = 6; // Operation - repeated string entity_ids = 7; // EntityIDs + string user_id = 1; // User id + string pat_id = 2; // Pat id + uint32 entity_type = 3; // Entity type + string optional_domain_id = 4; // Optional domain id + uint32 operation = 6; // Operation + string entity_id = 7; // EntityID } message AuthZRes { diff --git a/pkg/authz/authsvc/authz.go b/pkg/authz/authsvc/authz.go index f0d1de592b..55ef3718e9 100644 --- a/pkg/authz/authsvc/authz.go +++ b/pkg/authz/authsvc/authz.go @@ -124,13 +124,12 @@ func (a authorization) checkDomain(ctx context.Context, subjectType, subject, do func (a authorization) AuthorizePAT(ctx context.Context, pr authz.PatReq) error { req := grpcAuthV1.AuthZPatReq{ - UserId: pr.UserID, - PatId: pr.PatID, - PlatformEntityType: uint32(pr.PlatformEntityType), - OptionalDomainId: pr.OptionalDomainID, - OptionalDomainEntityType: uint32(pr.OptionalDomainEntityType), - Operation: uint32(pr.Operation), - EntityIds: pr.EntityIDs, + UserId: pr.UserID, + PatId: pr.PatID, + EntityType: uint32(pr.EntityType), + OptionalDomainId: pr.OptionalDomainID, + Operation: uint32(pr.Operation), + EntityId: pr.EntityID, } res, err := a.authSvcClient.AuthorizePAT(ctx, &req) if err != nil { diff --git a/pkg/authz/authz.go b/pkg/authz/authz.go index 93d807ac5e..02ac03ea6c 100644 --- a/pkg/authz/authz.go +++ b/pkg/authz/authz.go @@ -47,13 +47,12 @@ type PolicyReq struct { } type PatReq struct { - UserID string `json:"user_id,omitempty"` // UserID - PatID string `json:"pat_id,omitempty"` // UserID - PlatformEntityType auth.PlatformEntityType `json:"platform_entity_type,omitempty"` // Platform entity type - OptionalDomainID string `json:"optional_domainID,omitempty"` // Optional domain id - OptionalDomainEntityType auth.DomainEntityType `json:"optional_domain_entity_type,omitempty"` // Optional domain entity type - Operation auth.OperationType `json:"operation,omitempty"` // Operation - EntityIDs []string `json:"entityIDs,omitempty"` // EntityIDs + UserID string `json:"user_id,omitempty"` // UserID + PatID string `json:"pat_id,omitempty"` // UserID + EntityType auth.EntityType `json:"entity_type,omitempty"` // Entity type + OptionalDomainID string `json:"optional_domainID,omitempty"` // Optional domain id + Operation auth.Operation `json:"operation,omitempty"` // Operation + EntityID string `json:"entityID,omitempty"` // EntityID } // Authz is supermq authorization library. diff --git a/pkg/groups/events/consumer/streams.go b/pkg/groups/events/consumer/streams.go index 6e1b49b68c..fdf2e49a31 100644 --- a/pkg/groups/events/consumer/streams.go +++ b/pkg/groups/events/consumer/streams.go @@ -180,7 +180,6 @@ func (es *eventHandler) removeParentGroupHandler(ctx context.Context, data map[s if err != nil { return errors.Wrap(errRemoveParentGroupEvent, err) } - if err := es.repo.UnassignParentGroup(ctx, g.Parent, id); err != nil { return errors.Wrap(errRemoveParentGroupEvent, err) } diff --git a/users/middleware/authorization.go b/users/middleware/authorization.go index db84ad71bc..750367f0fc 100644 --- a/users/middleware/authorization.go +++ b/users/middleware/authorization.go @@ -46,12 +46,11 @@ func (am *authorizationMiddleware) Register(ctx context.Context, session authn.S func (am *authorizationMiddleware) View(ctx context.Context, session authn.Session, id string) (users.User, error) { if session.Type == authn.PersonalAccessToken { if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{ - UserID: session.UserID, - PatID: session.PatID, - PlatformEntityType: smqauth.PlatformUsersScope, - OptionalDomainEntityType: smqauth.DomainNullScope, - Operation: smqauth.ReadOp, - EntityIDs: []string{id}, + UserID: session.UserID, + PatID: session.PatID, + EntityType: smqauth.UsersType, + Operation: smqauth.ReadOp, + EntityID: id, }); err != nil { return users.User{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err) } @@ -67,12 +66,11 @@ func (am *authorizationMiddleware) View(ctx context.Context, session authn.Sessi func (am *authorizationMiddleware) ViewProfile(ctx context.Context, session authn.Session) (users.User, error) { if session.Type == authn.PersonalAccessToken { if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{ - UserID: session.UserID, - PatID: session.PatID, - PlatformEntityType: smqauth.PlatformUsersScope, - OptionalDomainEntityType: smqauth.DomainNullScope, - Operation: smqauth.ReadOp, - EntityIDs: []string{session.UserID}, + UserID: session.UserID, + PatID: session.PatID, + EntityType: smqauth.UsersType, + Operation: smqauth.ReadOp, + EntityID: session.UserID, }); err != nil { return users.User{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err) } @@ -83,12 +81,11 @@ func (am *authorizationMiddleware) ViewProfile(ctx context.Context, session auth func (am *authorizationMiddleware) ListUsers(ctx context.Context, session authn.Session, pm users.Page) (users.UsersPage, error) { if session.Type == authn.PersonalAccessToken { if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{ - UserID: session.UserID, - PatID: session.PatID, - PlatformEntityType: smqauth.PlatformUsersScope, - OptionalDomainEntityType: smqauth.DomainNullScope, - Operation: smqauth.ListOp, - EntityIDs: smqauth.AnyIDs{}.Values(), + UserID: session.UserID, + PatID: session.PatID, + EntityType: smqauth.UsersType, + Operation: smqauth.ListOp, + EntityID: smqauth.AnyIDs, }); err != nil { return users.UsersPage{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err) } @@ -107,12 +104,11 @@ func (am *authorizationMiddleware) SearchUsers(ctx context.Context, pm users.Pag func (am *authorizationMiddleware) Update(ctx context.Context, session authn.Session, user users.User) (users.User, error) { if session.Type == authn.PersonalAccessToken { if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{ - UserID: session.UserID, - PatID: session.PatID, - PlatformEntityType: smqauth.PlatformUsersScope, - OptionalDomainEntityType: smqauth.DomainNullScope, - Operation: smqauth.UpdateOp, - EntityIDs: []string{user.ID}, + UserID: session.UserID, + PatID: session.PatID, + EntityType: smqauth.UsersType, + Operation: smqauth.UpdateOp, + EntityID: user.ID, }); err != nil { return users.User{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err) } @@ -128,12 +124,11 @@ func (am *authorizationMiddleware) Update(ctx context.Context, session authn.Ses func (am *authorizationMiddleware) UpdateTags(ctx context.Context, session authn.Session, user users.User) (users.User, error) { if session.Type == authn.PersonalAccessToken { if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{ - UserID: session.UserID, - PatID: session.PatID, - PlatformEntityType: smqauth.PlatformUsersScope, - OptionalDomainEntityType: smqauth.DomainNullScope, - Operation: smqauth.UpdateOp, - EntityIDs: []string{user.ID}, + UserID: session.UserID, + PatID: session.PatID, + EntityType: smqauth.UsersType, + Operation: smqauth.UpdateOp, + EntityID: user.ID, }); err != nil { return users.User{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err) } @@ -149,12 +144,11 @@ func (am *authorizationMiddleware) UpdateTags(ctx context.Context, session authn func (am *authorizationMiddleware) UpdateEmail(ctx context.Context, session authn.Session, id, email string) (users.User, error) { if session.Type == authn.PersonalAccessToken { if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{ - UserID: session.UserID, - PatID: session.PatID, - PlatformEntityType: smqauth.PlatformUsersScope, - OptionalDomainEntityType: smqauth.DomainNullScope, - Operation: smqauth.UpdateOp, - EntityIDs: []string{id}, + UserID: session.UserID, + PatID: session.PatID, + EntityType: smqauth.UsersType, + Operation: smqauth.UpdateOp, + EntityID: id, }); err != nil { return users.User{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err) } @@ -169,12 +163,11 @@ func (am *authorizationMiddleware) UpdateEmail(ctx context.Context, session auth func (am *authorizationMiddleware) UpdateUsername(ctx context.Context, session authn.Session, id, username string) (users.User, error) { if session.Type == authn.PersonalAccessToken { if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{ - UserID: session.UserID, - PatID: session.PatID, - PlatformEntityType: smqauth.PlatformUsersScope, - OptionalDomainEntityType: smqauth.DomainNullScope, - Operation: smqauth.UpdateOp, - EntityIDs: []string{id}, + UserID: session.UserID, + PatID: session.PatID, + EntityType: smqauth.UsersType, + Operation: smqauth.UpdateOp, + EntityID: id, }); err != nil { return users.User{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err) } @@ -190,12 +183,11 @@ func (am *authorizationMiddleware) UpdateUsername(ctx context.Context, session a func (am *authorizationMiddleware) UpdateProfilePicture(ctx context.Context, session authn.Session, user users.User) (users.User, error) { if session.Type == authn.PersonalAccessToken { if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{ - UserID: session.UserID, - PatID: session.PatID, - PlatformEntityType: smqauth.PlatformUsersScope, - OptionalDomainEntityType: smqauth.DomainNullScope, - Operation: smqauth.UpdateOp, - EntityIDs: []string{user.ID}, + UserID: session.UserID, + PatID: session.PatID, + EntityType: smqauth.UsersType, + Operation: smqauth.UpdateOp, + EntityID: user.ID, }); err != nil { return users.User{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err) } @@ -215,12 +207,11 @@ func (am *authorizationMiddleware) GenerateResetToken(ctx context.Context, email func (am *authorizationMiddleware) UpdateSecret(ctx context.Context, session authn.Session, oldSecret, newSecret string) (users.User, error) { if session.Type == authn.PersonalAccessToken { if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{ - UserID: session.UserID, - PatID: session.PatID, - PlatformEntityType: smqauth.PlatformUsersScope, - OptionalDomainEntityType: smqauth.DomainNullScope, - Operation: smqauth.UpdateOp, - EntityIDs: []string{session.UserID}, + UserID: session.UserID, + PatID: session.PatID, + EntityType: smqauth.UsersType, + Operation: smqauth.UpdateOp, + EntityID: session.UserID, }); err != nil { return users.User{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err) } @@ -240,12 +231,11 @@ func (am *authorizationMiddleware) SendPasswordReset(ctx context.Context, host, func (am *authorizationMiddleware) UpdateRole(ctx context.Context, session authn.Session, user users.User) (users.User, error) { if session.Type == authn.PersonalAccessToken { if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{ - UserID: session.UserID, - PatID: session.PatID, - PlatformEntityType: smqauth.PlatformUsersScope, - OptionalDomainEntityType: smqauth.DomainNullScope, - Operation: smqauth.UpdateOp, - EntityIDs: []string{user.ID}, + UserID: session.UserID, + PatID: session.PatID, + EntityType: smqauth.UsersType, + Operation: smqauth.UpdateOp, + EntityID: user.ID, }); err != nil { return users.User{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err) } @@ -265,12 +255,11 @@ func (am *authorizationMiddleware) UpdateRole(ctx context.Context, session authn func (am *authorizationMiddleware) Enable(ctx context.Context, session authn.Session, id string) (users.User, error) { if session.Type == authn.PersonalAccessToken { if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{ - UserID: session.UserID, - PatID: session.PatID, - PlatformEntityType: smqauth.PlatformUsersScope, - OptionalDomainEntityType: smqauth.DomainNullScope, - Operation: smqauth.UpdateOp, - EntityIDs: []string{id}, + UserID: session.UserID, + PatID: session.PatID, + EntityType: smqauth.UsersType, + Operation: smqauth.UpdateOp, + EntityID: id, }); err != nil { return users.User{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err) } @@ -286,12 +275,11 @@ func (am *authorizationMiddleware) Enable(ctx context.Context, session authn.Ses func (am *authorizationMiddleware) Disable(ctx context.Context, session authn.Session, id string) (users.User, error) { if session.Type == authn.PersonalAccessToken { if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{ - UserID: session.UserID, - PatID: session.PatID, - PlatformEntityType: smqauth.PlatformUsersScope, - OptionalDomainEntityType: smqauth.DomainNullScope, - Operation: smqauth.UpdateOp, - EntityIDs: []string{id}, + UserID: session.UserID, + PatID: session.PatID, + EntityType: smqauth.UsersType, + Operation: smqauth.UpdateOp, + EntityID: id, }); err != nil { return users.User{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err) } @@ -307,12 +295,11 @@ func (am *authorizationMiddleware) Disable(ctx context.Context, session authn.Se func (am *authorizationMiddleware) Delete(ctx context.Context, session authn.Session, id string) error { if session.Type == authn.PersonalAccessToken { if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{ - UserID: session.UserID, - PatID: session.PatID, - PlatformEntityType: smqauth.PlatformUsersScope, - OptionalDomainEntityType: smqauth.DomainNullScope, - Operation: smqauth.DeleteOp, - EntityIDs: []string{id}, + UserID: session.UserID, + PatID: session.PatID, + EntityType: smqauth.UsersType, + Operation: smqauth.DeleteOp, + EntityID: id, }); err != nil { return errors.Wrap(svcerr.ErrUnauthorizedPAT, err) }