diff --git a/.github/workflows/check-generated-files.yml b/.github/workflows/check-generated-files.yml index 99730f9546..49c4c00d95 100644 --- a/.github/workflows/check-generated-files.yml +++ b/.github/workflows/check-generated-files.yml @@ -50,6 +50,7 @@ jobs: - "auth/keys.go" - "auth/policies.go" - "pkg/events/events.go" + - "provision/service.go" - name: Set up protoc if: steps.changes.outputs.proto == 'true' @@ -109,6 +110,7 @@ jobs: mv ./auth/mocks/keys.go ./auth/mocks/keys.go.tmp mv ./pkg/events/mocks/publisher.go ./pkg/events/mocks/publisher.go.tmp mv ./pkg/events/mocks/subscriber.go ./pkg/events/mocks/subscriber.go.tmp + mv ./provision/mocks/service.go ./provision/mocks/service.go.tmp make mocks @@ -137,3 +139,4 @@ jobs: check_mock_changes ./auth/mocks/keys.go "Auth Keys ./auth/mocks/keys.go" check_mock_changes ./pkg/events/mocks/publisher.go "ES Publisher ./pkg/events/mocks/publisher.go" check_mock_changes ./pkg/events/mocks/subscriber.go "EE Subscriber ./pkg/events/mocks/subscriber.go" + check_mock_changes ./provision/mocks/service.go "Provision Service ./provision/mocks/service.go" diff --git a/internal/api/common.go b/internal/api/common.go index 90b086c4b1..a3c92ac118 100644 --- a/internal/api/common.go +++ b/internal/api/common.go @@ -114,6 +114,7 @@ func EncodeError(_ context.Context, err error, w http.ResponseWriter) { errors.Contains(err, apiutil.ErrMissingMemberType), errors.Contains(err, apiutil.ErrMissingMemberKind), errors.Contains(err, apiutil.ErrLimitSize), + errors.Contains(err, apiutil.ErrBearerKey), errors.Contains(err, apiutil.ErrNameSize): w.WriteHeader(http.StatusBadRequest) case errors.Contains(err, svcerr.ErrAuthentication), diff --git a/provision/api/endpoint.go b/provision/api/endpoint.go index 2b9dc7b889..495ea6f3f9 100644 --- a/provision/api/endpoint.go +++ b/provision/api/endpoint.go @@ -18,11 +18,10 @@ func doProvision(svc provision.Service) endpoint.Endpoint { if err := req.validate(); err != nil { return nil, errors.Wrap(apiutil.ErrValidation, err) } - token := req.token - res, err := svc.Provision(token, req.Name, req.ExternalID, req.ExternalKey) + res, err := svc.Provision(req.token, req.Name, req.ExternalID, req.ExternalKey) if err != nil { - return provisionRes{Error: err.Error()}, nil + return nil, err } provisionResponse := provisionRes{ @@ -44,6 +43,7 @@ func getMapping(svc provision.Service) endpoint.Endpoint { if err := req.validate(); err != nil { return nil, errors.Wrap(apiutil.ErrValidation, err) } + return svc.Mapping(req.token) } } diff --git a/provision/api/endpoint_test.go b/provision/api/endpoint_test.go new file mode 100644 index 0000000000..be9cbd1b76 --- /dev/null +++ b/provision/api/endpoint_test.go @@ -0,0 +1,210 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package api_test + +import ( + "fmt" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/absmach/magistrala/internal/apiutil" + "github.com/absmach/magistrala/internal/testsutil" + mglog "github.com/absmach/magistrala/logger" + "github.com/absmach/magistrala/pkg/errors" + "github.com/absmach/magistrala/provision" + "github.com/absmach/magistrala/provision/api" + "github.com/absmach/magistrala/provision/mocks" + "github.com/stretchr/testify/assert" +) + +var ( + validToken = "valid" + validContenType = "application/json" + validID = testsutil.GenerateUUID(&testing.T{}) +) + +type testRequest struct { + client *http.Client + method string + url string + token string + contentType string + body io.Reader +} + +func (tr testRequest) make() (*http.Response, error) { + req, err := http.NewRequest(tr.method, tr.url, tr.body) + if err != nil { + return nil, err + } + + if tr.token != "" { + req.Header.Set("Authorization", apiutil.BearerPrefix+tr.token) + } + + if tr.contentType != "" { + req.Header.Set("Content-Type", tr.contentType) + } + + return tr.client.Do(req) +} + +func newProvisionServer() (*httptest.Server, *mocks.Service) { + svc := new(mocks.Service) + + logger := mglog.NewMock() + mux := api.MakeHandler(svc, logger, "test") + return httptest.NewServer(mux), svc +} + +func TestProvision(t *testing.T) { + is, svc := newProvisionServer() + + cases := []struct { + desc string + token string + data string + contentType string + status int + svcErr error + }{ + { + desc: "valid request", + token: validToken, + data: fmt.Sprintf(`{"name": "test", "external_id": "%s", "external_key": "%s"}`, validID, validID), + status: http.StatusCreated, + contentType: validContenType, + svcErr: nil, + }, + { + desc: "request with empty external id", + token: validToken, + data: fmt.Sprintf(`{"name": "test", "external_key": "%s"}`, validID), + status: http.StatusBadRequest, + contentType: validContenType, + svcErr: nil, + }, + { + desc: "request with empty external key", + token: validToken, + data: fmt.Sprintf(`{"name": "test", "external_id": "%s"}`, validID), + status: http.StatusBadRequest, + contentType: validContenType, + svcErr: nil, + }, + { + desc: "empty token", + token: "", + data: fmt.Sprintf(`{"name": "test", "external_id": "%s", "external_key": "%s"}`, validID, validID), + status: http.StatusCreated, + contentType: validContenType, + svcErr: nil, + }, + { + desc: "invalid content type", + token: validToken, + data: fmt.Sprintf(`{"name": "test", "external_id": "%s", "external_key": "%s"}`, validID, validID), + status: http.StatusUnsupportedMediaType, + contentType: "text/plain", + svcErr: nil, + }, + { + desc: "invalid request", + token: validToken, + data: `data`, + status: http.StatusBadRequest, + contentType: validContenType, + svcErr: nil, + }, + { + desc: "service error", + token: validToken, + data: fmt.Sprintf(`{"name": "test", "external_id": "%s", "external_key": "%s"}`, validID, validID), + status: http.StatusForbidden, + contentType: validContenType, + svcErr: errors.ErrAuthorization, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + repocall := svc.On("Provision", tc.token, "test", validID, validID).Return(provision.Result{}, tc.svcErr) + req := testRequest{ + client: is.Client(), + method: http.MethodPost, + url: is.URL + "/mapping", + token: tc.token, + contentType: tc.contentType, + body: strings.NewReader(tc.data), + } + + resp, err := req.make() + assert.Nil(t, err, tc.desc) + assert.Equal(t, tc.status, resp.StatusCode, tc.desc) + repocall.Unset() + }) + } +} + +func TestMapping(t *testing.T) { + is, svc := newProvisionServer() + + cases := []struct { + desc string + token string + contentType string + status int + svcErr error + }{ + { + desc: "valid request", + token: validToken, + status: http.StatusOK, + contentType: validContenType, + svcErr: nil, + }, + { + desc: "empty token", + token: "", + status: http.StatusUnauthorized, + contentType: validContenType, + svcErr: nil, + }, + { + desc: "invalid content type", + token: validToken, + status: http.StatusUnsupportedMediaType, + contentType: "text/plain", + svcErr: nil, + }, + { + desc: "service error", + token: validToken, + status: http.StatusForbidden, + contentType: validContenType, + svcErr: errors.ErrAuthorization, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + repocall := svc.On("Mapping", tc.token).Return(map[string]interface{}{}, tc.svcErr) + req := testRequest{ + client: is.Client(), + method: http.MethodGet, + url: is.URL + "/mapping", + token: tc.token, + contentType: tc.contentType, + } + + resp, err := req.make() + assert.Nil(t, err, tc.desc) + assert.Equal(t, tc.status, resp.StatusCode, tc.desc) + repocall.Unset() + }) + } +} diff --git a/provision/api/requests_test.go b/provision/api/requests_test.go index d1dfed0ff8..68f5158c24 100644 --- a/provision/api/requests_test.go +++ b/provision/api/requests_test.go @@ -8,33 +8,79 @@ import ( "testing" "github.com/absmach/magistrala/internal/apiutil" + "github.com/absmach/magistrala/internal/testsutil" "github.com/absmach/magistrala/pkg/errors" "github.com/stretchr/testify/assert" ) -func TestValidate(t *testing.T) { - cases := map[string]struct { - ExternalID string - ExternalKey string - err error +func TestProvisioReq(t *testing.T) { + cases := []struct { + desc string + req provisionReq + err error }{ - "mac address for device": { - ExternalID: "11:22:33:44:55:66", - ExternalKey: "key12345678", - err: nil, + { + desc: "valid request", + req: provisionReq{ + token: "token", + Name: "name", + ExternalID: testsutil.GenerateUUID(t), + ExternalKey: testsutil.GenerateUUID(t), + }, + err: nil, }, - "external id for device empty": { + { + desc: "empty external id", + req: provisionReq{ + token: "token", + Name: "name", + ExternalID: "", + ExternalKey: testsutil.GenerateUUID(t), + }, err: apiutil.ErrMissingID, }, + { + desc: "empty external key", + req: provisionReq{ + token: "token", + Name: "name", + ExternalID: testsutil.GenerateUUID(t), + ExternalKey: "", + }, + err: apiutil.ErrBearerKey, + }, + } + + for _, tc := range cases { + err := tc.req.validate() + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected `%v` got `%v`", tc.desc, tc.err, err)) } +} - for desc, tc := range cases { - req := provisionReq{ - ExternalID: tc.ExternalID, - ExternalKey: tc.ExternalKey, - } +func TestMappingReq(t *testing.T) { + cases := []struct { + desc string + req mappingReq + err error + }{ + { + desc: "valid request", + req: mappingReq{ + token: "token", + }, + err: nil, + }, + { + desc: "empty token", + req: mappingReq{ + token: "", + }, + err: apiutil.ErrBearerToken, + }, + } - err := req.validate() - assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected `%v` got `%v`", desc, tc.err, err)) + for _, tc := range cases { + err := tc.req.validate() + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected `%v` got `%v`", tc.desc, tc.err, err)) } } diff --git a/provision/api/responses.go b/provision/api/responses.go index e8550986c0..4acd358d72 100644 --- a/provision/api/responses.go +++ b/provision/api/responses.go @@ -6,9 +6,12 @@ package api import ( "net/http" + "github.com/absmach/magistrala" sdk "github.com/absmach/magistrala/pkg/sdk/go" ) +var _ magistrala.Response = (*provisionRes)(nil) + type provisionRes struct { Things []sdk.Thing `json:"things"` Channels []sdk.Channel `json:"channels"` @@ -16,7 +19,6 @@ type provisionRes struct { ClientKey map[string]string `json:"client_key,omitempty"` CACert string `json:"ca_cert,omitempty"` Whitelisted map[string]bool `json:"whitelisted,omitempty"` - Error string `json:"error,omitempty"` } func (res provisionRes) Code() int { diff --git a/provision/api/transport.go b/provision/api/transport.go index 89c95172bd..6ed3a07284 100644 --- a/provision/api/transport.go +++ b/provision/api/transport.go @@ -9,6 +9,7 @@ import ( "net/http" "github.com/absmach/magistrala" + "github.com/absmach/magistrala/internal/api" "github.com/absmach/magistrala/internal/apiutil" mglog "github.com/absmach/magistrala/logger" "github.com/absmach/magistrala/pkg/errors" @@ -25,24 +26,25 @@ const ( // MakeHandler returns a HTTP handler for API endpoints. func MakeHandler(svc provision.Service, logger mglog.Logger, instanceID string) http.Handler { opts := []kithttp.ServerOption{ - kithttp.ServerErrorEncoder(apiutil.LoggingErrorEncoder(logger, encodeError)), + kithttp.ServerErrorEncoder(apiutil.LoggingErrorEncoder(logger, api.EncodeError)), } r := chi.NewRouter() - r.Post("/mapping", kithttp.NewServer( - doProvision(svc), - decodeProvisionRequest, - encodeResponse, - opts..., - ).ServeHTTP) - - r.Get("/mapping", kithttp.NewServer( - getMapping(svc), - decodeMappingRequest, - encodeResponse, - opts..., - ).ServeHTTP) + r.Route("/mapping", func(r chi.Router) { + r.Post("/", kithttp.NewServer( + doProvision(svc), + decodeProvisionRequest, + api.EncodeResponse, + opts..., + ).ServeHTTP) + r.Get("/", kithttp.NewServer( + getMapping(svc), + decodeMappingRequest, + api.EncodeResponse, + opts..., + ).ServeHTTP) + }) r.Handle("/metrics", promhttp.Handler()) r.Get("/health", magistrala.Health("provision", instanceID)) @@ -50,24 +52,6 @@ func MakeHandler(svc provision.Service, logger mglog.Logger, instanceID string) return r } -func encodeResponse(_ context.Context, w http.ResponseWriter, response interface{}) error { - w.Header().Set("Content-Type", contentType) - - if ar, ok := response.(magistrala.Response); ok { - for k, v := range ar.Headers() { - w.Header().Set(k, v) - } - - w.WriteHeader(ar.Code()) - - if ar.Empty() { - return nil - } - } - - return json.NewEncoder(w).Encode(response) -} - func decodeProvisionRequest(_ context.Context, r *http.Request) (interface{}, error) { if r.Header.Get("Content-Type") != contentType { return nil, errors.Wrap(apiutil.ErrValidation, apiutil.ErrUnsupportedContentType) @@ -75,7 +59,7 @@ func decodeProvisionRequest(_ context.Context, r *http.Request) (interface{}, er req := provisionReq{token: apiutil.ExtractBearerToken(r)} if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - return nil, errors.Wrap(apiutil.ErrValidation, err) + return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(err, errors.ErrMalformedEntity)) } return req, nil @@ -90,37 +74,3 @@ func decodeMappingRequest(_ context.Context, r *http.Request) (interface{}, erro return req, nil } - -func encodeError(_ context.Context, err error, w http.ResponseWriter) { - var wrapper error - if errors.Contains(err, apiutil.ErrValidation) { - wrapper, err = errors.Unwrap(err) - } - - switch { - case errors.Contains(err, errors.ErrAuthentication), - errors.Contains(err, apiutil.ErrBearerToken): - w.WriteHeader(http.StatusUnauthorized) - case errors.Contains(err, apiutil.ErrUnsupportedContentType): - w.WriteHeader(http.StatusUnsupportedMediaType) - case errors.Contains(err, errors.ErrMalformedEntity), - errors.Contains(err, apiutil.ErrMissingID), - errors.Contains(err, apiutil.ErrBearerKey): - w.WriteHeader(http.StatusBadRequest) - case errors.Contains(err, errors.ErrConflict): - w.WriteHeader(http.StatusConflict) - default: - w.WriteHeader(http.StatusInternalServerError) - } - - if wrapper != nil { - err = errors.Wrap(wrapper, err) - } - - if errorVal, ok := err.(errors.Error); ok { - w.Header().Set("Content-Type", contentType) - if err := json.NewEncoder(w).Encode(errorVal); err != nil { - w.WriteHeader(http.StatusInternalServerError) - } - } -} diff --git a/provision/mocks/service.go b/provision/mocks/service.go new file mode 100644 index 0000000000..24a21fab5c --- /dev/null +++ b/provision/mocks/service.go @@ -0,0 +1,122 @@ +// Code generated by mockery v2.38.0. DO NOT EDIT. + +// Copyright (c) Abstract Machines + +package mocks + +import ( + provision "github.com/absmach/magistrala/provision" + mock "github.com/stretchr/testify/mock" +) + +// Service is an autogenerated mock type for the Service type +type Service struct { + mock.Mock +} + +// Cert provides a mock function with given fields: token, thingID, duration +func (_m *Service) Cert(token string, thingID string, duration string) (string, string, error) { + ret := _m.Called(token, thingID, duration) + + if len(ret) == 0 { + panic("no return value specified for Cert") + } + + var r0 string + var r1 string + var r2 error + if rf, ok := ret.Get(0).(func(string, string, string) (string, string, error)); ok { + return rf(token, thingID, duration) + } + if rf, ok := ret.Get(0).(func(string, string, string) string); ok { + r0 = rf(token, thingID, duration) + } else { + r0 = ret.Get(0).(string) + } + + if rf, ok := ret.Get(1).(func(string, string, string) string); ok { + r1 = rf(token, thingID, duration) + } else { + r1 = ret.Get(1).(string) + } + + if rf, ok := ret.Get(2).(func(string, string, string) error); ok { + r2 = rf(token, thingID, duration) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// Mapping provides a mock function with given fields: token +func (_m *Service) Mapping(token string) (map[string]interface{}, error) { + ret := _m.Called(token) + + if len(ret) == 0 { + panic("no return value specified for Mapping") + } + + var r0 map[string]interface{} + var r1 error + if rf, ok := ret.Get(0).(func(string) (map[string]interface{}, error)); ok { + return rf(token) + } + if rf, ok := ret.Get(0).(func(string) map[string]interface{}); ok { + r0 = rf(token) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(map[string]interface{}) + } + } + + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(token) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Provision provides a mock function with given fields: token, name, externalID, externalKey +func (_m *Service) Provision(token string, name string, externalID string, externalKey string) (provision.Result, error) { + ret := _m.Called(token, name, externalID, externalKey) + + if len(ret) == 0 { + panic("no return value specified for Provision") + } + + var r0 provision.Result + var r1 error + if rf, ok := ret.Get(0).(func(string, string, string, string) (provision.Result, error)); ok { + return rf(token, name, externalID, externalKey) + } + if rf, ok := ret.Get(0).(func(string, string, string, string) provision.Result); ok { + r0 = rf(token, name, externalID, externalKey) + } else { + r0 = ret.Get(0).(provision.Result) + } + + if rf, ok := ret.Get(1).(func(string, string, string, string) error); ok { + r1 = rf(token, name, externalID, externalKey) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// NewService creates a new instance of Service. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewService(t interface { + mock.TestingT + Cleanup(func()) +}) *Service { + mock := &Service{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/provision/service.go b/provision/service.go index 6b1e35594b..83b7a1c4e1 100644 --- a/provision/service.go +++ b/provision/service.go @@ -46,6 +46,8 @@ var ( var _ Service = (*provisionService)(nil) // Service specifies Provision service API. +// +//go:generate mockery --name Service --output=./mocks --filename service.go --quiet --note "Copyright (c) Abstract Machines" type Service interface { // Provision is the only method this API specifies. Depending on the configuration, // the following actions will can be executed: