From b97d93f7e96dee73b688c464b66db90c3dea1c4b Mon Sep 17 00:00:00 2001 From: Marius Goetze Date: Tue, 30 Apr 2024 13:24:38 +0200 Subject: [PATCH 1/3] fix: fix database schema - wrong column name for origin uri, should be snake case - timestamp is mandatory --- pkg/models/notification.go | 16 ++++++++-------- .../migrations/000001_init_schema.up.sql | 4 ++-- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/pkg/models/notification.go b/pkg/models/notification.go index a3c5e12..558112f 100644 --- a/pkg/models/notification.go +++ b/pkg/models/notification.go @@ -5,12 +5,12 @@ package models type Notification struct { - Id string `json:"id" db:"id" readonly:"true"` - Origin string `json:"origin" db:"origin" binding:"required"` - OriginUri string `json:"originUri,omitempty" db:"origin"` // can be used to provide a link to the origin - Timestamp string `json:"timestamp" db:"timestamp" binding:"required" format:"date-time"` - Title string `json:"title" db:"title" binding:"required"` // can also be seen as the 'type' - Detail string `json:"detail" db:"detail" binding:"required"` - Level string `json:"level" db:"level" binding:"required" enums:"info,warning,error,critical"` - CustomFields map[string]any `json:"customFields,omitempty" db:"custom_fields"` // can contain arbitrary structured information about the notification + Id string `json:"id" readonly:"true"` + Origin string `json:"origin" binding:"required"` + OriginUri string `json:"originUri,omitempty"` // can be used to provide a link to the origin + Timestamp string `json:"timestamp" binding:"required" format:"date-time"` + Title string `json:"title" binding:"required"` // can also be seen as the 'type' + Detail string `json:"detail" binding:"required"` + Level string `json:"level" binding:"required" enums:"info,warning,error,critical"` + CustomFields map[string]any `json:"customFields,omitempty"` // can contain arbitrary structured information about the notification } diff --git a/pkg/repository/migrations/000001_init_schema.up.sql b/pkg/repository/migrations/000001_init_schema.up.sql index e1ec862..735557d 100644 --- a/pkg/repository/migrations/000001_init_schema.up.sql +++ b/pkg/repository/migrations/000001_init_schema.up.sql @@ -5,8 +5,8 @@ COMMENT ON SCHEMA notification_service IS 'Notification Service schema'; CREATE TABLE notification_service.notifications ( "id" UUID DEFAULT gen_random_uuid() PRIMARY KEY, "origin" TEXT NOT NULL, - "originUri" TEXT NOT NULL, - "timestamp" TIMESTAMP, + "origin_uri" TEXT, + "timestamp" TIMESTAMP NOT NULL, "title" TEXT NOT NULL, "detail" TEXT NOT NULL, "level" VARCHAR(255) NOT NULL, From 63c4e950b3a99535a4599d9032bce51c3b93e418 Mon Sep 17 00:00:00 2001 From: Marius Goetze Date: Fri, 10 May 2024 17:33:09 +0200 Subject: [PATCH 2/3] add packages from vi backend ultimately those components should be extracted to the opensight-golang-libraries --- pkg/errs/errors.go | 45 ++++++ pkg/restErrorHandler/rest_error_handler.go | 48 +++++++ .../rest_error_handler_test.go | 102 ++++++++++++++ pkg/web/helper/parameter_parsing.go | 133 ++++++++++++++++++ 4 files changed, 328 insertions(+) create mode 100644 pkg/errs/errors.go create mode 100644 pkg/restErrorHandler/rest_error_handler.go create mode 100644 pkg/restErrorHandler/rest_error_handler_test.go create mode 100644 pkg/web/helper/parameter_parsing.go diff --git a/pkg/errs/errors.go b/pkg/errs/errors.go new file mode 100644 index 0000000..ea9963a --- /dev/null +++ b/pkg/errs/errors.go @@ -0,0 +1,45 @@ +// SPDX-FileCopyrightText: 2024 Greenbone AG +// +// SPDX-License-Identifier: AGPL-3.0-or-later + +package errs + +import ( + "errors" + "fmt" +) + +// ErrItemNotFound is the error used when looking up an item by ID, e.g. +// OID for VTs, fails because the item cannot be found. +var ErrItemNotFound = errors.New("item not found") + +// embed this error to mark an error as retryable +var ErrRetryable = errors.New("(retryable error)") + +// ErrConflict indicates a conflict. If there are certain fields conflicting which are meaningful to the client, +// set the individual error message for a property via `Errors`, otherwise just set `Message`. +type ErrConflict struct { + Message string + Errors map[string]string // maps property to specific error message +} + +func (e *ErrConflict) Error() string { + message := e.Message + if len(e.Errors) > 0 { + message += fmt.Sprintf(", specific errors: %v", e.Errors) + } + return message +} + +type ErrValidation struct { + Message string + Errors map[string]string // maps property to specific error message +} + +func (e *ErrValidation) Error() string { + message := e.Message + if len(e.Errors) > 0 { + message += fmt.Sprintf(", specific errors: %v", e.Errors) + } + return message +} diff --git a/pkg/restErrorHandler/rest_error_handler.go b/pkg/restErrorHandler/rest_error_handler.go new file mode 100644 index 0000000..f9a629c --- /dev/null +++ b/pkg/restErrorHandler/rest_error_handler.go @@ -0,0 +1,48 @@ +// SPDX-FileCopyrightText: 2024 Greenbone AG +// +// SPDX-License-Identifier: AGPL-3.0-or-later + +package restErrorHandler + +import ( + "errors" + "net/http" + + "github.com/greenbone/opensight-notification-service/pkg/errs" + + "github.com/greenbone/opensight-golang-libraries/pkg/errorResponses" + + "github.com/gin-gonic/gin" + "github.com/rs/zerolog/log" +) + +// ErrorHandler determines the appropriate error response and code from the error type. It relies on the types defined in [errs]. +// The default case is an internal server error hiding the implementation details from the client. In this case a log message is issued containing the error. +// A log message for context can be provided via parameter internalErrorLogMessage. +func ErrorHandler(gc *gin.Context, internalErrorLogMessage string, err error) { + var errConflict *errs.ErrConflict + var errValidation *errs.ErrValidation + switch { + case errors.Is(err, errs.ErrItemNotFound): + gc.JSON(http.StatusNotFound, errorResponses.NewErrorGenericResponse(err.Error())) + case errors.As(err, &errConflict): + gc.JSON(http.StatusConflict, ErrConflictToResponse(*errConflict)) + case errors.As(err, &errValidation): + gc.JSON(http.StatusBadRequest, ErrValidationToResponse(*errValidation)) + default: + log.Err(err).Str("endpoint", gc.Request.Method+" "+gc.Request.URL.Path).Msg(internalErrorLogMessage) + gc.JSON(http.StatusInternalServerError, errorResponses.ErrorInternalResponse) + } +} + +func ErrValidationToResponse(err errs.ErrValidation) errorResponses.ErrorResponse { + return errorResponses.NewErrorValidationResponse(err.Message, "", err.Errors) +} + +func ErrConflictToResponse(err errs.ErrConflict) errorResponses.ErrorResponse { + return errorResponses.ErrorResponse{ + Type: errorResponses.ErrorTypeGeneric, + Title: err.Message, + Errors: err.Errors, + } +} diff --git a/pkg/restErrorHandler/rest_error_handler_test.go b/pkg/restErrorHandler/rest_error_handler_test.go new file mode 100644 index 0000000..94ed0ee --- /dev/null +++ b/pkg/restErrorHandler/rest_error_handler_test.go @@ -0,0 +1,102 @@ +// SPDX-FileCopyrightText: 2024 Greenbone AG +// +// SPDX-License-Identifier: AGPL-3.0-or-later + +package restErrorHandler + +import ( + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/greenbone/opensight-notification-service/pkg/errs" + "github.com/greenbone/opensight-notification-service/pkg/helper" + + "github.com/greenbone/opensight-golang-libraries/pkg/errorResponses" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" +) + +var someValidationError = errs.ErrValidation{ + Message: "some validation error", + Errors: map[string]string{"field1": "issue with field1", "field2": "issue with field2"}, +} +var someConflictError = errs.ErrConflict{Errors: map[string]string{"test": "value already exists", "test2": "value already exists"}} + +func TestErrorHandler(t *testing.T) { + tests := []struct { + name string + err error + wantStatusCode int + wantErrorResponse *errorResponses.ErrorResponse + }{ + { + name: "hide internal errors from rest clients", + err: errors.New("some internal error"), + wantStatusCode: http.StatusInternalServerError, + wantErrorResponse: &errorResponses.ErrorInternalResponse, + }, + { + name: "not found error", + err: fmt.Errorf("wrapped: %w", errs.ErrItemNotFound), + wantStatusCode: http.StatusNotFound, + }, + { + name: "conflict error", + err: fmt.Errorf("wrapped: %w", &someConflictError), + wantStatusCode: http.StatusConflict, + wantErrorResponse: helper.ToPtr(ErrConflictToResponse(someConflictError)), + }, + { + name: "validation error", + err: fmt.Errorf("wrapped: %w", &someValidationError), + wantStatusCode: http.StatusBadRequest, + wantErrorResponse: helper.ToPtr(ErrValidationToResponse(someValidationError)), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotResponse := httptest.NewRecorder() + gc, _ := gin.CreateTestContext(gotResponse) + // pretend a request has happened + gc.Request = httptest.NewRequest(http.MethodGet, "/some/path", nil) + + ErrorHandler(gc, "some specific log message", tt.err) + + // compare status code + assert.Equal(t, tt.wantStatusCode, gotResponse.Code) + + if tt.wantErrorResponse != nil { // compare responses + gotErrorResponse, err := io.ReadAll(gotResponse.Body) + if err != nil { + t.Error("could not read response body: %w", err) + return + } + wantResponseJson, err := json.Marshal(*tt.wantErrorResponse) + if err != nil { + t.Error("could not parse error response to json: %w", err) + } + assert.JSONEq(t, string(wantResponseJson), string(gotErrorResponse)) + } + }) + } +} + +func TestErrConflictToResponse(t *testing.T) { + errConflictResponse := ErrConflictToResponse(someConflictError) + + assert.Equal(t, errorResponses.ErrorTypeGeneric, errConflictResponse.Type) + assert.Equal(t, someConflictError.Errors, errConflictResponse.Errors) +} + +func TestErrValidationToResponse(t *testing.T) { + errValidationResponse := ErrValidationToResponse(someValidationError) + + assert.Equal(t, errorResponses.ErrorTypeValidation, errValidationResponse.Type) + assert.Equal(t, someValidationError.Errors, errValidationResponse.Errors) +} diff --git a/pkg/web/helper/parameter_parsing.go b/pkg/web/helper/parameter_parsing.go new file mode 100644 index 0000000..65a3f2c --- /dev/null +++ b/pkg/web/helper/parameter_parsing.go @@ -0,0 +1,133 @@ +// SPDX-FileCopyrightText: 2024 Greenbone AG +// +// SPDX-License-Identifier: AGPL-3.0-or-later + +package helper + +import ( + "fmt" + "strings" + + "github.com/gin-gonic/gin" + "github.com/greenbone/opensight-golang-libraries/pkg/query" + "github.com/greenbone/opensight-golang-libraries/pkg/query/filter" + "github.com/greenbone/opensight-golang-libraries/pkg/query/paging" + "github.com/greenbone/opensight-golang-libraries/pkg/query/sorting" + "github.com/greenbone/opensight-notification-service/pkg/errs" + + "github.com/rs/zerolog/log" + "github.com/samber/lo" +) + +const ( + DefaultLimit uint64 = 50 +) + +// PrepareResultSelector converts the common query parameters of a List endpoint to a ResultSelector. +// Parameters are verified where possible and defaults are set. +func PrepareResultSelector(gc *gin.Context, filterOptions []filter.RequestOption, allowedSortFields []string, defaults query.ResultSelector) (resultSelector query.ResultSelector, err error) { + resultSelector = query.ResultSelector{} + if err = gc.ShouldBindJSON(&resultSelector); err != nil { + return resultSelector, &errs.ErrValidation{Message: fmt.Sprintf("can't parse body: %v", err)} + } + + //apply defaults + resultSelector = applyDefaults(resultSelector, defaults) + if resultSelector.Sorting != nil { // TODO: implement sorting, and remove this if-branch + resultSelector.Sorting = nil + log.Warn().Msg("sorting not yet supported, dropping sort property from request") + } + + err = validate(resultSelector, filterOptions, allowedSortFields) + if err != nil { + return resultSelector, &errs.ErrValidation{Message: fmt.Sprintf("error validating result selector %v", err)} + } + + return resultSelector, nil +} + +// ResultSelectorDefaults holds default result selectors +func ResultSelectorDefaults() query.ResultSelector { + return query.ResultSelector{ + Paging: &paging.Request{ + PageIndex: 0, + PageSize: int(DefaultLimit), + }, + Sorting: &sorting.Request{ // TODO: implement sorting, then we need a different default per endpoint instead of this dummy + SortColumn: "", + SortDirection: "", + }, + } +} + +func applyDefaults(resultSelector query.ResultSelector, defaults query.ResultSelector) query.ResultSelector { + if resultSelector.Paging == nil { + resultSelector.Paging = defaults.Paging + } + if resultSelector.Sorting == nil { + resultSelector.Sorting = defaults.Sorting + } else { + if resultSelector.Sorting.SortColumn == "" { + resultSelector.Sorting.SortColumn = defaults.Sorting.SortColumn + resultSelector.Sorting.SortDirection = defaults.Sorting.SortDirection + } + + if resultSelector.Sorting.SortDirection == "" { + resultSelector.Sorting.SortDirection = defaults.Sorting.SortDirection + } + } + return resultSelector +} + +func validate(resultSelector query.ResultSelector, filterOptions []filter.RequestOption, allowedSortFields []string) error { + err := filter.ValidateFilter(resultSelector.Filter, filterOptions) + if err != nil { + return err + } + + err = validateSorting(resultSelector.Sorting, allowedSortFields) + if err != nil { + return err + } + + err = validatePaging(resultSelector.Paging) + if err != nil { + return err + } + + return nil +} + +func validateSorting(sortingRequest *sorting.Request, allowedSortFields []string) error { + if sortingRequest == nil { + return nil + } + + err := sorting.ValidateSortingRequest(sortingRequest) + if err != nil { + return err + } + + if !lo.Contains(allowedSortFields, sortingRequest.SortColumn) { + return sorting.NewSortingError(fmt.Sprintf("%s is no valid sort column, possible values: %s", + sortingRequest.SortColumn, strings.Join(allowedSortFields, ", "))) + } + + return nil +} + +func validatePaging(pagingRequest *paging.Request) error { + if pagingRequest == nil { + return nil + } + + if pagingRequest.PageIndex < 0 { + return paging.NewPagingError("%d is no valid page index, it must be >= 0", pagingRequest.PageIndex) + } + + if pagingRequest.PageSize <= 0 { + return paging.NewPagingError("%d is no valid page size, it must be > 0", pagingRequest.PageSize) + } + + return nil +} From 7ae85d4a9d308e70a56a2798baef7212a10f47ee Mon Sep 17 00:00:00 2001 From: Marius Goetze Date: Tue, 30 Apr 2024 13:25:08 +0200 Subject: [PATCH 3/3] implement create and list notification endpoints --- .github/workflows/push.yml | 2 +- .mockery.yaml | 10 + Dockerfile | 8 +- Makefile | 6 + go.mod | 6 + go.sum | 5 + pkg/helper/helper.go | 18 ++ pkg/helper/helper_test.go | 43 ++++ pkg/port/mocks/notification_service.go | 162 +++++++++++++ .../notificationRepository.go | 55 ++++- .../notification_db_models.go | 72 ++++++ .../notificationService.go | 4 +- pkg/web/api.go | 7 + pkg/web/cors_middleware.go | 24 ++ .../notificationController.go | 60 ++++- .../notificationController_test.go | 216 ++++++++++++++++++ pkg/web/testhelper/helper.go | 48 ++++ pkg/web/web.go | 2 +- 18 files changed, 738 insertions(+), 10 deletions(-) create mode 100644 .mockery.yaml create mode 100644 pkg/helper/helper.go create mode 100644 pkg/helper/helper_test.go create mode 100644 pkg/port/mocks/notification_service.go create mode 100644 pkg/repository/notificationrepository/notification_db_models.go create mode 100644 pkg/web/cors_middleware.go create mode 100644 pkg/web/notificationcontroller/notificationController_test.go create mode 100644 pkg/web/testhelper/helper.go diff --git a/.github/workflows/push.yml b/.github/workflows/push.yml index 23d0844..bd0f3c9 100644 --- a/.github/workflows/push.yml +++ b/.github/workflows/push.yml @@ -18,7 +18,7 @@ jobs: name: OpenSight Notification Service uses: greenbone/workflows/.github/workflows/helm-container-build-push-3rd-gen.yml@main with: - # helm-chart: opensight-notification-service # TODO: reference relevant helm chart as soon as it exists + helm-chart: opensight-notification-service image-url: ${{ github.repository }} image-labels: | org.opencontainers.image.vendor=Greenbone diff --git a/.mockery.yaml b/.mockery.yaml new file mode 100644 index 0000000..5bfd657 --- /dev/null +++ b/.mockery.yaml @@ -0,0 +1,10 @@ +with-expecter: true +dir: "{{.InterfaceDir}}/mocks" +outpkg: "mocks" +mockname: "{{.InterfaceName}}" +filename: "{{.InterfaceNameSnake}}.go" + +packages: + github.com/greenbone/opensight-notification-service/pkg/port: + interfaces: + NotificationService: diff --git a/Dockerfile b/Dockerfile index 6b2a4c7..d8c717b 100644 --- a/Dockerfile +++ b/Dockerfile @@ -21,9 +21,15 @@ COPY pkg/web pkg/web COPY pkg/models pkg/models RUN make api-docs -# copy rest of the source files and build +# copy rest of the source files COPY cmd cmd COPY pkg pkg + +# (re)generate mocks +COPY .mockery.yaml .mockery.yaml +RUN make generate-code + +# test and build RUN make test RUN make build diff --git a/Makefile b/Makefile index 3383b58..61e4cfc 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,7 @@ all: api-docs build test SWAG = github.com/swaggo/swag/cmd/swag@v1.16.2 +MOCKERY = github.com/vektra/mockery/v2@v2.40.2 GOLANGCI-LINT = github.com/golangci/golangci-lint/cmd/golangci-lint@latest .PHONY: lint @@ -10,6 +11,11 @@ lint: .PHONY: install-code-generation-tools install-code-generation-tools: go install $(SWAG) + go install $(MOCKERY) + +.PHONY: generate-code +generate-code: # create mocks + go run $(MOCKERY) .PHONY: api-docs api-docs: diff --git a/go.mod b/go.mod index 763c466..5352a02 100644 --- a/go.mod +++ b/go.mod @@ -3,12 +3,15 @@ module github.com/greenbone/opensight-notification-service go 1.22.2 require ( + github.com/gin-contrib/cors v1.7.2 github.com/gin-contrib/logger v1.1.2 github.com/gin-gonic/gin v1.10.0 github.com/golang-migrate/migrate/v4 v4.17.1 github.com/greenbone/opensight-golang-libraries v1.3.1 github.com/jmoiron/sqlx v1.4.0 github.com/rs/zerolog v1.32.0 + github.com/samber/lo v1.39.0 + github.com/stretchr/testify v1.9.0 github.com/swaggo/files v1.0.1 github.com/swaggo/gin-swagger v1.6.0 github.com/swaggo/swag v1.16.3 @@ -20,6 +23,7 @@ require ( github.com/bytedance/sonic/loader v0.1.1 // indirect github.com/cloudwego/base64x v0.1.4 // indirect github.com/cloudwego/iasm v0.2.0 // indirect + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/fsnotify/fsnotify v1.7.0 // indirect github.com/gabriel-vasile/mimetype v1.4.3 // indirect github.com/gin-contrib/sse v0.1.0 // indirect @@ -51,6 +55,7 @@ require ( github.com/modern-go/reflect2 v1.0.2 // indirect github.com/pelletier/go-toml/v2 v2.2.2 // indirect github.com/pkg/errors v0.9.1 // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/sagikazarmark/locafero v0.4.0 // indirect github.com/sagikazarmark/slog-shim v0.1.0 // indirect github.com/sourcegraph/conc v0.3.0 // indirect @@ -58,6 +63,7 @@ require ( github.com/spf13/cast v1.6.0 // indirect github.com/spf13/pflag v1.0.5 // indirect github.com/spf13/viper v1.18.2 // indirect + github.com/stretchr/objx v0.5.2 // indirect github.com/subosito/gotenv v1.6.0 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.2.12 // indirect diff --git a/go.sum b/go.sum index f8692c0..b3c0e8c 100644 --- a/go.sum +++ b/go.sum @@ -39,6 +39,8 @@ github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nos github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0= github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk= +github.com/gin-contrib/cors v1.7.2 h1:oLDHxdg8W/XDoN/8zamqk/Drgt4oVZDvaV0YmvVICQw= +github.com/gin-contrib/cors v1.7.2/go.mod h1:SUJVARKgQ40dmrzgXEVxj2m7Ig1v1qIboQkPDTQ9t2E= github.com/gin-contrib/gzip v0.0.6 h1:NjcunTcGAj5CO1gn4N8jHOSIeRFHIbn51z6K+xaN4d4= github.com/gin-contrib/gzip v0.0.6/go.mod h1:QOJlmV2xmayAjkNS2Y8NQsMneuRShOU/kjovCXNuzzk= github.com/gin-contrib/logger v1.1.2 h1:+y8VHqn5zAsAFnW6y/6GF93eXaCPFv/XPdqUCFeBMRg= @@ -163,6 +165,8 @@ github.com/sagikazarmark/locafero v0.4.0 h1:HApY1R9zGo4DBgr7dqsTH/JJxLTTsOt7u6ke github.com/sagikazarmark/locafero v0.4.0/go.mod h1:Pe1W6UlPYUk/+wc/6KFhbORCfqzgYEpgQ3O5fPuL3H4= github.com/sagikazarmark/slog-shim v0.1.0 h1:diDBnUNK9N/354PgrxMywXnAwEr1QZcOr6gto+ugjYE= github.com/sagikazarmark/slog-shim v0.1.0/go.mod h1:SrcSrq8aKtyuqEI1uvTDTK1arOWRIczQRv+GVI1AkeQ= +github.com/samber/lo v1.39.0 h1:4gTz1wUhNYLhFSKl6O+8peW0v2F4BCY034GRpU9WnuA= +github.com/samber/lo v1.39.0/go.mod h1:+m/ZKRl6ClXCE2Lgf3MsQlWfh4bn1bz6CXEOxnEXnEA= github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo= github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0= github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8= @@ -176,6 +180,7 @@ github.com/spf13/viper v1.18.2/go.mod h1:EKmWIqdnk5lOcmR72yw6hS+8OPYcwD0jteitLMV github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= diff --git a/pkg/helper/helper.go b/pkg/helper/helper.go new file mode 100644 index 0000000..87b4cc6 --- /dev/null +++ b/pkg/helper/helper.go @@ -0,0 +1,18 @@ +// SPDX-FileCopyrightText: 2024 Greenbone AG +// +// SPDX-License-Identifier: AGPL-3.0-or-later + +package helper + +func ToPtr[T any](val T) *T { + return &val +} + +// SafeDereference return the value ptr points to. If ptr is nil, it returns the default value if the type instead. +func SafeDereference[T any](ptr *T) T { + var zeroT T + if ptr == nil { + return zeroT + } + return *ptr +} diff --git a/pkg/helper/helper_test.go b/pkg/helper/helper_test.go new file mode 100644 index 0000000..ae47632 --- /dev/null +++ b/pkg/helper/helper_test.go @@ -0,0 +1,43 @@ +// SPDX-FileCopyrightText: 2024 Greenbone AG +// +// SPDX-License-Identifier: AGPL-3.0-or-later + +package helper + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestSafeDereference(t *testing.T) { + type someStruct struct { + foo, bar string + } + + zeroValue := someStruct{} + nonZeroValue := someStruct{foo: "aa", bar: "bb"} + + tests := []struct { + name string + input *someStruct + want someStruct + }{ + { + name: "yields zero value with nil pointer", + input: nil, + want: zeroValue, + }, + { + name: "yields object pointed on for non-nil pointer", + input: &nonZeroValue, + want: nonZeroValue, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := SafeDereference(tt.input) + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/pkg/port/mocks/notification_service.go b/pkg/port/mocks/notification_service.go new file mode 100644 index 0000000..0161cf1 --- /dev/null +++ b/pkg/port/mocks/notification_service.go @@ -0,0 +1,162 @@ +// Code generated by mockery v2.40.2. DO NOT EDIT. + +package mocks + +import ( + context "context" + + models "github.com/greenbone/opensight-notification-service/pkg/models" + mock "github.com/stretchr/testify/mock" + + query "github.com/greenbone/opensight-golang-libraries/pkg/query" +) + +// NotificationService is an autogenerated mock type for the NotificationService type +type NotificationService struct { + mock.Mock +} + +type NotificationService_Expecter struct { + mock *mock.Mock +} + +func (_m *NotificationService) EXPECT() *NotificationService_Expecter { + return &NotificationService_Expecter{mock: &_m.Mock} +} + +// CreateNotification provides a mock function with given fields: ctx, notificationIn +func (_m *NotificationService) CreateNotification(ctx context.Context, notificationIn models.Notification) (models.Notification, error) { + ret := _m.Called(ctx, notificationIn) + + if len(ret) == 0 { + panic("no return value specified for CreateNotification") + } + + var r0 models.Notification + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, models.Notification) (models.Notification, error)); ok { + return rf(ctx, notificationIn) + } + if rf, ok := ret.Get(0).(func(context.Context, models.Notification) models.Notification); ok { + r0 = rf(ctx, notificationIn) + } else { + r0 = ret.Get(0).(models.Notification) + } + + if rf, ok := ret.Get(1).(func(context.Context, models.Notification) error); ok { + r1 = rf(ctx, notificationIn) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// NotificationService_CreateNotification_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CreateNotification' +type NotificationService_CreateNotification_Call struct { + *mock.Call +} + +// CreateNotification is a helper method to define mock.On call +// - ctx context.Context +// - notificationIn models.Notification +func (_e *NotificationService_Expecter) CreateNotification(ctx interface{}, notificationIn interface{}) *NotificationService_CreateNotification_Call { + return &NotificationService_CreateNotification_Call{Call: _e.mock.On("CreateNotification", ctx, notificationIn)} +} + +func (_c *NotificationService_CreateNotification_Call) Run(run func(ctx context.Context, notificationIn models.Notification)) *NotificationService_CreateNotification_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(models.Notification)) + }) + return _c +} + +func (_c *NotificationService_CreateNotification_Call) Return(notification models.Notification, err error) *NotificationService_CreateNotification_Call { + _c.Call.Return(notification, err) + return _c +} + +func (_c *NotificationService_CreateNotification_Call) RunAndReturn(run func(context.Context, models.Notification) (models.Notification, error)) *NotificationService_CreateNotification_Call { + _c.Call.Return(run) + return _c +} + +// ListNotifications provides a mock function with given fields: ctx, resultSelector +func (_m *NotificationService) ListNotifications(ctx context.Context, resultSelector query.ResultSelector) ([]models.Notification, uint64, error) { + ret := _m.Called(ctx, resultSelector) + + if len(ret) == 0 { + panic("no return value specified for ListNotifications") + } + + var r0 []models.Notification + var r1 uint64 + var r2 error + if rf, ok := ret.Get(0).(func(context.Context, query.ResultSelector) ([]models.Notification, uint64, error)); ok { + return rf(ctx, resultSelector) + } + if rf, ok := ret.Get(0).(func(context.Context, query.ResultSelector) []models.Notification); ok { + r0 = rf(ctx, resultSelector) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]models.Notification) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, query.ResultSelector) uint64); ok { + r1 = rf(ctx, resultSelector) + } else { + r1 = ret.Get(1).(uint64) + } + + if rf, ok := ret.Get(2).(func(context.Context, query.ResultSelector) error); ok { + r2 = rf(ctx, resultSelector) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// NotificationService_ListNotifications_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListNotifications' +type NotificationService_ListNotifications_Call struct { + *mock.Call +} + +// ListNotifications is a helper method to define mock.On call +// - ctx context.Context +// - resultSelector query.ResultSelector +func (_e *NotificationService_Expecter) ListNotifications(ctx interface{}, resultSelector interface{}) *NotificationService_ListNotifications_Call { + return &NotificationService_ListNotifications_Call{Call: _e.mock.On("ListNotifications", ctx, resultSelector)} +} + +func (_c *NotificationService_ListNotifications_Call) Run(run func(ctx context.Context, resultSelector query.ResultSelector)) *NotificationService_ListNotifications_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(query.ResultSelector)) + }) + return _c +} + +func (_c *NotificationService_ListNotifications_Call) Return(notifications []models.Notification, totalResult uint64, err error) *NotificationService_ListNotifications_Call { + _c.Call.Return(notifications, totalResult, err) + return _c +} + +func (_c *NotificationService_ListNotifications_Call) RunAndReturn(run func(context.Context, query.ResultSelector) ([]models.Notification, uint64, error)) *NotificationService_ListNotifications_Call { + _c.Call.Return(run) + return _c +} + +// NewNotificationService creates a new instance of NotificationService. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewNotificationService(t interface { + mock.TestingT + Cleanup(func()) +}) *NotificationService { + mock := &NotificationService{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/pkg/repository/notificationrepository/notificationRepository.go b/pkg/repository/notificationrepository/notificationRepository.go index 6ca0e2c..f90f2d0 100644 --- a/pkg/repository/notificationrepository/notificationRepository.go +++ b/pkg/repository/notificationrepository/notificationRepository.go @@ -7,6 +7,7 @@ package notificationrepository import ( "context" "errors" + "fmt" "github.com/greenbone/opensight-golang-libraries/pkg/query" "github.com/greenbone/opensight-notification-service/pkg/models" @@ -28,10 +29,60 @@ func NewNotificationRepository(db *sqlx.DB) (port.NotificationRepository, error) return client, nil } -func (r *NotificationRepository) ListNotifications(ctx context.Context, resultSelector query.ResultSelector) (notifications []models.Notification, totalResult uint64, err error) { +func (r *NotificationRepository) ListNotifications(ctx context.Context, resultSelector query.ResultSelector) (notifications []models.Notification, totalResults uint64, err error) { + var rows []notificationRow + query := unfilteredListNotificationsQuery + + if resultSelector.Paging != nil { // TODO: add support for filtering and sorting + limit := resultSelector.Paging.PageSize + offset := resultSelector.Paging.PageIndex * resultSelector.Paging.PageSize + query += fmt.Sprint(` LIMIT `, limit, ` OFFSET `, offset) + } + + err = r.client.SelectContext(ctx, &rows, query) + if err != nil { + err = fmt.Errorf("error getting notifications from database: %w", err) + return + } + + err = r.client.QueryRowxContext(ctx, `SELECT count(*) FROM `+notificationsTable).Scan(&totalResults) + if err != nil { + err = fmt.Errorf("error getting total results: %w", err) + return + } + + notifications = make([]models.Notification, 0, len(rows)) + for _, row := range rows { + notification, err := row.ToNotificationModel() + if err != nil { + return nil, 0, fmt.Errorf("failed to transform notification db entry: %w", err) + } + notifications = append(notifications, notification) + } return } func (r *NotificationRepository) CreateNotification(ctx context.Context, notificationIn models.Notification) (notification models.Notification, err error) { - return + insertRow, err := toNotificationRow(notificationIn) + if err != nil { + return notification, fmt.Errorf("invalid argument for inserting notification into database: %w", err) + } + + createNotificationStatement, err := r.client.PrepareNamedContext(ctx, createNotificationQuery) + if err != nil { + return notification, fmt.Errorf("could not prepare sql statement: %w", err) + } + + var row notificationRow + err = createNotificationStatement.QueryRowxContext(ctx, insertRow).StructScan(&row) + if err != nil { + return notification, fmt.Errorf("could not insert into database: %w", err) + } + + notification, err = row.ToNotificationModel() + if err != nil { + return notification, fmt.Errorf("failed to transform notification db entry to model: %w", err) + } + + return notification, nil } diff --git a/pkg/repository/notificationrepository/notification_db_models.go b/pkg/repository/notificationrepository/notification_db_models.go new file mode 100644 index 0000000..ebd1033 --- /dev/null +++ b/pkg/repository/notificationrepository/notification_db_models.go @@ -0,0 +1,72 @@ +// SPDX-FileCopyrightText: 2024 Greenbone AG +// +// SPDX-License-Identifier: AGPL-3.0-or-later + +package notificationrepository + +import ( + "encoding/json" + + "github.com/greenbone/opensight-notification-service/pkg/helper" + "github.com/greenbone/opensight-notification-service/pkg/models" +) + +const notificationsTable = "notification_service.notifications" + +const createNotificationQuery = `INSERT INTO ` + notificationsTable + ` (origin, origin_uri, timestamp, title, detail, level, custom_fields) VALUES (:origin, :origin_uri, :timestamp, :title, :detail, :level, :custom_fields) RETURNING *` +const unfilteredListNotificationsQuery = `SELECT * FROM ` + notificationsTable + +type notificationRow struct { + Id string `db:"id"` + Origin string `db:"origin"` + OriginUri *string `db:"origin_uri"` + Timestamp string `db:"timestamp"` + Title string `db:"title"` + Detail string `db:"detail"` + Level string `db:"level"` + CustomFields []byte `db:"custom_fields"` +} + +func toNotificationRow(n models.Notification) (notificationRow, error) { + var empty notificationRow + + customFieldsSerialized, err := json.Marshal(n.CustomFields) + if err != nil { + return empty, err // TODO: return validation error ? + } + + notificationRow := notificationRow{ + Id: n.Id, + Origin: n.Origin, + OriginUri: &n.OriginUri, + Timestamp: n.Timestamp, + Title: n.Title, + Detail: n.Detail, + Level: n.Level, + CustomFields: customFieldsSerialized, + } + + return notificationRow, nil +} + +func (n *notificationRow) ToNotificationModel() (models.Notification, error) { + var empty models.Notification + + notification := models.Notification{ + Id: n.Id, + Origin: n.Origin, + OriginUri: helper.SafeDereference(n.OriginUri), + Timestamp: n.Timestamp, + Title: n.Title, + Detail: n.Detail, + Level: n.Level, + // CustomFields is set below + } + + err := json.Unmarshal(n.CustomFields, ¬ification.CustomFields) + if err != nil { + return empty, err + } + + return notification, nil +} diff --git a/pkg/services/notificationservice/notificationService.go b/pkg/services/notificationservice/notificationService.go index eb4cd36..4282ce4 100644 --- a/pkg/services/notificationservice/notificationService.go +++ b/pkg/services/notificationservice/notificationService.go @@ -21,9 +21,9 @@ func NewNotificationService(store port.NotificationRepository) *NotificationServ } func (s *NotificationService) ListNotifications(ctx context.Context, resultSelector query.ResultSelector) (notifications []models.Notification, totalResult uint64, err error) { - return + return s.store.ListNotifications(ctx, resultSelector) } func (s *NotificationService) CreateNotification(ctx context.Context, notificationIn models.Notification) (notification models.Notification, err error) { - return + return s.store.CreateNotification(ctx, notificationIn) } diff --git a/pkg/web/api.go b/pkg/web/api.go index 4f7562c..23f07bf 100644 --- a/pkg/web/api.go +++ b/pkg/web/api.go @@ -23,6 +23,13 @@ import ( // @externalDocs.description OpenAPI // @externalDocs.url https://swagger.io/resources/open-api/ +const ( + // APIVersion is the current version of the web API + APIVersion = "1.0" +) + +const APIVersionKey = "api-version" + func RegisterSwaggerDocsRoute(docsRouter gin.IRouter) { apiDocsHandler := swagger.GetApiDocsHandler(docs.SwaggerInfonotificationservice) docsRouter.GET("/notification-service/*any", apiDocsHandler) diff --git a/pkg/web/cors_middleware.go b/pkg/web/cors_middleware.go new file mode 100644 index 0000000..4bbb800 --- /dev/null +++ b/pkg/web/cors_middleware.go @@ -0,0 +1,24 @@ +// SPDX-FileCopyrightText: 2024 Greenbone AG +// +// SPDX-License-Identifier: AGPL-3.0-or-later + +package web + +import ( + "time" + + "github.com/gin-contrib/cors" + "github.com/gin-gonic/gin" +) + +func getCors() gin.HandlerFunc { + CORSHandler := cors.New(cors.Config{ + AllowAllOrigins: true, // TODO: should be more restrictive + AllowMethods: []string{"GET", "POST", "PUT", "DELETE", "HEAD", "OPTIONS"}, + AllowHeaders: []string{"Origin", "Content-Length", "Content-Type", "Authorization"}, + AllowCredentials: false, + MaxAge: 12 * time.Hour, + }) + + return CORSHandler +} diff --git a/pkg/web/notificationcontroller/notificationController.go b/pkg/web/notificationcontroller/notificationController.go index 8d51804..2e42bf6 100644 --- a/pkg/web/notificationcontroller/notificationController.go +++ b/pkg/web/notificationcontroller/notificationController.go @@ -5,12 +5,20 @@ package notificationcontroller import ( + "fmt" "net/http" "github.com/gin-gonic/gin" + "github.com/greenbone/opensight-golang-libraries/pkg/query" _ "github.com/greenbone/opensight-golang-libraries/pkg/query" + "github.com/greenbone/opensight-golang-libraries/pkg/query/filter" + "github.com/greenbone/opensight-notification-service/pkg/errs" + "github.com/greenbone/opensight-notification-service/pkg/models" _ "github.com/greenbone/opensight-notification-service/pkg/models" "github.com/greenbone/opensight-notification-service/pkg/port" + "github.com/greenbone/opensight-notification-service/pkg/restErrorHandler" + "github.com/greenbone/opensight-notification-service/pkg/web" + "github.com/greenbone/opensight-notification-service/pkg/web/helper" ) type NotificationController struct { @@ -47,7 +55,21 @@ func (c *NotificationController) registerRoutes(router gin.IRouter) { // @Header all {string} api-version "API version" // @Router /notifications [post] func (c *NotificationController) CreateNotification(gc *gin.Context) { - gc.Status(http.StatusNotImplemented) + gc.Header(web.APIVersionKey, web.APIVersion) + + notification, err := parseAndValidateNotification(gc) + if err != nil { + restErrorHandler.ErrorHandler(gc, "could not get notification", err) + return + } + + notificationNew, err := c.notificationService.CreateNotification(gc, notification) + if err != nil { + restErrorHandler.ErrorHandler(gc, "could not create notification", err) + return + } + + gc.JSON(http.StatusCreated, query.ResponseWithMetadata[models.Notification]{Data: notificationNew}) } // ListNotifications @@ -63,7 +85,24 @@ func (c *NotificationController) CreateNotification(gc *gin.Context) { // @Header all {string} api-version "API version" // @Router /notifications [put] func (c *NotificationController) ListNotifications(gc *gin.Context) { - gc.Status(http.StatusNotImplemented) + gc.Header(web.APIVersionKey, web.APIVersion) + + resultSelector, err := helper.PrepareResultSelector(gc, []filter.RequestOption{}, []string{}, helper.ResultSelectorDefaults()) + if err != nil { + restErrorHandler.ErrorHandler(gc, "could not prepare result selector", err) + return + } + + notifications, totalResults, err := c.notificationService.ListNotifications(gc, resultSelector) + if err != nil { + restErrorHandler.ErrorHandler(gc, "could not list notifications", err) + return + } + + gc.JSON(http.StatusOK, query.ResponseListWithMetadata[models.Notification]{ + Metadata: query.NewMetadata(resultSelector, totalResults), + Data: notifications, + }) } // GetOptions @@ -77,5 +116,20 @@ func (c *NotificationController) ListNotifications(gc *gin.Context) { // @Header all {string} api-version "API version" // @Router /notifications/options [get] func (c *NotificationController) GetOptions(gc *gin.Context) { - gc.Status(http.StatusNotImplemented) + gc.Header(web.APIVersionKey, web.APIVersion) + + // for now we don't support filtering + permittedFilters := []query.FilterOption{} + + gc.JSON(http.StatusOK, query.ResponseWithMetadata[[]query.FilterOption]{Data: permittedFilters}) +} + +func parseAndValidateNotification(gc *gin.Context) (notification models.Notification, err error) { // TODO: refine + var empty models.Notification + err = gc.ShouldBindJSON(¬ification) + if err != nil { + return empty, &errs.ErrValidation{Message: fmt.Sprintf("can't parse body: %v", err)} + } + + return notification, nil } diff --git a/pkg/web/notificationcontroller/notificationController_test.go b/pkg/web/notificationcontroller/notificationController_test.go new file mode 100644 index 0000000..d537a87 --- /dev/null +++ b/pkg/web/notificationcontroller/notificationController_test.go @@ -0,0 +1,216 @@ +// SPDX-FileCopyrightText: 2024 Greenbone AG +// +// SPDX-License-Identifier: AGPL-3.0-or-later + +package notificationcontroller + +import ( + "errors" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/gin-gonic/gin" + "github.com/greenbone/opensight-golang-libraries/pkg/query" + "github.com/greenbone/opensight-golang-libraries/pkg/query/filter" + "github.com/greenbone/opensight-golang-libraries/pkg/query/paging" + "github.com/greenbone/opensight-notification-service/pkg/helper" + "github.com/greenbone/opensight-notification-service/pkg/models" + "github.com/greenbone/opensight-notification-service/pkg/port/mocks" + "github.com/greenbone/opensight-notification-service/pkg/web/testhelper" + "github.com/stretchr/testify/mock" +) + +func getNotification() models.Notification { + return models.Notification{ + Id: "57fe22b8-89a4-445f-b6c7-ef9ea724ea48", + Timestamp: time.Time{}.Format(time.RFC3339Nano), + Origin: "Example Task XY", + Title: "Example Task XY failed", + Detail: "Example Task XY failed because ...", + Level: "error", + } +} + +func TestListNotifications(t *testing.T) { + someNotification := getNotification() + + type mockReturn struct { + items []models.Notification + totalResults uint64 + err error + } + type want struct { + serviceCall bool + serviceArg query.ResultSelector + responseCode int + responseParsed query.ResponseListWithMetadata[models.Notification] + } + + resultSelectorWithoutFilter := query.ResultSelector{ + Filter: &filter.Request{Operator: filter.LogicOperatorAnd}, + Paging: &paging.Request{PageSize: 10}, + } + + tests := []struct { + name string + requestBody query.ResultSelector + mockReturn mockReturn + want want + }{ + { + name: "service is called with correct result selector", + requestBody: resultSelectorWithoutFilter, + mockReturn: mockReturn{ + items: []models.Notification{someNotification}, + totalResults: 1, + err: nil, + }, + want: want{ + serviceCall: true, + serviceArg: resultSelectorWithoutFilter, + responseCode: http.StatusOK, + responseParsed: query.ResponseListWithMetadata[models.Notification]{ + Data: []models.Notification{someNotification}, + Metadata: query.NewMetadata(resultSelectorWithoutFilter, 1), + }, + }, + }, + { + name: "return internal server error on service failure", + requestBody: resultSelectorWithoutFilter, + mockReturn: mockReturn{err: errors.New("internal service error")}, + want: want{ + serviceCall: true, + serviceArg: resultSelectorWithoutFilter, + responseCode: http.StatusInternalServerError, + }, + }, + { + name: "return bad request on invalid input", + requestBody: query.ResultSelector{Paging: &paging.Request{PageSize: -1}}, // invalid page size + want: want{ + serviceCall: false, + responseCode: http.StatusBadRequest, + }, + }, + } + + requestUrl := "/notifications" + + gin.SetMode(gin.TestMode) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockNotificationService := mocks.NewNotificationService(t) + + // Create a new engine for testing + engine := gin.Default() + // constructor registers the routes + _ = NewNotificationController(&engine.RouterGroup, mockNotificationService) + + if tt.want.serviceCall { + mockNotificationService.EXPECT().ListNotifications(mock.Anything, tt.want.serviceArg). + Return(tt.mockReturn.items, tt.mockReturn.totalResults, tt.mockReturn.err). + Once() + } + + req, err := testhelper.NewJSONRequest(http.MethodPut, requestUrl, tt.requestBody) + if err != nil { + t.Error("could not build request", err) + return + } + + resp := httptest.NewRecorder() + engine.ServeHTTP(resp, req) + + testhelper.VerifyResponseWithMetadata(t, tt.want.responseCode, tt.want.responseParsed, resp) + }) + } +} + +func TestCreateNotification(t *testing.T) { + someNotification := getNotification() + someNotification.Id = "to be ignored" + + wantNotification := getNotification() + wantNotification.Id = "new id" + + type mockServiceReturn struct { + item models.Notification + err error + } + type want struct { + notificationServiceArg *models.Notification + responseCode int + responseParsed query.ResponseWithMetadata[models.Notification] + } + + tests := []struct { + name string + notificationToCreate models.Notification + mockServiceReturn mockServiceReturn + want want + }{ + { + name: "services are called with the correct parameters (read only fields don't affect outcome)", + notificationToCreate: someNotification, + mockServiceReturn: mockServiceReturn{item: wantNotification}, + want: want{ + notificationServiceArg: helper.ToPtr(someNotification), + responseCode: http.StatusCreated, + responseParsed: query.ResponseWithMetadata[models.Notification]{Data: wantNotification}, + }, + }, + { + name: "return internal server error on service failure", + notificationToCreate: someNotification, + mockServiceReturn: mockServiceReturn{item: models.Notification{}, err: errors.New("some internal error")}, + want: want{ + notificationServiceArg: helper.ToPtr(someNotification), + responseCode: http.StatusInternalServerError, + }, + }, + { + name: "don't create a notification if validation fails", + notificationToCreate: models.Notification{}, // invalid: mandatory parameters not set + want: want{ + responseCode: http.StatusBadRequest, + }, + }, + } + + httpMethod := http.MethodPost + requestUrl := "/notifications" + + gin.SetMode(gin.TestMode) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockNotificationService := mocks.NewNotificationService(t) + + // Create a new engine for testing + engine := gin.Default() + // constructor registers the routes + _ = NewNotificationController(&engine.RouterGroup, mockNotificationService) + + if tt.want.notificationServiceArg != nil { + mockNotificationService.EXPECT().CreateNotification(mock.Anything, *tt.want.notificationServiceArg). + Return(tt.mockServiceReturn.item, tt.mockServiceReturn.err). + Once() + } + + req, err := testhelper.NewJSONRequest(httpMethod, requestUrl, tt.notificationToCreate) + if err != nil { + t.Error("could not build request", err) + return + } + + resp := httptest.NewRecorder() + engine.ServeHTTP(resp, req) + + testhelper.VerifyResponseWithMetadata(t, tt.want.responseCode, tt.want.responseParsed, resp) + }) + } +} diff --git a/pkg/web/testhelper/helper.go b/pkg/web/testhelper/helper.go new file mode 100644 index 0000000..8ccea40 --- /dev/null +++ b/pkg/web/testhelper/helper.go @@ -0,0 +1,48 @@ +package testhelper + +import ( + "bytes" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/greenbone/opensight-golang-libraries/pkg/errorResponses" + "github.com/stretchr/testify/assert" +) + +func VerifyResponseWithMetadata[T any]( + t *testing.T, + wantResponseCode int, wantResponseParsed T, + gotResponse *httptest.ResponseRecorder) { + + assert.Equal(t, wantResponseCode, gotResponse.Code) + + if wantResponseCode >= 200 && wantResponseCode < 300 { + var gotBodyParsed T + err := json.NewDecoder(gotResponse.Body).Decode(&gotBodyParsed) + if err != nil { + t.Error("response is not valid json: %w", err) + return + } + assert.Equal(t, wantResponseParsed, gotBodyParsed) + } else if wantResponseCode == http.StatusInternalServerError { + var gotBodyParsed errorResponses.ErrorResponse + err := json.NewDecoder(gotResponse.Body).Decode(&gotBodyParsed) + if err != nil { + t.Error("response is not valid json: %w", err) + return + } + assert.Equal(t, errorResponses.ErrorInternalResponse, gotBodyParsed) + } +} + +// NewJSONRequest wraps [http.NewRequest] and sets the passed struct as body +func NewJSONRequest(method, url string, bodyAsStruct any) (*http.Request, error) { + body, err := json.Marshal(bodyAsStruct) + if err != nil { + return nil, fmt.Errorf("could not parse struct to json: %w", err) + } + return http.NewRequest(method, url, bytes.NewReader(body)) +} diff --git a/pkg/web/web.go b/pkg/web/web.go index 7c96462..474a683 100644 --- a/pkg/web/web.go +++ b/pkg/web/web.go @@ -11,6 +11,6 @@ import ( func NewWebEngine() *gin.Engine { ginWebEngine := gin.New() - ginWebEngine.Use(logger.SetLogger(), gin.Recovery()) + ginWebEngine.Use(logger.SetLogger(), gin.Recovery(), getCors()) return ginWebEngine }