diff --git a/pkg/clients/postgres/clients.go b/pkg/clients/postgres/clients.go index 63cf14a5a2..6497951db5 100644 --- a/pkg/clients/postgres/clients.go +++ b/pkg/clients/postgres/clients.go @@ -13,13 +13,20 @@ import ( "github.com/absmach/magistrala/internal/api" "github.com/absmach/magistrala/pkg/clients" - "github.com/absmach/magistrala/pkg/errors" repoerr "github.com/absmach/magistrala/pkg/errors/repository" "github.com/absmach/magistrala/pkg/groups" "github.com/absmach/magistrala/pkg/postgres" "github.com/jackc/pgtype" ) +var ( + MsgErrMarshalMeta = "failed to marshal JSON metadata" + MsgErrUnmarshalMeta = "failed to unmarshal JSON metadata" + MsgErrFailedScan = "failed to scan DB value to struct" + MsgErrFailedQuery = "failed to execute DB query" + MsgErrFailedTotal = "failed to calculate total" +) + type Repository struct { DB postgres.Database } @@ -30,6 +37,7 @@ func (repo *Repository) Update(ctx context.Context, client clients.Client) (clie if client.Name != "" { query = append(query, "name = :name,") } + if client.Metadata != nil { query = append(query, "metadata = :metadata,") } @@ -95,14 +103,14 @@ func (repo *Repository) RetrieveByID(ctx context.Context, id string) (clients.Cl row, err := repo.DB.NamedQueryContext(ctx, q, dbc) if err != nil { - return clients.Client{}, errors.Wrap(repoerr.ErrViewEntity, err) + return clients.Client{}, repoerr.NewReadError(MsgErrFailedQuery, err) } defer row.Close() dbc = DBClient{} if row.Next() { if err := row.StructScan(&dbc); err != nil { - return clients.Client{}, errors.Wrap(repoerr.ErrViewEntity, err) + return clients.Client{}, repoerr.NewTypeError(MsgErrFailedScan, err) } return ToClient(dbc) @@ -129,7 +137,7 @@ func (repo *Repository) RetrieveByIdentity(ctx context.Context, identity string) dbc = DBClient{} if row.Next() { if err := row.StructScan(&dbc); err != nil { - return clients.Client{}, errors.Wrap(repoerr.ErrViewEntity, err) + return clients.Client{}, repoerr.NewTypeError(MsgErrFailedScan, err) } return ToClient(dbc) @@ -141,19 +149,18 @@ func (repo *Repository) RetrieveByIdentity(ctx context.Context, identity string) func (repo *Repository) RetrieveAll(ctx context.Context, pm clients.Page) (clients.ClientsPage, error) { query, err := PageQuery(pm) if err != nil { - return clients.ClientsPage{}, errors.Wrap(repoerr.ErrViewEntity, err) + return clients.ClientsPage{}, err } - q := fmt.Sprintf(`SELECT c.id, c.name, c.tags, c.identity, c.metadata, COALESCE(c.domain_id, '') AS domain_id, c.status, - c.created_at, c.updated_at, COALESCE(c.updated_by, '') AS updated_by FROM clients c %s ORDER BY c.created_at LIMIT :limit OFFSET :offset;`, query) - dbPage, err := ToDBClientsPage(pm) if err != nil { - return clients.ClientsPage{}, errors.Wrap(repoerr.ErrFailedToRetrieveAllGroups, err) + return clients.ClientsPage{}, err } + q := fmt.Sprintf(`SELECT c.id, c.name, c.tags, c.identity, c.metadata, COALESCE(c.domain_id, '') AS domain_id, c.status, + c.created_at, c.updated_at, COALESCE(c.updated_by, '') AS updated_by FROM clients c %s ORDER BY c.created_at LIMIT :limit OFFSET :offset;`, query) rows, err := repo.DB.NamedQueryContext(ctx, q, dbPage) if err != nil { - return clients.ClientsPage{}, errors.Wrap(repoerr.ErrFailedToRetrieveAllGroups, err) + return clients.ClientsPage{}, repoerr.NewReadError("failed to read from the DB", err) } defer rows.Close() @@ -161,7 +168,7 @@ func (repo *Repository) RetrieveAll(ctx context.Context, pm clients.Page) (clien for rows.Next() { dbc := DBClient{} if err := rows.StructScan(&dbc); err != nil { - return clients.ClientsPage{}, errors.Wrap(repoerr.ErrViewEntity, err) + return clients.ClientsPage{}, repoerr.NewTypeError(MsgErrFailedScan, err) } c, err := ToClient(dbc) @@ -175,7 +182,7 @@ func (repo *Repository) RetrieveAll(ctx context.Context, pm clients.Page) (clien total, err := postgres.Total(ctx, repo.DB, cq, dbPage) if err != nil { - return clients.ClientsPage{}, errors.Wrap(repoerr.ErrViewEntity, err) + return clients.ClientsPage{}, repoerr.NewReadError(MsgErrFailedTotal, err) } page := clients.ClientsPage{ @@ -193,16 +200,15 @@ func (repo *Repository) RetrieveAll(ctx context.Context, pm clients.Page) (clien func (repo *Repository) RetrieveAllBasicInfo(ctx context.Context, pm clients.Page) (clients.ClientsPage, error) { sq, tq := constructSearchQuery(pm) - q := fmt.Sprintf(`SELECT c.id, c.name, c.created_at, c.updated_at FROM clients c %s LIMIT :limit OFFSET :offset;`, sq) - dbPage, err := ToDBClientsPage(pm) if err != nil { - return clients.ClientsPage{}, errors.Wrap(repoerr.ErrFailedToRetrieveAllGroups, err) + return clients.ClientsPage{}, err } + q := fmt.Sprintf(`SELECT c.id, c.name, c.created_at, c.updated_at FROM clients c %s LIMIT :limit OFFSET :offset;`, sq) rows, err := repo.DB.NamedQueryContext(ctx, q, dbPage) if err != nil { - return clients.ClientsPage{}, errors.Wrap(repoerr.ErrFailedToRetrieveAllGroups, err) + return clients.ClientsPage{}, repoerr.NewReadError(MsgErrFailedQuery, err) } defer rows.Close() @@ -210,7 +216,7 @@ func (repo *Repository) RetrieveAllBasicInfo(ctx context.Context, pm clients.Pag for rows.Next() { dbc := DBClient{} if err := rows.StructScan(&dbc); err != nil { - return clients.ClientsPage{}, errors.Wrap(repoerr.ErrViewEntity, err) + return clients.ClientsPage{}, repoerr.NewTypeError(MsgErrFailedScan, err) } c, err := ToClient(dbc) @@ -224,7 +230,7 @@ func (repo *Repository) RetrieveAllBasicInfo(ctx context.Context, pm clients.Pag cq := fmt.Sprintf(`SELECT COUNT(*) FROM clients c %s;`, tq) total, err := postgres.Total(ctx, repo.DB, cq, dbPage) if err != nil { - return clients.ClientsPage{}, errors.Wrap(repoerr.ErrViewEntity, err) + return clients.ClientsPage{}, repoerr.NewReadError(MsgErrFailedTotal, err) } page := clients.ClientsPage{ @@ -247,7 +253,7 @@ func (repo *Repository) RetrieveAllByIDs(ctx context.Context, pm clients.Page) ( } query, err := PageQuery(pm) if err != nil { - return clients.ClientsPage{}, errors.Wrap(repoerr.ErrViewEntity, err) + return clients.ClientsPage{}, err } q := fmt.Sprintf(`SELECT c.id, c.name, c.tags, c.identity, c.metadata, COALESCE(c.domain_id, '') AS domain_id, c.status, @@ -255,11 +261,11 @@ func (repo *Repository) RetrieveAllByIDs(ctx context.Context, pm clients.Page) ( dbPage, err := ToDBClientsPage(pm) if err != nil { - return clients.ClientsPage{}, errors.Wrap(repoerr.ErrFailedToRetrieveAllGroups, err) + return clients.ClientsPage{}, err } rows, err := repo.DB.NamedQueryContext(ctx, q, dbPage) if err != nil { - return clients.ClientsPage{}, errors.Wrap(repoerr.ErrFailedToRetrieveAllGroups, err) + return clients.ClientsPage{}, err } defer rows.Close() @@ -267,7 +273,7 @@ func (repo *Repository) RetrieveAllByIDs(ctx context.Context, pm clients.Page) ( for rows.Next() { dbc := DBClient{} if err := rows.StructScan(&dbc); err != nil { - return clients.ClientsPage{}, errors.Wrap(repoerr.ErrViewEntity, err) + return clients.ClientsPage{}, repoerr.NewTypeError(MsgErrFailedScan, err) } c, err := ToClient(dbc) @@ -281,7 +287,7 @@ func (repo *Repository) RetrieveAllByIDs(ctx context.Context, pm clients.Page) ( total, err := postgres.Total(ctx, repo.DB, cq, dbPage) if err != nil { - return clients.ClientsPage{}, errors.Wrap(repoerr.ErrViewEntity, err) + return clients.ClientsPage{}, repoerr.NewReadError(MsgErrFailedTotal, err) } page := clients.ClientsPage{ @@ -299,7 +305,7 @@ func (repo *Repository) RetrieveAllByIDs(ctx context.Context, pm clients.Page) ( func (repo *Repository) update(ctx context.Context, client clients.Client, query string) (clients.Client, error) { dbc, err := ToDBClient(client) if err != nil { - return clients.Client{}, errors.Wrap(repoerr.ErrUpdateEntity, err) + return clients.Client{}, err } row, err := repo.DB.NamedQueryContext(ctx, query, dbc) @@ -311,7 +317,7 @@ func (repo *Repository) update(ctx context.Context, client clients.Client, query dbc = DBClient{} if row.Next() { if err := row.StructScan(&dbc); err != nil { - return clients.Client{}, errors.Wrap(repoerr.ErrUpdateEntity, err) + return clients.Client{}, repoerr.NewTypeError(MsgErrFailedScan, err) } return ToClient(dbc) @@ -355,7 +361,7 @@ func ToDBClient(c clients.Client) (DBClient, error) { if len(c.Metadata) > 0 { b, err := json.Marshal(c.Metadata) if err != nil { - return DBClient{}, errors.Wrap(repoerr.ErrMalformedEntity, err) + return DBClient{}, repoerr.NewTypeError(MsgErrMarshalMeta, err) } data = b } @@ -392,7 +398,7 @@ func ToClient(c DBClient) (clients.Client, error) { var metadata clients.Metadata if c.Metadata != nil { if err := json.Unmarshal([]byte(c.Metadata), &metadata); err != nil { - return clients.Client{}, errors.Wrap(errors.ErrMalformedEntity, err) + return clients.Client{}, repoerr.NewTypeError(MsgErrUnmarshalMeta, err) } } var tags []string @@ -432,7 +438,7 @@ func ToClient(c DBClient) (clients.Client, error) { func ToDBClientsPage(pm clients.Page) (dbClientsPage, error) { _, data, err := postgres.CreateMetadataQuery("", pm.Metadata) if err != nil { - return dbClientsPage{}, errors.Wrap(repoerr.ErrViewEntity, err) + return dbClientsPage{}, err } return dbClientsPage{ Name: pm.Name, @@ -465,7 +471,7 @@ type dbClientsPage struct { func PageQuery(pm clients.Page) (string, error) { mq, _, err := postgres.CreateMetadataQuery("", pm.Metadata) if err != nil { - return "", errors.Wrap(errors.ErrMalformedEntity, err) + return "", err } var query []string var emq string diff --git a/pkg/clients/postgres/clients_test.go b/pkg/clients/postgres/clients_test.go index 178fcdf4ab..cffd701178 100644 --- a/pkg/clients/postgres/clients_test.go +++ b/pkg/clients/postgres/clients_test.go @@ -345,7 +345,7 @@ func TestRetrieveAll(t *testing.T) { }, Clients: []mgclients.Client(nil), }, - err: repoerr.ErrViewEntity, + err: repoerr.NewTypeError(pgclients.MsgErrMarshalMeta, nil), }, { desc: "with name", @@ -909,7 +909,7 @@ func TestRetrieveByIDs(t *testing.T) { }, Clients: []mgclients.Client(nil), }, - err: errors.ErrMalformedEntity, + err: repoerr.NewTypeError(pgclients.MsgErrMarshalMeta, nil), }, } @@ -922,7 +922,7 @@ func TestRetrieveByIDs(t *testing.T) { assert.Equal(t, c.response.Offset, response.Offset) assert.ElementsMatch(t, response.Clients, c.response.Clients) default: - assert.True(t, errors.Contains(err, c.err), fmt.Sprintf("expected %s to contain %s\n", err, c.err)) + assert.True(t, errors.ContainsType(err, c.err), fmt.Sprintf("%s: expected %s to contain %s\n", c.desc, err, c.err)) } } } @@ -1342,7 +1342,7 @@ func TestUpdate(t *testing.T) { "update": make(chan int), }, }, - err: repoerr.ErrUpdateEntity, + err: repoerr.NewTypeError(pgclients.MsgErrMarshalMeta, nil), }, { desc: "update metadata for disabled client", diff --git a/pkg/errors/api/types.go b/pkg/errors/api/types.go new file mode 100644 index 0000000000..4666aff8a7 --- /dev/null +++ b/pkg/errors/api/types.go @@ -0,0 +1,20 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package api + +import "github.com/absmach/magistrala/pkg/errors" + +type ( + ContentTypeError struct { + *errors.CustomError + } + + ValidationError struct { + *errors.CustomError + } + + InvalidParamsError struct { + *errors.CustomError + } +) diff --git a/pkg/errors/auth/types.go b/pkg/errors/auth/types.go new file mode 100644 index 0000000000..0f4931334e --- /dev/null +++ b/pkg/errors/auth/types.go @@ -0,0 +1,24 @@ +package auth + +import "github.com/absmach/magistrala/pkg/errors" + +type ( + + // AuthenticationError indicates failure occurred while authenticating the entity. + AuthenticationError struct { + *errors.CustomError + } + + // AuthorizationError indicates failure occurred while authorizing the entity. + AuthorizationError struct { + *errors.CustomError + } +) + +func NewAuthNError(text string, err error) error { + return &AuthenticationError{errors.NewErr(text, err)} +} + +func NewAuthZError(text string, err error) error { + return &AuthenticationError{errors.NewErr(text, err)} +} diff --git a/pkg/errors/errors.go b/pkg/errors/errors.go index 6ca1637db4..3a63a9ab18 100644 --- a/pkg/errors/errors.go +++ b/pkg/errors/errors.go @@ -5,6 +5,7 @@ package errors import ( "encoding/json" + "reflect" ) // Error specifies an API that must be fullfiled by error type. @@ -15,30 +16,37 @@ type Error interface { // Msg returns error message. Msg() string - // Err returns wrapped error. - Err() Error + // Unwrap returns wrapped error. + Unwrap() error // MarshalJSON returns a marshaled error. MarshalJSON() ([]byte, error) } -var _ Error = (*customError)(nil) +var _ Error = (*CustomError)(nil) -// customError represents a Magistrala error. -type customError struct { +// CustomError represents a Magistrala error. +type CustomError struct { msg string err Error } // New returns an Error that formats as the given text. func New(text string) Error { - return &customError{ + return &CustomError{ msg: text, err: nil, } } -func (ce *customError) Error() string { +func NewErr(text string, err error) *CustomError { + return &CustomError{ + msg: text, + err: cast(err), + } +} + +func (ce *CustomError) Error() string { if ce == nil { return "" } @@ -48,18 +56,20 @@ func (ce *customError) Error() string { return ce.msg + " : " + ce.err.Error() } -func (ce *customError) Msg() string { +func (ce *CustomError) Msg() string { return ce.msg } -func (ce *customError) Err() Error { +func (ce *CustomError) Unwrap() error { return ce.err } -func (ce *customError) MarshalJSON() ([]byte, error) { +func (ce *CustomError) MarshalJSON() ([]byte, error) { var val string - if e := ce.Err(); e != nil { - val = e.Msg() + if e := ce.Unwrap(); e != nil { + if e1, ok := e.(Error); ok { + val = e1.Msg() + } } return json.Marshal(&struct { Err string `json:"error"` @@ -80,7 +90,25 @@ func Contains(e1, e2 error) bool { if ce.Msg() == e2.Error() { return true } - return Contains(ce.Err(), e2) + return Contains(ce.Unwrap(), e2) + } + return e1.Error() == e2.Error() +} + +// ContainsType inspects if e2 error is contained in any layer of e1 error. +func ContainsType(e1, e2 error) bool { + if e1 == nil || e2 == nil { + return e2 == e1 + } + + ce, ok := e1.(Error) + if ok { + v1 := reflect.ValueOf(e1) + v2 := reflect.ValueOf(e2) + if v1.Type() == v2.Type() && ce.Msg() == e2.Error() { + return true + } + return ContainsType(ce.Unwrap(), e2) } return e1.Error() == e2.Error() } @@ -91,12 +119,12 @@ func Wrap(wrapper, err error) error { return wrapper } if w, ok := wrapper.(Error); ok { - return &customError{ + return &CustomError{ msg: w.Msg(), err: cast(err), } } - return &customError{ + return &CustomError{ msg: wrapper.Error(), err: cast(err), } @@ -105,10 +133,10 @@ func Wrap(wrapper, err error) error { // Unwrap returns the wrapper and the error by separating the Wrapper from the error. func Unwrap(err error) (error, error) { if ce, ok := err.(Error); ok { - if ce.Err() == nil { + if ce.Unwrap() == nil { return nil, New(ce.Msg()) } - return New(ce.Msg()), ce.Err() + return New(ce.Msg()), ce.Unwrap() } return nil, err @@ -121,7 +149,7 @@ func cast(err error) Error { if e, ok := err.(Error); ok { return e } - return &customError{ + return &CustomError{ msg: err.Error(), err: nil, } diff --git a/pkg/errors/repository/errors.go b/pkg/errors/repository/errors.go new file mode 100644 index 0000000000..04109dd1ec --- /dev/null +++ b/pkg/errors/repository/errors.go @@ -0,0 +1,36 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package repository + +import "github.com/absmach/magistrala/pkg/errors" + +// Wrapper for Repository errors. +var ( + // ErrMalformedEntity indicates a malformed entity specification. + ErrMalformedEntity = &ConstraintError{errors.NewErr("malformed entity specification", nil)} + + // ErrNotFound indicates a non-existent entity request. + ErrNotFound = &ReadError{errors.NewErr("entity not found", nil)} + + // ErrConflict indicates that entity already exists. + ErrConflict = &WriteError{errors.NewErr("entity already exists", nil)} + + // ErrCreateEntity indicates error in creating entity or entities. + ErrCreateEntity = &WriteError{errors.NewErr("failed to create entity in the db", nil)} + + // ErrViewEntity indicates error in viewing entity or entities. + ErrViewEntity = &ReadError{errors.NewErr("view entity failed", nil)} + + // ErrUpdateEntity indicates error in updating entity or entities. + ErrUpdateEntity = &WriteError{errors.NewErr("update entity failed", nil)} + + // ErrRemoveEntity indicates error in removing entity. + ErrRemoveEntity = &WriteError{errors.NewErr("failed to remove entity", nil)} + + // ErrFailedOpDB indicates a failure in a database operation. + ErrFailedOpDB = &ConstraintError{errors.NewErr("operation on db element failed", nil)} + + // ErrFailedToRetrieveAllGroups failed to retrieve groups. + ErrFailedToRetrieveAllGroups = &ReadError{errors.NewErr("failed to retrieve all groups", nil)} +) diff --git a/pkg/errors/repository/types.go b/pkg/errors/repository/types.go index 9c33d083bc..c3f6c91586 100644 --- a/pkg/errors/repository/types.go +++ b/pkg/errors/repository/types.go @@ -3,34 +3,65 @@ package repository -import "github.com/absmach/magistrala/pkg/errors" +import ( + "github.com/absmach/magistrala/pkg/errors" +) // Wrapper for Repository errors. -var ( - // ErrMalformedEntity indicates a malformed entity specification. - ErrMalformedEntity = errors.New("malformed entity specification") +type ( + ConstraintError struct { + *errors.CustomError + } + + WriteError struct { + *errors.CustomError + } - // ErrNotFound indicates a non-existent entity request. - ErrNotFound = errors.New("entity not found") + ReadError struct { + *errors.CustomError + } - // ErrConflict indicates that entity already exists. - ErrConflict = errors.New("entity already exists") + RollbackError struct { + *errors.CustomError + } - // ErrCreateEntity indicates error in creating entity or entities. - ErrCreateEntity = errors.New("failed to create entity in the db") + TypeError struct { + *errors.CustomError + } - // ErrViewEntity indicates error in viewing entity or entities. - ErrViewEntity = errors.New("view entity failed") + OtherError struct { + *errors.CustomError + } - // ErrUpdateEntity indicates error in updating entity or entities. - ErrUpdateEntity = errors.New("update entity failed") + NotFoundError struct { + *errors.CustomError + } +) - // ErrRemoveEntity indicates error in removing entity. - ErrRemoveEntity = errors.New("failed to remove entity") +func NewConstraintError(text string, err error) *ConstraintError { + return &ConstraintError{errors.NewErr(text, err)} +} - // ErrFailedOpDB indicates a failure in a database operation. - ErrFailedOpDB = errors.New("operation on db element failed") +func NewWriteError(text string, err error) *WriteError { + return &WriteError{errors.NewErr(text, err)} +} - // ErrFailedToRetrieveAllGroups failed to retrieve groups. - ErrFailedToRetrieveAllGroups = errors.New("failed to retrieve all groups") -) +func NewReadError(text string, err error) *ReadError { + return &ReadError{errors.NewErr(text, err)} +} + +func NewRollbackError(rbErr, err error) *RollbackError { + return &RollbackError{errors.NewErr(rbErr.Error(), err)} +} + +func NewTypeError(text string, err error) *TypeError { + return &TypeError{errors.NewErr(text, err)} +} + +func NewOtherError(text string, err error) *OtherError { + return &OtherError{errors.NewErr(text, err)} +} + +func NewNotFoundError(text string, err error) *NotFoundError { + return &NotFoundError{errors.NewErr(text, err)} +} diff --git a/pkg/errors/sdk_errors.go b/pkg/errors/sdk_errors.go index 61535c9165..55647ab785 100644 --- a/pkg/errors/sdk_errors.go +++ b/pkg/errors/sdk_errors.go @@ -27,7 +27,7 @@ type SDKError interface { var _ SDKError = (*sdkError)(nil) type sdkError struct { - *customError + *CustomError statusCode int } @@ -35,10 +35,10 @@ func (ce *sdkError) Error() string { if ce == nil { return "" } - if ce.customError == nil { + if ce.CustomError == nil { return http.StatusText(ce.statusCode) } - return fmt.Sprintf("Status: %s: %s", http.StatusText(ce.statusCode), ce.customError.Error()) + return fmt.Sprintf("Status: %s: %s", http.StatusText(ce.statusCode), ce.CustomError.Error()) } func (ce *sdkError) StatusCode() int { @@ -54,14 +54,14 @@ func NewSDKError(err error) SDKError { if e, ok := err.(Error); ok { return &sdkError{ statusCode: 0, - customError: &customError{ + CustomError: &CustomError{ msg: e.Msg(), - err: cast(e.Err()), + err: cast(e.Unwrap()), }, } } return &sdkError{ - customError: &customError{ + CustomError: &CustomError{ msg: err.Error(), err: nil, }, @@ -78,15 +78,15 @@ func NewSDKErrorWithStatus(err error, statusCode int) SDKError { if e, ok := err.(Error); ok { return &sdkError{ statusCode: statusCode, - customError: &customError{ + CustomError: &CustomError{ msg: e.Msg(), - err: cast(e.Err()), + err: cast(e.Unwrap()), }, } } return &sdkError{ statusCode: statusCode, - customError: &customError{ + CustomError: &CustomError{ msg: err.Error(), err: nil, }, diff --git a/pkg/errors/service/errors.go b/pkg/errors/service/errors.go new file mode 100644 index 0000000000..c7afe37056 --- /dev/null +++ b/pkg/errors/service/errors.go @@ -0,0 +1,63 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package service + +import "github.com/absmach/magistrala/pkg/errors" + +// Common errors that can be found in service layer. +var ( + // ErrAuthentication indicates failure occurred while authenticating the entity. + ErrAuthentication = &AuthenticationError{errors.NewErr("failed to perform authentication over the entity", nil)} + + // ErrAuthorization indicates failure occurred while authorizing the entity. + ErrAuthorization = &AuthorizationError{errors.NewErr("failed to perform authorization over the entity", nil)} + + // ErrDomainAuthorization indicates failure occurred while authorizing the domain. + ErrDomainAuthorization = &AuthorizationError{errors.NewErr("failed to perform authorization over the domain", nil)} + + // ErrLogin indicates wrong login credentials. + ErrLogin = &AuthenticationError{errors.NewErr("invalid user id or secret", nil)} + + // ErrMalformedEntity indicates a malformed entity specification. + ErrMalformedEntity = &MalformedError{errors.NewErr("malformed entity specification", nil)} + + // ErrNotFound indicates a non-existent entity request. + ErrNotFound = &NotFoundError{errors.NewErr("entity not found", nil)} + + // ErrConflict indicates that entity already exists. + ErrConflict = &ConflictError{errors.NewErr("entity already exists", nil)} + + // ErrCreateEntity indicates error in creating entity or entities. + ErrCreateEntity = &OtherError{errors.NewErr("failed to create entity", nil)} + + // ErrRemoveEntity indicates error in removing entity. + ErrRemoveEntity = &OtherError{errors.NewErr("failed to remove entity", nil)} + + // ErrViewEntity indicates error in viewing entity or entities. + ErrViewEntity = &OtherError{errors.NewErr("view entity failed", nil)} + + // ErrUpdateEntity indicates error in updating entity or entities. + ErrUpdateEntity = &OtherError{errors.NewErr("update entity failed", nil)} + + // ErrInvalidStatus indicates an invalid status. + ErrInvalidStatus = &OtherError{errors.NewErr("invalid status", nil)} + + // ErrInvalidRole indicates that an invalid role. + ErrInvalidRole = &OtherError{errors.NewErr("invalid client role", nil)} + + // ErrInvalidPolicy indicates that an invalid policy. + ErrInvalidPolicy = &OtherError{errors.NewErr("invalid policy", nil)} + + // ErrEnableClient indicates error in enabling client. + ErrEnableClient = &OtherError{errors.NewErr("failed to enable client", nil)} + + // ErrDisableClient indicates error in disabling client. + ErrDisableClient = &OtherError{errors.NewErr("failed to disable client", nil)} + + // ErrAddPolicies indicates error in adding policies. + ErrAddPolicies = &OtherError{errors.NewErr("failed to add policies", nil)} + + // ErrDeletePolicies indicates error in removing policies. + ErrDeletePolicies = &OtherError{errors.NewErr("failed to remove policies", nil)} +) diff --git a/pkg/errors/service/types.go b/pkg/errors/service/types.go index 13ff3e58f3..6aae63c271 100644 --- a/pkg/errors/service/types.go +++ b/pkg/errors/service/types.go @@ -5,59 +5,44 @@ package service import "github.com/absmach/magistrala/pkg/errors" -// Wrapper for Service errors. -var ( - // ErrAuthentication indicates failure occurred while authenticating the entity. - ErrAuthentication = errors.New("failed to perform authentication over the entity") - - // ErrAuthorization indicates failure occurred while authorizing the entity. - ErrAuthorization = errors.New("failed to perform authorization over the entity") - - // ErrDomainAuthorization indicates failure occurred while authorizing the domain. - ErrDomainAuthorization = errors.New("failed to perform authorization over the domain") - - // ErrLogin indicates wrong login credentials. - ErrLogin = errors.New("invalid user id or secret") - - // ErrMalformedEntity indicates a malformed entity specification. - ErrMalformedEntity = errors.New("malformed entity specification") - - // ErrNotFound indicates a non-existent entity request. - ErrNotFound = errors.New("entity not found") - - // ErrConflict indicates that entity already exists. - ErrConflict = errors.New("entity already exists") - - // ErrCreateEntity indicates error in creating entity or entities. - ErrCreateEntity = errors.New("failed to create entity") - - // ErrRemoveEntity indicates error in removing entity. - ErrRemoveEntity = errors.New("failed to remove entity") - - // ErrViewEntity indicates error in viewing entity or entities. - ErrViewEntity = errors.New("view entity failed") - - // ErrUpdateEntity indicates error in updating entity or entities. - ErrUpdateEntity = errors.New("update entity failed") - - // ErrInvalidStatus indicates an invalid status. - ErrInvalidStatus = errors.New("invalid status") - - // ErrInvalidRole indicates that an invalid role. - ErrInvalidRole = errors.New("invalid client role") - - // ErrInvalidPolicy indicates that an invalid policy. - ErrInvalidPolicy = errors.New("invalid policy") - - // ErrEnableClient indicates error in enabling client. - ErrEnableClient = errors.New("failed to enable client") - - // ErrDisableClient indicates error in disabling client. - ErrDisableClient = errors.New("failed to disable client") +// Wrapper service type errors +type ( + + // AuthenticationError indicates failure occurred while authenticating the entity. + AuthenticationError struct { + *errors.CustomError + } + + // AuthorizationError indicates failure occurred while authorizing the entity. + AuthorizationError struct { + *errors.CustomError + } + + // MalformedEntityError indicates a malformed entity specification. + MalformedError struct { + *errors.CustomError + } + + // ConflictError indicates unique constraint violation. + ConflictError struct { + *errors.CustomError + } + + // NotFoundError indicates that resource is not found at the given location. + NotFoundError struct { + *errors.CustomError + } + + // OtherError indicates unknown error usually caused by internal. + OtherError struct { + *errors.CustomError + } +) - // ErrAddPolicies indicates error in adding policies. - ErrAddPolicies = errors.New("failed to add policies") +func NewAuthNError(err error) error { + return &AuthenticationError{errors.NewErr("failed to perform authentication over the user", err)} +} - // ErrDeletePolicies indicates error in removing policies. - ErrDeletePolicies = errors.New("failed to remove policies") -) +func NewAuthZError(err error) error { + return &AuthenticationError{errors.NewErr("failed to perform authorization over the user", err)} +} diff --git a/pkg/postgres/common.go b/pkg/postgres/common.go index 3f394f7721..4716b3f79c 100644 --- a/pkg/postgres/common.go +++ b/pkg/postgres/common.go @@ -7,6 +7,8 @@ import ( "context" "encoding/json" "fmt" + + repoerr "github.com/absmach/magistrala/pkg/errors/repository" ) // CreateMetadataQuery creates a query to filter by metadata. @@ -23,7 +25,7 @@ func CreateMetadataQuery(entity string, um map[string]interface{}) (string, []by param, err := json.Marshal(um) if err != nil { - return "", nil, err + return "", nil, repoerr.NewTypeError("failed to marshal JSON metadata", err) } query := fmt.Sprintf("%smetadata @> :metadata", entity) diff --git a/pkg/postgres/errors.go b/pkg/postgres/errors.go index 541f7f2eb1..035b6b9ebb 100644 --- a/pkg/postgres/errors.go +++ b/pkg/postgres/errors.go @@ -22,6 +22,22 @@ const ( // HandleError handles the error and returns a wrapped error. // It checks the error code and returns a specific error. +func HandleRepoError(text string, err error) error { + pqErr, ok := err.(*pgconn.PgError) + if ok { + switch pqErr.Code { + case errDuplicate: + return repoerr.NewWriteError(text, err) + case errInvalid, errInvalidChar, errTruncation, errUntranslatable: + return repoerr.NewTypeError(text, err) + case errFK: + return repoerr.NewConstraintError(text, err) + } + } + + return repoerr.NewOtherError(text, err) +} + func HandleError(wrapper, err error) error { pqErr, ok := err.(*pgconn.PgError) if ok { diff --git a/things/postgres/clients_test.go b/things/postgres/clients_test.go index 8cef3c24db..0bf64a5b44 100644 --- a/things/postgres/clients_test.go +++ b/things/postgres/clients_test.go @@ -296,7 +296,7 @@ func TestClientsSave(t *testing.T) { }, }, }, - err: errors.ErrMalformedEntity, + err: repoerr.NewTypeError("failed to marshal JSON metadata", nil), }, } for _, tc := range cases { diff --git a/twins/mongodb/twins.go b/twins/mongodb/twins.go index be50df55a3..84d624e5cc 100644 --- a/twins/mongodb/twins.go +++ b/twins/mongodb/twins.go @@ -34,7 +34,7 @@ func NewTwinRepository(db *mongo.Database) twins.TwinRepository { func (tr *twinRepository) Save(ctx context.Context, tw twins.Twin) (string, error) { if len(tw.Name) > maxNameSize { - return "", errors.ErrMalformedEntity + return "", repoerr.ErrMalformedEntity } coll := tr.db.Collection(twinsCollection) @@ -48,7 +48,7 @@ func (tr *twinRepository) Save(ctx context.Context, tw twins.Twin) (string, erro func (tr *twinRepository) Update(ctx context.Context, tw twins.Twin) error { if len(tw.Name) > maxNameSize { - return errors.ErrMalformedEntity + return repoerr.ErrMalformedEntity } coll := tr.db.Collection(twinsCollection) diff --git a/users/api/clients.go b/users/api/clients.go index 35862e5491..6a03143192 100644 --- a/users/api/clients.go +++ b/users/api/clients.go @@ -25,7 +25,6 @@ import ( var passRegex = regexp.MustCompile("^.{8,}$") -// MakeHandler returns a HTTP handler for API endpoints. func clientsHandler(svc users.Service, r *chi.Mux, logger *slog.Logger, pr *regexp.Regexp, providers ...oauth2.Provider) http.Handler { passRegex = pr diff --git a/users/api/groups.go b/users/api/groups.go index 8362ae37a4..701877574b 100644 --- a/users/api/groups.go +++ b/users/api/groups.go @@ -21,7 +21,6 @@ import ( "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" ) -// MakeHandler returns a HTTP handler for Groups API endpoints. func groupsHandler(svc groups.Service, r *chi.Mux, logger *slog.Logger) http.Handler { opts := []kithttp.ServerOption{ kithttp.ServerErrorEncoder(apiutil.LoggingErrorEncoder(logger, api.EncodeError)), diff --git a/users/errors.go b/users/errors.go new file mode 100644 index 0000000000..df5fbfdb7d --- /dev/null +++ b/users/errors.go @@ -0,0 +1,63 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 +package users + +import ( + "github.com/absmach/magistrala/pkg/errors" + autherr "github.com/absmach/magistrala/pkg/errors/auth" + svcerr "github.com/absmach/magistrala/pkg/errors/service" +) + +const ( + // Malformed errors + HashErr = "failed to hash user credentials" + StatusErr = "invalid status" + RoleErr = "invalid role" + + // AuthN and AuthZ errors + AuthNErr = "failed to authenticate user" + AuthZErr = "failed to authorize user" + InvalidPasswordErr = "invalid password" + ResetTokenErr = "failed to issue reset token" + DisabledUserRefreshErr = "failed to refresh token for disabled user" + + // Internal errors + ViewErr = "failed to fetch user" + IssueTokenErr = "failed to issue token" + RollbackErr = "failed to remove policies during rollback" + AddPoliciesErr = "failed to add user policies" + DeletePoliciesErr = "failed to delete user policies" + PermissionsListErr = "failed to list permissions" + UpdateErr = "failed to update user" + UpdateTagsErr = "failed to update user tags" + UpdateIdentityErr = "failed to update user identity" + UpdateSecretErr = "failed to update user secret" + UpdateRoleErr = "failed to update user role" + ListMembersErr = "failed to list members" + UserAddErr = "failed to add user" + UpdateStatus = "failed to update user status to " +) + +var ( + errNotAddedStatus = errors.New("response status is not added") + errNotDeletedStatus = errors.New("response status is not deleted") + errAuthZ = autherr.NewAuthZError(AuthZErr, nil) +) + +type ( + MalformedError = svcerr.MalformedError + NotFoundError = svcerr.NotFoundError + InternalError = svcerr.OtherError +) + +func newMalformedError(text string, err error) error { + return &MalformedError{CustomError: errors.NewErr(text, err)} +} + +func newNotFoundError(text string, err error) error { + return &NotFoundError{CustomError: errors.NewErr(text, err)} +} + +func newInternalError(text string, err error) error { + return &InternalError{CustomError: errors.NewErr(text, err)} +} diff --git a/users/postgres/clients.go b/users/postgres/clients.go index cc4c8e2ab6..138487cc95 100644 --- a/users/postgres/clients.go +++ b/users/postgres/clients.go @@ -51,12 +51,12 @@ func (repo clientRepo) Save(ctx context.Context, c mgclients.Client) (mgclients. RETURNING id, name, tags, identity, metadata, status, created_at` dbc, err := pgclients.ToDBClient(c) if err != nil { - return mgclients.Client{}, errors.Wrap(repoerr.ErrCreateEntity, err) + return mgclients.Client{}, repoerr.NewTypeError("failed to convert to DB structure", err) } row, err := repo.DB.NamedQueryContext(ctx, q, dbc) if err != nil { - return mgclients.Client{}, postgres.HandleError(repoerr.ErrCreateEntity, err) + return mgclients.Client{}, postgres.HandleRepoError("failed to create user", err) } defer row.Close() @@ -78,18 +78,18 @@ func (repo clientRepo) CheckSuperAdmin(ctx context.Context, adminID string) erro q := "SELECT 1 FROM clients WHERE id = $1 AND role = $2" rows, err := repo.DB.QueryContext(ctx, q, adminID, mgclients.AdminRole) if err != nil { - return postgres.HandleError(repoerr.ErrViewEntity, err) + return postgres.HandleRepoError("failed to fetch admin user", err) } defer rows.Close() if rows.Next() { if err := rows.Err(); err != nil { - return postgres.HandleError(repoerr.ErrViewEntity, err) + return postgres.HandleRepoError("failed process db response", err) } return nil } - return repoerr.ErrNotFound + return repoerr.NewNotFoundError("super admin not found", nil) } func (repo clientRepo) RetrieveByID(ctx context.Context, id string) (mgclients.Client, error) { @@ -102,14 +102,14 @@ func (repo clientRepo) RetrieveByID(ctx context.Context, id string) (mgclients.C rows, err := repo.DB.NamedQueryContext(ctx, q, dbc) if err != nil { - return mgclients.Client{}, postgres.HandleError(repoerr.ErrViewEntity, err) + return mgclients.Client{}, postgres.HandleRepoError("failed to retrieve user", err) } defer rows.Close() dbc = pgclients.DBClient{} if rows.Next() { if err = rows.StructScan(&dbc); err != nil { - return mgclients.Client{}, postgres.HandleError(repoerr.ErrViewEntity, err) + return mgclients.Client{}, postgres.HandleRepoError("failed to retrieve user", err) } client, err := pgclients.ToClient(dbc) @@ -120,7 +120,7 @@ func (repo clientRepo) RetrieveByID(ctx context.Context, id string) (mgclients.C return client, nil } - return mgclients.Client{}, repoerr.ErrNotFound + return mgclients.Client{}, repoerr.NewNotFoundError("user not found", nil) } func (repo clientRepo) RetrieveAll(ctx context.Context, pm mgclients.Page) (mgclients.ClientsPage, error) { @@ -134,11 +134,11 @@ func (repo clientRepo) RetrieveAll(ctx context.Context, pm mgclients.Page) (mgcl dbPage, err := pgclients.ToDBClientsPage(pm) if err != nil { - return mgclients.ClientsPage{}, errors.Wrap(repoerr.ErrFailedToRetrieveAllGroups, err) + return mgclients.ClientsPage{}, repoerr.NewTypeError("failed to convert users page", err) } rows, err := repo.DB.NamedQueryContext(ctx, q, dbPage) if err != nil { - return mgclients.ClientsPage{}, errors.Wrap(repoerr.ErrFailedToRetrieveAllGroups, err) + return mgclients.ClientsPage{}, repoerr.NewReadError("failed to fetch users", err) } defer rows.Close() @@ -146,7 +146,7 @@ func (repo clientRepo) RetrieveAll(ctx context.Context, pm mgclients.Page) (mgcl for rows.Next() { dbc := pgclients.DBClient{} if err := rows.StructScan(&dbc); err != nil { - return mgclients.ClientsPage{}, errors.Wrap(repoerr.ErrViewEntity, err) + return mgclients.ClientsPage{}, repoerr.NewTypeError("failed to convert db data", err) } c, err := pgclients.ToClient(dbc) @@ -160,7 +160,7 @@ func (repo clientRepo) RetrieveAll(ctx context.Context, pm mgclients.Page) (mgcl total, err := postgres.Total(ctx, repo.DB, cq, dbPage) if err != nil { - return mgclients.ClientsPage{}, errors.Wrap(repoerr.ErrViewEntity, err) + return mgclients.ClientsPage{}, repoerr.NewReadError("failed to calculate total", err) } page := mgclients.ClientsPage{ @@ -182,21 +182,21 @@ func (repo clientRepo) UpdateRole(ctx context.Context, client mgclients.Client) dbc, err := pgclients.ToDBClient(client) if err != nil { - return mgclients.Client{}, errors.Wrap(repoerr.ErrUpdateEntity, err) + return mgclients.Client{}, repoerr.NewTypeError("failed to convert to db format", err) } row, err := repo.DB.NamedQueryContext(ctx, query, dbc) if err != nil { - return mgclients.Client{}, postgres.HandleError(err, repoerr.ErrUpdateEntity) + return mgclients.Client{}, postgres.HandleRepoError("failed write update", err) } defer row.Close() if ok := row.Next(); !ok { - return mgclients.Client{}, errors.Wrap(repoerr.ErrNotFound, row.Err()) + return mgclients.Client{}, repoerr.NewNotFoundError("user to update not found", row.Err()) } dbc = pgclients.DBClient{} if err := row.StructScan(&dbc); err != nil { - return mgclients.Client{}, err + return mgclients.Client{}, repoerr.NewTypeError("failed to convert from db response", err) } return pgclients.ToClient(dbc) diff --git a/users/postgres/clients_test.go b/users/postgres/clients_test.go index dd86ca887e..a563edc394 100644 --- a/users/postgres/clients_test.go +++ b/users/postgres/clients_test.go @@ -27,6 +27,10 @@ var ( namesgen = namegenerator.NewGenerator() ) +const ( + createRepoErrMsg = "failed to create user" +) + func TestClientsSave(t *testing.T) { t.Cleanup(func() { _, err := db.Exec("DELETE FROM clients") @@ -70,7 +74,7 @@ func TestClientsSave(t *testing.T) { Metadata: mgclients.Metadata{}, Status: mgclients.EnabledStatus, }, - err: repoerr.ErrConflict, + err: repoerr.NewWriteError(createRepoErrMsg, nil), }, { desc: "add client with duplicate client name", @@ -84,7 +88,7 @@ func TestClientsSave(t *testing.T) { Metadata: mgclients.Metadata{}, Status: mgclients.EnabledStatus, }, - err: repoerr.ErrConflict, + err: repoerr.NewWriteError(createRepoErrMsg, nil), }, { desc: "add client with invalid client id", @@ -98,7 +102,7 @@ func TestClientsSave(t *testing.T) { Metadata: mgclients.Metadata{}, Status: mgclients.EnabledStatus, }, - err: errors.ErrMalformedEntity, + err: repoerr.NewTypeError(createRepoErrMsg, nil), }, { desc: "add client with invalid client name", @@ -112,7 +116,7 @@ func TestClientsSave(t *testing.T) { Metadata: mgclients.Metadata{}, Status: mgclients.EnabledStatus, }, - err: errors.ErrMalformedEntity, + err: repoerr.NewTypeError(createRepoErrMsg, nil), }, { desc: "add client with invalid client identity", @@ -126,7 +130,7 @@ func TestClientsSave(t *testing.T) { Metadata: mgclients.Metadata{}, Status: mgclients.EnabledStatus, }, - err: errors.ErrMalformedEntity, + err: repoerr.NewTypeError(createRepoErrMsg, nil), }, { desc: "add client with a missing client name", @@ -177,13 +181,13 @@ func TestClientsSave(t *testing.T) { "key": make(chan int), }, }, - err: errors.ErrMalformedEntity, + err: repoerr.NewTypeError("failed to marshal JSON metadata", nil), }, } for _, tc := range cases { rClient, err := repo.Save(context.Background(), tc.client) - assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + assert.True(t, errors.ContainsType(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) if err == nil { rClient.Credentials.Secret = tc.client.Credentials.Secret assert.Equal(t, tc.client, rClient, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.client, rClient)) @@ -231,7 +235,7 @@ func TestIsPlatformAdmin(t *testing.T) { Status: mgclients.EnabledStatus, Role: mgclients.UserRole, }, - err: repoerr.ErrNotFound, + err: repoerr.NewNotFoundError("super admin not found", nil), }, } @@ -277,12 +281,12 @@ func TestRetrieveByID(t *testing.T) { { desc: "retrieve non-existing client", clientID: invalidName, - err: repoerr.ErrNotFound, + err: repoerr.NewNotFoundError("user not found", nil), }, { desc: "retrieve with empty client id", clientID: "", - err: repoerr.ErrNotFound, + err: repoerr.NewNotFoundError("user not found", nil), }, } @@ -738,7 +742,7 @@ func TestUpdateRole(t *testing.T) { desc: "update role with invalid client id", client: mgclients.Client{ID: invalidName}, newRole: mgclients.AdminRole, - err: repoerr.ErrNotFound, + err: repoerr.NewNotFoundError("user to update not found", nil), }, } diff --git a/users/service.go b/users/service.go index 9f1b1d1f5b..ed828cfb53 100644 --- a/users/service.go +++ b/users/service.go @@ -11,19 +11,12 @@ import ( "github.com/absmach/magistrala/auth" mgclients "github.com/absmach/magistrala/pkg/clients" "github.com/absmach/magistrala/pkg/errors" + autherr "github.com/absmach/magistrala/pkg/errors/auth" repoerr "github.com/absmach/magistrala/pkg/errors/repository" - svcerr "github.com/absmach/magistrala/pkg/errors/service" "github.com/absmach/magistrala/users/postgres" "golang.org/x/sync/errgroup" ) -var ( - errIssueToken = errors.New("failed to issue token") - errFailedPermissionsList = errors.New("failed to list permissions") - errRecoveryToken = errors.New("failed to generate password recovery token") - errLoginDisableUser = errors.New("failed to login in disabled user") -) - type service struct { clients postgres.Repository idProvider magistrala.IDProvider @@ -64,16 +57,16 @@ func (svc service) RegisterClient(ctx context.Context, token string, cli mgclien if cli.Credentials.Secret != "" { hash, err := svc.hasher.Hash(cli.Credentials.Secret) if err != nil { - return mgclients.Client{}, errors.Wrap(svcerr.ErrMalformedEntity, err) + return mgclients.Client{}, newMalformedError(HashErr, err) } cli.Credentials.Secret = hash } if cli.Status != mgclients.DisabledStatus && cli.Status != mgclients.EnabledStatus { - return mgclients.Client{}, errors.Wrap(svcerr.ErrMalformedEntity, svcerr.ErrInvalidStatus) + return mgclients.Client{}, newMalformedError(StatusErr, nil) } if cli.Role != mgclients.UserRole && cli.Role != mgclients.AdminRole { - return mgclients.Client{}, errors.Wrap(svcerr.ErrMalformedEntity, svcerr.ErrInvalidRole) + return mgclients.Client{}, newMalformedError(RoleErr, nil) } cli.ID = clientID cli.CreatedAt = time.Now() @@ -84,13 +77,13 @@ func (svc service) RegisterClient(ctx context.Context, token string, cli mgclien defer func() { if err != nil { if errRollback := svc.addClientPolicyRollback(ctx, cli.ID, cli.Role); errRollback != nil { - err = errors.Wrap(errors.Wrap(errors.ErrRollbackTx, errRollback), err) + err = newInternalError(err.Error(), errRollback) } } }() client, err := svc.clients.Save(ctx, cli) if err != nil { - return mgclients.Client{}, errors.Wrap(svcerr.ErrCreateEntity, err) + return mgclients.Client{}, newInternalError(UserAddErr, err) } return client, nil } @@ -98,10 +91,10 @@ func (svc service) RegisterClient(ctx context.Context, token string, cli mgclien func (svc service) IssueToken(ctx context.Context, identity, secret, domainID string) (*magistrala.Token, error) { dbUser, err := svc.clients.RetrieveByIdentity(ctx, identity) if err != nil { - return &magistrala.Token{}, errors.Wrap(svcerr.ErrAuthentication, err) + return &magistrala.Token{}, autherr.NewAuthNError(AuthNErr, err) } if err := svc.hasher.Compare(secret, dbUser.Credentials.Secret); err != nil { - return &magistrala.Token{}, errors.Wrap(svcerr.ErrLogin, err) + return &magistrala.Token{}, autherr.NewAuthNError(InvalidPasswordErr, err) } var d string @@ -111,7 +104,7 @@ func (svc service) IssueToken(ctx context.Context, identity, secret, domainID st token, err := svc.auth.Issue(ctx, &magistrala.IssueReq{UserId: dbUser.ID, DomainId: &d, Type: uint32(auth.AccessKey)}) if err != nil { - return &magistrala.Token{}, errors.Wrap(errIssueToken, err) + return &magistrala.Token{}, newInternalError(IssueTokenErr, err) } return token, err @@ -130,10 +123,10 @@ func (svc service) RefreshToken(ctx context.Context, refreshToken, domainID stri dbUser, err := svc.clients.RetrieveByID(ctx, tokenUserID) if err != nil { - return &magistrala.Token{}, errors.Wrap(svcerr.ErrAuthentication, err) + return &magistrala.Token{}, autherr.NewAuthNError(AuthNErr, err) } if dbUser.Status == mgclients.DisabledStatus { - return &magistrala.Token{}, errors.Wrap(svcerr.ErrAuthentication, errLoginDisableUser) + return &magistrala.Token{}, autherr.NewAuthZError(DisabledUserRefreshErr, nil) } return svc.auth.Refresh(ctx, &magistrala.RefreshReq{RefreshToken: refreshToken, DomainId: &d}) @@ -147,7 +140,7 @@ func (svc service) ViewClient(ctx context.Context, token, id string) (mgclients. client, err := svc.clients.RetrieveByID(ctx, id) if err != nil { - return mgclients.Client{}, errors.Wrap(svcerr.ErrViewEntity, err) + return mgclients.Client{}, newNotFoundError(ViewErr, err) } if tokenUserID != id { @@ -168,7 +161,7 @@ func (svc service) ViewProfile(ctx context.Context, token string) (mgclients.Cli } client, err := svc.clients.RetrieveByID(ctx, id) if err != nil { - return mgclients.Client{}, errors.Wrap(svcerr.ErrViewEntity, err) + return mgclients.Client{}, newNotFoundError(ViewErr, err) } client.Credentials.Secret = "" @@ -183,7 +176,7 @@ func (svc service) ListClients(ctx context.Context, token string, pm mgclients.P if err := svc.checkSuperAdmin(ctx, userID); err == nil { pg, err := svc.clients.RetrieveAll(ctx, pm) if err != nil { - return mgclients.ClientsPage{}, errors.Wrap(svcerr.ErrViewEntity, err) + return mgclients.ClientsPage{}, newNotFoundError(ViewErr, err) } return pg, err } @@ -198,7 +191,7 @@ func (svc service) ListClients(ctx context.Context, token string, pm mgclients.P } pg, err := svc.clients.RetrieveAll(ctx, p) if err != nil { - return mgclients.ClientsPage{}, errors.Wrap(svcerr.ErrViewEntity, err) + return mgclients.ClientsPage{}, newNotFoundError(ViewErr, err) } for i, c := range pg.Clients { @@ -230,7 +223,7 @@ func (svc service) UpdateClient(ctx context.Context, token string, cli mgclients client, err = svc.clients.Update(ctx, client) if err != nil { - return mgclients.Client{}, errors.Wrap(svcerr.ErrUpdateEntity, err) + return mgclients.Client{}, newInternalError(UpdateErr, err) } return client, nil } @@ -255,7 +248,7 @@ func (svc service) UpdateClientTags(ctx context.Context, token string, cli mgcli } client, err = svc.clients.UpdateTags(ctx, client) if err != nil { - return mgclients.Client{}, errors.Wrap(svcerr.ErrUpdateEntity, err) + return mgclients.Client{}, newInternalError(UpdateTagsErr, err) } return client, nil @@ -283,7 +276,7 @@ func (svc service) UpdateClientIdentity(ctx context.Context, token, clientID, id } cli, err = svc.clients.UpdateIdentity(ctx, cli) if err != nil { - return mgclients.Client{}, errors.Wrap(svcerr.ErrUpdateEntity, err) + return mgclients.Client{}, newInternalError(UpdateIdentityErr, err) } return cli, nil } @@ -291,7 +284,7 @@ func (svc service) UpdateClientIdentity(ctx context.Context, token, clientID, id func (svc service) GenerateResetToken(ctx context.Context, email, host string) error { client, err := svc.clients.RetrieveByIdentity(ctx, email) if err != nil { - return errors.Wrap(svcerr.ErrViewEntity, err) + return newNotFoundError(ViewErr, err) } issueReq := &magistrala.IssueReq{ UserId: client.ID, @@ -299,7 +292,7 @@ func (svc service) GenerateResetToken(ctx context.Context, email, host string) e } token, err := svc.auth.Issue(ctx, issueReq) if err != nil { - return errors.Wrap(errRecoveryToken, err) + return autherr.NewAuthZError(ResetTokenErr, err) } return svc.SendPasswordReset(ctx, host, email, client.Name, token.AccessToken) @@ -312,12 +305,11 @@ func (svc service) ResetSecret(ctx context.Context, resetToken, secret string) e } c, err := svc.clients.RetrieveByID(ctx, id) if err != nil { - return errors.Wrap(svcerr.ErrViewEntity, err) + return newNotFoundError(ViewErr, err) } - secret, err = svc.hasher.Hash(secret) if err != nil { - return errors.Wrap(svcerr.ErrMalformedEntity, err) + return newMalformedError(HashErr, err) } c = mgclients.Client{ ID: c.ID, @@ -329,7 +321,7 @@ func (svc service) ResetSecret(ctx context.Context, resetToken, secret string) e UpdatedBy: id, } if _, err := svc.clients.UpdateSecret(ctx, c); err != nil { - return errors.Wrap(svcerr.ErrAuthorization, err) + return newInternalError(UpdateSecretErr, err) } return nil } @@ -341,14 +333,14 @@ func (svc service) UpdateClientSecret(ctx context.Context, token, oldSecret, new } dbClient, err := svc.clients.RetrieveByID(ctx, id) if err != nil { - return mgclients.Client{}, errors.Wrap(svcerr.ErrViewEntity, err) + return mgclients.Client{}, newNotFoundError(ViewErr, err) } if _, err := svc.IssueToken(ctx, dbClient.Credentials.Identity, oldSecret, ""); err != nil { return mgclients.Client{}, err } newSecret, err = svc.hasher.Hash(newSecret) if err != nil { - return mgclients.Client{}, errors.Wrap(svcerr.ErrMalformedEntity, err) + return mgclients.Client{}, newMalformedError(HashErr, err) } dbClient.Credentials.Secret = newSecret dbClient.UpdatedAt = time.Now() @@ -356,7 +348,7 @@ func (svc service) UpdateClientSecret(ctx context.Context, token, oldSecret, new dbClient, err = svc.clients.UpdateSecret(ctx, dbClient) if err != nil { - return mgclients.Client{}, errors.Wrap(svcerr.ErrUpdateEntity, err) + return mgclients.Client{}, newInternalError(UpdateSecretErr, err) } return dbClient, nil @@ -395,9 +387,9 @@ func (svc service) UpdateClientRole(ctx context.Context, token string, cli mgcli if err != nil { // If failed to update role in DB, then revert back to platform admin policy in spicedb if errRollback := svc.updateClientPolicy(ctx, cli.ID, mgclients.UserRole); errRollback != nil { - return mgclients.Client{}, errors.Wrap(errRollback, err) + return mgclients.Client{}, newInternalError(err.Error(), errRollback) } - return mgclients.Client{}, errors.Wrap(svcerr.ErrUpdateEntity, err) + return mgclients.Client{}, newInternalError(UpdateRoleErr, err) } return client, nil } @@ -410,7 +402,7 @@ func (svc service) EnableClient(ctx context.Context, token, id string) (mgclient } client, err := svc.changeClientStatus(ctx, token, client) if err != nil { - return mgclients.Client{}, errors.Wrap(mgclients.ErrEnableClient, err) + return mgclients.Client{}, err } return client, nil @@ -442,7 +434,7 @@ func (svc service) changeClientStatus(ctx context.Context, token string, client } dbClient, err := svc.clients.RetrieveByID(ctx, client.ID) if err != nil { - return mgclients.Client{}, errors.Wrap(svcerr.ErrViewEntity, err) + return mgclients.Client{}, newNotFoundError(ViewErr, err) } if dbClient.Status == client.Status { return mgclients.Client{}, errors.ErrStatusAlreadyAssigned @@ -451,7 +443,7 @@ func (svc service) changeClientStatus(ctx context.Context, token string, client client, err = svc.clients.ChangeStatus(ctx, client) if err != nil { - return mgclients.Client{}, errors.Wrap(svcerr.ErrUpdateEntity, err) + return mgclients.Client{}, newInternalError(UpdateStatus+client.Status.String(), err) } return client, nil } @@ -492,7 +484,7 @@ func (svc service) ListMembers(ctx context.Context, token, objectKind, objectID } if _, err := svc.authorize(ctx, auth.UserType, auth.TokenKind, token, authzPerm, objectType, objectID); err != nil { - return mgclients.MembersPage{}, errors.Wrap(svcerr.ErrAuthorization, err) + return mgclients.MembersPage{}, autherr.NewAuthZError(AuthZErr, err) } duids, err := svc.auth.ListAllSubjects(ctx, &magistrala.ListSubjectsReq{ SubjectType: auth.UserType, @@ -501,7 +493,7 @@ func (svc service) ListMembers(ctx context.Context, token, objectKind, objectID ObjectType: objectType, }) if err != nil { - return mgclients.MembersPage{}, errors.Wrap(svcerr.ErrNotFound, err) + return mgclients.MembersPage{}, newInternalError(ListMembersErr, err) } if len(duids.Policies) == 0 { return mgclients.MembersPage{ @@ -519,7 +511,7 @@ func (svc service) ListMembers(ctx context.Context, token, objectKind, objectID cp, err := svc.clients.RetrieveAll(ctx, pm) if err != nil { - return mgclients.MembersPage{}, errors.Wrap(svcerr.ErrViewEntity, err) + return mgclients.MembersPage{}, newNotFoundError(ViewErr, err) } for i, c := range cp.Clients { @@ -557,7 +549,7 @@ func (svc service) retrieveObjectUsersPermissions(ctx context.Context, domainID, userID := auth.EncodeDomainUserID(domainID, client.ID) permissions, err := svc.listObjectUserPermission(ctx, userID, objectType, objectID) if err != nil { - return errors.Wrap(svcerr.ErrAuthorization, err) + return autherr.NewAuthZError(AuthZErr, err) } client.Permissions = permissions return nil @@ -571,7 +563,7 @@ func (svc service) listObjectUserPermission(ctx context.Context, userID, objectT ObjectType: objectType, }) if err != nil { - return []string{}, errors.Wrap(errFailedPermissionsList, err) + return []string{}, newInternalError(PermissionsListErr, err) } return lp.GetPermissions(), nil } @@ -579,9 +571,9 @@ func (svc service) listObjectUserPermission(ctx context.Context, userID, objectT func (svc *service) checkSuperAdmin(ctx context.Context, adminID string) error { if _, err := svc.authorize(ctx, auth.UserType, auth.UsersKind, adminID, auth.AdminPermission, auth.PlatformType, auth.MagistralaObject); err != nil { if err := svc.clients.CheckSuperAdmin(ctx, adminID); err != nil { - return errors.Wrap(svcerr.ErrAuthorization, err) + return autherr.NewAuthZError(AuthZErr, err) } - return errors.Wrap(svcerr.ErrAuthorization, err) + return autherr.NewAuthZError(AuthZErr, err) } return nil @@ -590,7 +582,7 @@ func (svc *service) checkSuperAdmin(ctx context.Context, adminID string) error { func (svc service) identify(ctx context.Context, token string) (*magistrala.IdentityRes, error) { res, err := svc.auth.Identify(ctx, &magistrala.IdentityReq{Token: token}) if err != nil { - return &magistrala.IdentityRes{}, errors.Wrap(svcerr.ErrAuthentication, err) + return &magistrala.IdentityRes{}, autherr.NewAuthNError(AuthNErr, err) } return res, nil } @@ -606,11 +598,11 @@ func (svc *service) authorize(ctx context.Context, subjType, subjKind, subj, per } res, err := svc.auth.Authorize(ctx, req) if err != nil { - return "", errors.Wrap(svcerr.ErrAuthorization, err) + return "", autherr.NewAuthZError(AuthZErr, err) } if !res.GetAuthorized() { - return "", svcerr.ErrAuthorization + return "", errAuthZ } return res.GetId(), nil } @@ -618,7 +610,7 @@ func (svc *service) authorize(ctx context.Context, subjType, subjKind, subj, per func (svc service) OAuthCallback(ctx context.Context, client mgclients.Client) (*magistrala.Token, error) { rclient, err := svc.clients.RetrieveByIdentity(ctx, client.Credentials.Identity) if err != nil { - switch errors.Contains(err, repoerr.ErrNotFound) { + switch errors.ContainsType(err, repoerr.ErrNotFound) { case true: rclient, err = svc.RegisterClient(ctx, "", client) if err != nil { @@ -644,11 +636,11 @@ func (svc service) OAuthCallback(ctx context.Context, client mgclients.Client) ( } func (svc service) Identify(ctx context.Context, token string) (string, error) { - user, err := svc.auth.Identify(ctx, &magistrala.IdentityReq{Token: token}) + res, err := svc.identify(ctx, token) if err != nil { - return "", errors.Wrap(svcerr.ErrAuthentication, err) + return "", err } - return user.GetUserId(), nil + return res.GetUserId(), nil } func (svc service) addClientPolicy(ctx context.Context, userID string, role mgclients.Role) error { @@ -673,10 +665,10 @@ func (svc service) addClientPolicy(ctx context.Context, userID string, role mgcl } resp, err := svc.auth.AddPolicies(ctx, &policies) if err != nil { - return errors.Wrap(svcerr.ErrAddPolicies, err) + return newInternalError(AddPoliciesErr, err) } if !resp.Added { - return svcerr.ErrAuthorization + return newInternalError(AddPoliciesErr, errNotAddedStatus) } return nil } @@ -703,10 +695,10 @@ func (svc service) addClientPolicyRollback(ctx context.Context, userID string, r } resp, err := svc.auth.DeletePolicies(ctx, &policies) if err != nil { - return errors.Wrap(svcerr.ErrDeletePolicies, err) + return newInternalError(RollbackErr, err) } if !resp.Deleted { - return svcerr.ErrAuthorization + return newInternalError(RollbackErr, errNotDeletedStatus) } return nil } @@ -722,10 +714,10 @@ func (svc service) updateClientPolicy(ctx context.Context, userID string, role m Object: auth.MagistralaObject, }) if err != nil { - return errors.Wrap(svcerr.ErrAddPolicies, err) + return newInternalError(AddPoliciesErr, err) } if !resp.Added { - return svcerr.ErrAuthorization + return newInternalError(AddPoliciesErr, errNotAddedStatus) } return nil case mgclients.UserRole: @@ -739,10 +731,10 @@ func (svc service) updateClientPolicy(ctx context.Context, userID string, role m Object: auth.MagistralaObject, }) if err != nil { - return errors.Wrap(svcerr.ErrDeletePolicies, err) + return newInternalError(DeletePoliciesErr, err) } if !resp.Deleted { - return svcerr.ErrAuthorization + return newInternalError(DeletePoliciesErr, errNotDeletedStatus) } return nil }