diff --git a/.mockery.yaml b/.mockery.yaml index 76888c1cc..04ecf7655 100644 --- a/.mockery.yaml +++ b/.mockery.yaml @@ -60,3 +60,6 @@ packages: TemplateDao: config: filename: "templates_mock.go" + ModuleStreamsDao: + config: + filename: "modules_streams_mock.go" diff --git a/api/docs.go b/api/docs.go index a2c3854d1..67b241319 100644 --- a/api/docs.go +++ b/api/docs.go @@ -2287,6 +2287,68 @@ const docTemplate = `{ } } }, + "/snapshots/module_streams/search": { + "post": { + "description": "List modules and their streams for snapshots", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "tags": [ + "module_streams" + ], + "summary": "List modules and their streams for snapshots", + "operationId": "searchSnapshotModuleStreams", + "parameters": [ + { + "description": "request body", + "name": "body", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/api.SearchSnapshotModuleStreamsRequest" + } + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "type": "array", + "items": { + "$ref": "#/definitions/api.SearchModuleStreams" + } + } + }, + "400": { + "description": "Bad Request", + "schema": { + "$ref": "#/definitions/errors.ErrorResponse" + } + }, + "401": { + "description": "Unauthorized", + "schema": { + "$ref": "#/definitions/errors.ErrorResponse" + } + }, + "404": { + "description": "Not Found", + "schema": { + "$ref": "#/definitions/errors.ErrorResponse" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/errors.ErrorResponse" + } + } + } + } + }, "/snapshots/package_groups/names": { "post": { "description": "This enables users to search for package groups in a given list of snapshots.", @@ -4582,6 +4644,22 @@ const docTemplate = `{ } } }, + "api.SearchModuleStreams": { + "type": "object", + "properties": { + "module_name": { + "description": "Module name", + "type": "string" + }, + "streams": { + "description": "A list of stream related information for the module", + "type": "array", + "items": { + "$ref": "#/definitions/api.Stream" + } + } + } + }, "api.SearchPackageGroupResponse": { "type": "object", "properties": { @@ -4619,6 +4697,37 @@ const docTemplate = `{ } } }, + "api.SearchSnapshotModuleStreamsRequest": { + "type": "object", + "required": [ + "rpm_names", + "uuids" + ], + "properties": { + "rpm_names": { + "description": "List of rpm names to restrict returned modules", + "type": "array", + "items": { + "type": "string" + } + }, + "search": { + "description": "Search string to search module names", + "type": "string" + }, + "sort_by": { + "description": "SortBy sets the sort order of the result", + "type": "string" + }, + "uuids": { + "description": "List of snapshot UUIDs to search", + "type": "array", + "items": { + "type": "string" + } + } + } + }, "api.SnapshotCollectionResponse": { "type": "object", "properties": { @@ -4845,6 +4954,45 @@ const docTemplate = `{ } } }, + "api.Stream": { + "type": "object", + "properties": { + "arch": { + "description": "The Architecture of the rpm", + "type": "string" + }, + "context": { + "description": "Context of the module", + "type": "string" + }, + "description": { + "description": "Module description", + "type": "string" + }, + "name": { + "description": "Name of the module", + "type": "string" + }, + "profiles": { + "description": "Module profile data", + "type": "object", + "additionalProperties": { + "type": "array", + "items": { + "type": "string" + } + } + }, + "stream": { + "description": "Module stream version", + "type": "string" + }, + "version": { + "description": "The version of the rpm", + "type": "string" + } + } + }, "api.TaskInfoCollectionResponse": { "type": "object", "properties": { diff --git a/api/openapi.json b/api/openapi.json index a5c44302d..6e7cbfbf5 100644 --- a/api/openapi.json +++ b/api/openapi.json @@ -1039,6 +1039,22 @@ }, "type": "object" }, + "api.SearchModuleStreams": { + "properties": { + "module_name": { + "description": "Module name", + "type": "string" + }, + "streams": { + "description": "A list of stream related information for the module", + "items": { + "$ref": "#/components/schemas/api.Stream" + }, + "type": "array" + } + }, + "type": "object" + }, "api.SearchPackageGroupResponse": { "properties": { "description": { @@ -1076,6 +1092,37 @@ }, "type": "object" }, + "api.SearchSnapshotModuleStreamsRequest": { + "properties": { + "rpm_names": { + "description": "List of rpm names to restrict returned modules", + "items": { + "type": "string" + }, + "type": "array" + }, + "search": { + "description": "Search string to search module names", + "type": "string" + }, + "sort_by": { + "description": "SortBy sets the sort order of the result", + "type": "string" + }, + "uuids": { + "description": "List of snapshot UUIDs to search", + "items": { + "type": "string" + }, + "type": "array" + } + }, + "required": [ + "rpm_names", + "uuids" + ], + "type": "object" + }, "api.SnapshotCollectionResponse": { "properties": { "data": { @@ -1295,6 +1342,45 @@ }, "type": "object" }, + "api.Stream": { + "properties": { + "arch": { + "description": "The Architecture of the rpm", + "type": "string" + }, + "context": { + "description": "Context of the module", + "type": "string" + }, + "description": { + "description": "Module description", + "type": "string" + }, + "name": { + "description": "Name of the module", + "type": "string" + }, + "profiles": { + "additionalProperties": { + "items": { + "type": "string" + }, + "type": "array" + }, + "description": "Module profile data", + "type": "object" + }, + "stream": { + "description": "Module stream version", + "type": "string" + }, + "version": { + "description": "The version of the rpm", + "type": "string" + } + }, + "type": "object" + }, "api.TaskInfoCollectionResponse": { "properties": { "data": { @@ -4667,6 +4753,83 @@ ] } }, + "/snapshots/module_streams/search": { + "post": { + "description": "List modules and their streams for snapshots", + "operationId": "searchSnapshotModuleStreams", + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/api.SearchSnapshotModuleStreamsRequest" + } + } + }, + "description": "request body", + "required": true, + "x-originalParamName": "body" + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "items": { + "$ref": "#/components/schemas/api.SearchModuleStreams" + }, + "type": "array" + } + } + }, + "description": "OK" + }, + "400": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/errors.ErrorResponse" + } + } + }, + "description": "Bad Request" + }, + "401": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/errors.ErrorResponse" + } + } + }, + "description": "Unauthorized" + }, + "404": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/errors.ErrorResponse" + } + } + }, + "description": "Not Found" + }, + "500": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/errors.ErrorResponse" + } + } + }, + "description": "Internal Server Error" + } + }, + "summary": "List modules and their streams for snapshots", + "tags": [ + "module_streams" + ] + } + }, "/snapshots/package_groups/names": { "post": { "description": "This enables users to search for package groups in a given list of snapshots.", diff --git a/pkg/api/module_streams.go b/pkg/api/module_streams.go new file mode 100644 index 000000000..fcea1c116 --- /dev/null +++ b/pkg/api/module_streams.go @@ -0,0 +1,23 @@ +package api + +type SearchSnapshotModuleStreamsRequest struct { + UUIDs []string `json:"uuids" validate:"required"` // List of snapshot UUIDs to search + RpmNames []string `json:"rpm_names" validate:"required"` // List of rpm names to restrict returned modules + SortBy string `json:"sort_by"` // SortBy sets the sort order of the result + Search string `json:"search"` // Search string to search module names +} + +type Stream struct { + Name string `json:"name"` // Name of the module + Stream string `json:"stream"` // Module stream version + Context string `json:"context"` // Context of the module + Arch string `json:"arch"` // The Architecture of the rpm + Version string `json:"version"` // The version of the rpm + Description string `json:"description"` // Module description + Profiles map[string][]string `json:"profiles"` // Module profile data +} + +type SearchModuleStreams struct { + ModuleName string `json:"module_name"` // Module name + Streams []Stream `json:"streams"` // A list of stream related information for the module +} diff --git a/pkg/dao/interfaces.go b/pkg/dao/interfaces.go index 879d6eb43..69ccc007d 100644 --- a/pkg/dao/interfaces.go +++ b/pkg/dao/interfaces.go @@ -25,6 +25,7 @@ type DaoRegistry struct { Environment EnvironmentDao Template TemplateDao Uploads UploadDao + ModuleStreams ModuleStreamsDao } func GetDaoRegistry(db *gorm.DB) *DaoRegistry { @@ -37,8 +38,9 @@ func GetDaoRegistry(db *gorm.DB) *DaoRegistry { Rpm: &rpmDaoImpl{ db: db, }, - Repository: repositoryDaoImpl{db: db}, - Metrics: metricsDaoImpl{db: db}, + ModuleStreams: &moduleStreamsImpl{db: db}, + Repository: repositoryDaoImpl{db: db}, + Metrics: metricsDaoImpl{db: db}, Snapshot: &snapshotDaoImpl{ db: db, pulpClient: pulp_client.GetPulpClientWithDomain(""), @@ -80,6 +82,10 @@ type RepositoryConfigDao interface { BulkImport(ctx context.Context, reposToImport []api.RepositoryRequest) ([]api.RepositoryImportResponse, []error) } +type ModuleStreamsDao interface { + SearchSnapshotModuleStreams(ctx context.Context, orgID string, request api.SearchSnapshotModuleStreamsRequest) ([]api.SearchModuleStreams, error) +} + type RpmDao interface { List(ctx context.Context, orgID string, uuidRepo string, limit int, offset int, search string, sortBy string) (api.RepositoryRpmCollectionResponse, int64, error) Search(ctx context.Context, orgID string, request api.ContentUnitSearchRequest) ([]api.SearchRpmResponse, error) diff --git a/pkg/dao/module_streams.go b/pkg/dao/module_streams.go new file mode 100644 index 000000000..898c13f60 --- /dev/null +++ b/pkg/dao/module_streams.go @@ -0,0 +1,97 @@ +package dao + +import ( + "context" + "fmt" + + "github.com/content-services/content-sources-backend/pkg/api" + "github.com/content-services/content-sources-backend/pkg/config" + ce "github.com/content-services/content-sources-backend/pkg/errors" + "github.com/content-services/tang/pkg/tangy" + "gorm.io/gorm" +) + +type moduleStreamsImpl struct { + db *gorm.DB +} + +func GetModuleStreamsDao(db *gorm.DB) ModuleStreamsDao { + // Return DAO instance + return &moduleStreamsImpl{ + db: db, + } +} + +func (r *moduleStreamsImpl) SearchSnapshotModuleStreams(ctx context.Context, orgID string, request api.SearchSnapshotModuleStreamsRequest) ([]api.SearchModuleStreams, error) { + if orgID == "" { + return []api.SearchModuleStreams{}, fmt.Errorf("orgID can not be an empty string") + } + + if request.RpmNames == nil { + request.RpmNames = []string{} + } + + if len(request.UUIDs) == 0 { + return []api.SearchModuleStreams{}, &ce.DaoError{ + BadValidation: true, + Message: "must contain at least 1 snapshot UUID", + } + } + + response := []api.SearchModuleStreams{} + + // Check that snapshot uuids exist + uuidsValid, uuid := checkForValidSnapshotUuids(ctx, request.UUIDs, r.db) + if !uuidsValid { + return []api.SearchModuleStreams{}, &ce.DaoError{ + NotFound: true, + Message: "Could not find snapshot with UUID: " + uuid, + } + } + + pulpHrefs := []string{} + res := readableSnapshots(r.db.WithContext(ctx), orgID).Where("snapshots.UUID in ?", UuidifyStrings(request.UUIDs)).Pluck("version_href", &pulpHrefs) + if res.Error != nil { + return []api.SearchModuleStreams{}, fmt.Errorf("failed to query the db for snapshots: %w", res.Error) + } + if config.Tang == nil { + return []api.SearchModuleStreams{}, fmt.Errorf("no tang configuration present") + } + + if len(pulpHrefs) == 0 { + return []api.SearchModuleStreams{}, nil + } + + pkgs, err := (*config.Tang).RpmRepositoryVersionModuleStreamsList(ctx, pulpHrefs, + tangy.ModuleStreamListFilters{RpmNames: request.RpmNames, Search: request.Search}, request.SortBy) + + if err != nil { + return []api.SearchModuleStreams{}, fmt.Errorf("error querying module streams in snapshots: %w", err) + } + + mappedModuleStreams := map[string][]api.Stream{} + + for _, pkg := range pkgs { + if mappedModuleStreams[pkg.Name] == nil { + mappedModuleStreams[pkg.Name] = []api.Stream{} + } + mappedModuleStreams[pkg.Name] = append(mappedModuleStreams[pkg.Name], api.Stream{ + Name: pkg.Name, + Stream: pkg.Stream, + Context: pkg.Context, + Arch: pkg.Arch, + Version: pkg.Version, + Description: pkg.Description, + Profiles: pkg.Profiles, + }) + } + + for key, moduleStream := range mappedModuleStreams { + response = append(response, api.SearchModuleStreams{ + ModuleName: key, + Streams: moduleStream, + }) + } + + return response, nil +} diff --git a/pkg/dao/module_streams_test.go b/pkg/dao/module_streams_test.go new file mode 100644 index 000000000..8e250a33a --- /dev/null +++ b/pkg/dao/module_streams_test.go @@ -0,0 +1,113 @@ +package dao + +import ( + "context" + "testing" + + "github.com/content-services/content-sources-backend/pkg/api" + "github.com/content-services/content-sources-backend/pkg/config" + "github.com/content-services/content-sources-backend/pkg/models" + "github.com/content-services/content-sources-backend/pkg/seeds" + "github.com/content-services/tang/pkg/tangy" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" +) + +type ModuleStreamSuite struct { + *DaoSuite + repoConfig *models.RepositoryConfiguration + repo *models.Repository + repoPrivate *models.Repository +} + +func (s *ModuleStreamSuite) SetupTest() { + s.DaoSuite.SetupTest() + + repo := repoPublicTest.DeepCopy() + if err := s.tx.Create(repo).Error; err != nil { + s.FailNow("Preparing Repository record: %w", err) + } + s.repo = repo + + repoPrivate := repoPrivateTest.DeepCopy() + if err := s.tx.Create(repoPrivate).Error; err != nil { + s.FailNow("Preparing private Repository record: %w", err) + } + s.repoPrivate = repoPrivate + + repoConfig := repoConfigTest1.DeepCopy() + repoConfig.RepositoryUUID = repo.Base.UUID + if err := s.tx.Create(repoConfig).Error; err != nil { + s.FailNow("Preparing RepositoryConfiguration record: %w", err) + } + s.repoConfig = repoConfig +} + +func TestModuleStreamSuite(t *testing.T) { + m := DaoSuite{} + r := ModuleStreamSuite{DaoSuite: &m} + suite.Run(t, &r) +} + +func (s *RpmSuite) TestSearchModulesForSnapshots() { + orgId := seeds.RandomOrgId() + mTangy, origTangy := mockTangy(s.T()) + defer func() { config.Tang = origTangy }() + ctx := context.Background() + + hrefs := []string{"some_pulp_version_href"} + stream1 := tangy.ModuleStreams{ + Name: "Foodidly", + // add more + } + expected := []tangy.ModuleStreams{stream1} + + // Create a repo config, and snapshot, update its version_href to expected href + _, err := seeds.SeedRepositoryConfigurations(s.tx, 1, seeds.SeedOptions{ + OrgID: orgId, + BatchSize: 0, + }) + require.NoError(s.T(), err) + repoConfig := models.RepositoryConfiguration{} + res := s.tx.Where("org_id = ?", orgId).First(&repoConfig) + require.NoError(s.T(), res.Error) + snaps, err := seeds.SeedSnapshots(s.tx, repoConfig.UUID, 1) + require.NoError(s.T(), err) + res = s.tx.Model(&models.Snapshot{}).Where("repository_configuration_uuid = ?", repoConfig.UUID).Update("version_href", hrefs[0]) + require.NoError(s.T(), res.Error) + // pulpHrefs, request.Search, *request.Limit) + mTangy.On("RpmRepositoryVersionModuleStreamsList", ctx, hrefs, tangy.ModuleStreamListFilters{Search: "Foo", RpmNames: []string{}}, "").Return(expected, nil) + //ctx context.Context, hrefs []string, rpmNames []string, search string, pageOpts PageOption + dao := GetModuleStreamsDao(s.tx) + + resp, err := dao.SearchSnapshotModuleStreams(ctx, orgId, api.SearchSnapshotModuleStreamsRequest{ + UUIDs: []string{snaps[0].UUID}, + RpmNames: []string(nil), + Search: "Foo", + }) + + require.NoError(s.T(), err) + + assert.Equal(s.T(), + []api.SearchModuleStreams{{ModuleName: expected[0].Name, Streams: []api.Stream{{Name: stream1.Name}}}}, + resp, + ) + + // ensure error returned for invalid snapshot uuid + _, err = dao.SearchSnapshotModuleStreams(ctx, orgId, api.SearchSnapshotModuleStreamsRequest{ + UUIDs: []string{"blerg!"}, + Search: "Foo", + }) + + assert.Error(s.T(), err) + + // ensure error returned for no uuids + _, err = dao.SearchSnapshotModuleStreams(ctx, orgId, api.SearchSnapshotModuleStreamsRequest{ + UUIDs: []string{}, + RpmNames: []string{}, + Search: "Foo", + }) + + assert.Error(s.T(), err) +} diff --git a/pkg/dao/modules_streams_mock.go b/pkg/dao/modules_streams_mock.go new file mode 100644 index 000000000..4ac3d3f18 --- /dev/null +++ b/pkg/dao/modules_streams_mock.go @@ -0,0 +1,60 @@ +// Code generated by mockery. DO NOT EDIT. + +package dao + +import ( + context "context" + + api "github.com/content-services/content-sources-backend/pkg/api" + + mock "github.com/stretchr/testify/mock" +) + +// MockModuleStreamsDao is an autogenerated mock type for the ModuleStreamsDao type +type MockModuleStreamsDao struct { + mock.Mock +} + +// SearchSnapshotModuleStreams provides a mock function with given fields: ctx, orgID, request +func (_m *MockModuleStreamsDao) SearchSnapshotModuleStreams(ctx context.Context, orgID string, request api.SearchSnapshotModuleStreamsRequest) ([]api.SearchModuleStreams, error) { + ret := _m.Called(ctx, orgID, request) + + if len(ret) == 0 { + panic("no return value specified for SearchSnapshotModuleStreams") + } + + var r0 []api.SearchModuleStreams + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, api.SearchSnapshotModuleStreamsRequest) ([]api.SearchModuleStreams, error)); ok { + return rf(ctx, orgID, request) + } + if rf, ok := ret.Get(0).(func(context.Context, string, api.SearchSnapshotModuleStreamsRequest) []api.SearchModuleStreams); ok { + r0 = rf(ctx, orgID, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]api.SearchModuleStreams) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string, api.SearchSnapshotModuleStreamsRequest) error); ok { + r1 = rf(ctx, orgID, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// NewMockModuleStreamsDao creates a new instance of MockModuleStreamsDao. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockModuleStreamsDao(t interface { + mock.TestingT + Cleanup(func()) +}) *MockModuleStreamsDao { + mock := &MockModuleStreamsDao{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/pkg/dao/registry_mock.go b/pkg/dao/registry_mock.go index faa8a4502..2d1559588 100644 --- a/pkg/dao/registry_mock.go +++ b/pkg/dao/registry_mock.go @@ -16,6 +16,7 @@ type MockDaoRegistry struct { PackageGroup MockPackageGroupDao Environment MockEnvironmentDao Template MockTemplateDao + ModuleStreams MockModuleStreamsDao } func (m *MockDaoRegistry) ToDaoRegistry() *DaoRegistry { @@ -31,6 +32,7 @@ func (m *MockDaoRegistry) ToDaoRegistry() *DaoRegistry { PackageGroup: &m.PackageGroup, Environment: &m.Environment, Template: &m.Template, + ModuleStreams: &m.ModuleStreams, } return &r } @@ -48,6 +50,7 @@ func GetMockDaoRegistry(t *testing.T) *MockDaoRegistry { PackageGroup: *NewMockPackageGroupDao(t), Environment: *NewMockEnvironmentDao(t), Template: *NewMockTemplateDao(t), + ModuleStreams: *NewMockModuleStreamsDao(t), } return ® } diff --git a/pkg/handler/api.go b/pkg/handler/api.go index d3101240d..9a9d1dc08 100644 --- a/pkg/handler/api.go +++ b/pkg/handler/api.go @@ -87,6 +87,7 @@ func RegisterRoutes(ctx context.Context, engine *echo.Echo) { RegisterTemplateRoutes(group, daoReg, &taskClient) RegisterPulpRoutes(group, daoReg) RegisterCandlepinRoutes(group, &cpClient, &ch) + RegisterModuleStreamsRoutes(group, daoReg) } data, err := json.MarshalIndent(engine.Routes(), "", " ") diff --git a/pkg/handler/module_streams.go b/pkg/handler/module_streams.go new file mode 100644 index 000000000..9a1c621da --- /dev/null +++ b/pkg/handler/module_streams.go @@ -0,0 +1,55 @@ +package handler + +import ( + "net/http" + + "github.com/content-services/content-sources-backend/pkg/api" + "github.com/content-services/content-sources-backend/pkg/dao" + ce "github.com/content-services/content-sources-backend/pkg/errors" + "github.com/content-services/content-sources-backend/pkg/rbac" + "github.com/labstack/echo/v4" +) + +type ModuleStreamsHandler struct { + Dao dao.DaoRegistry +} + +func RegisterModuleStreamsRoutes(engine *echo.Group, rDao *dao.DaoRegistry) { + rh := ModuleStreamsHandler{ + Dao: *rDao, + } + + addRepoRoute(engine, http.MethodPost, "/snapshots/module_streams/search", rh.searchSnapshotModuleStreams, rbac.RbacVerbRead) +} + +// searchSnapshotModuleStreams godoc +// @Summary List modules and their streams for snapshots +// @ID searchSnapshotModuleStreams +// @Description List modules and their streams for snapshots +// @Tags module_streams +// @Accept json +// @Produce json +// @Param body body api.SearchSnapshotModuleStreamsRequest true "request body" +// @Success 200 {object} []api.SearchModuleStreams +// @Failure 400 {object} ce.ErrorResponse +// @Failure 401 {object} ce.ErrorResponse +// @Failure 404 {object} ce.ErrorResponse +// @Failure 500 {object} ce.ErrorResponse +// @Router /snapshots/module_streams/search [post] +func (rh *ModuleStreamsHandler) searchSnapshotModuleStreams(c echo.Context) error { + _, orgId := getAccountIdOrgId(c) + + dataInput := api.SearchSnapshotModuleStreamsRequest{} + + if err := c.Bind(&dataInput); err != nil { + return ce.NewErrorResponse(http.StatusBadRequest, "Error binding parameters", err.Error()) + } + + apiResponse, err := rh.Dao.ModuleStreams.SearchSnapshotModuleStreams(c.Request().Context(), orgId, dataInput) + + if err != nil { + return ce.NewErrorResponse(ce.HttpCodeForDaoError(err), "Error searching modules streams", err.Error()) + } + + return c.JSON(200, apiResponse) +} diff --git a/pkg/handler/module_streams_test.go b/pkg/handler/module_streams_test.go new file mode 100644 index 000000000..de8495e7c --- /dev/null +++ b/pkg/handler/module_streams_test.go @@ -0,0 +1,166 @@ +package handler + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/content-services/content-sources-backend/pkg/api" + "github.com/content-services/content-sources-backend/pkg/config" + "github.com/content-services/content-sources-backend/pkg/dao" + "github.com/content-services/content-sources-backend/pkg/middleware" + test_handler "github.com/content-services/content-sources-backend/pkg/test/handler" + "github.com/labstack/echo/v4" + echo_middleware "github.com/labstack/echo/v4/middleware" + "github.com/redhatinsights/platform-go-middlewares/v2/identity" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" +) + +type ModuleStreamsSuite struct { + suite.Suite + echo *echo.Echo + dao dao.MockDaoRegistry +} + +func TestModuleStreamsSuite(t *testing.T) { + suite.Run(t, new(ModuleStreamsSuite)) +} + +func (suite *ModuleStreamsSuite) SetupTest() { + suite.echo = echo.New() + suite.echo.Use(echo_middleware.RequestIDWithConfig(echo_middleware.RequestIDConfig{ + TargetHeader: "x-rh-insights-request-id", + })) + suite.echo.Use(middleware.WrapMiddlewareWithSkipper(identity.EnforceIdentity, middleware.SkipMiddleware)) + suite.dao = *dao.GetMockDaoRegistry(suite.T()) +} + +func (suite *ModuleStreamsSuite) TearDownTest() { + require.NoError(suite.T(), suite.echo.Shutdown(context.Background())) +} + +func (suite *ModuleStreamsSuite) serveModuleStreamsRouter(req *http.Request) (int, []byte, error) { + var ( + err error + ) + + router := echo.New() + router.Use(echo_middleware.RequestIDWithConfig(echo_middleware.RequestIDConfig{ + TargetHeader: "x-rh-insights-request-id", + })) + router.Use(middleware.WrapMiddlewareWithSkipper(identity.EnforceIdentity, middleware.SkipMiddleware)) + pathPrefix := router.Group(api.FullRootPath()) + + router.HTTPErrorHandler = config.CustomHTTPErrorHandler + + rh := ModuleStreamsHandler{ + Dao: *suite.dao.ToDaoRegistry(), + } + RegisterModuleStreamsRoutes(pathPrefix, &rh.Dao) + + rr := httptest.NewRecorder() + router.ServeHTTP(rr, req) + + response := rr.Result() + defer response.Body.Close() + + body, err := io.ReadAll(response.Body) + return response.StatusCode, body, err +} + +func (suite *ModuleStreamsSuite) TestSearchSnapshotModuleStreams() { + t := suite.T() + + config.Load() + config.Get().Features.Snapshots.Enabled = true + config.Get().Features.Snapshots.Accounts = &[]string{test_handler.MockAccountNumber} + defer resetFeatures() + + type TestCaseExpected struct { + Code int + Body string + } + + type TestCaseGiven struct { + Method string + Body string + } + + type TestCase struct { + Name string + Given TestCaseGiven + Expected TestCaseExpected + } + + var testCases []TestCase = []TestCase{ + { + Name: "Success scenario", + Given: TestCaseGiven{ + Method: http.MethodPost, + Body: `{"uuids":["abcd"],"rpm_names":[],"search":"demo"}`, + }, + Expected: TestCaseExpected{ + Code: http.StatusOK, + Body: "[]\n", + }, + }, + { + Name: "Evoke a StatusBadRequest response", + Given: TestCaseGiven{ + Method: http.MethodPost, + Body: "{", + }, + Expected: TestCaseExpected{ + Code: http.StatusBadRequest, + Body: "{\"errors\":[{\"status\":400,\"title\":\"Error binding parameters\",\"detail\":\"code=400, message=unexpected EOF, internal=unexpected EOF\"}]}\n", + }, + }, + } + + for _, testCase := range testCases { + t.Log(testCase.Name) + + path := fmt.Sprintf("%s/snapshots/module_streams/search", api.FullRootPath()) + switch { + case testCase.Expected.Code >= 200 && testCase.Expected.Code < 300: + { + var bodyRequest api.SearchSnapshotModuleStreamsRequest + err := json.Unmarshal([]byte(testCase.Given.Body), &bodyRequest) + require.NoError(t, err) + suite.dao.ModuleStreams.On("SearchSnapshotModuleStreams", mock.AnythingOfType("*context.valueCtx"), test_handler.MockOrgId, bodyRequest). + Return([]api.SearchModuleStreams{}, nil) + } + default: + { + } + } + + var bodyRequest io.Reader + if testCase.Given.Body == "" { + bodyRequest = nil + } else { + bodyRequest = strings.NewReader(testCase.Given.Body) + } + + // Prepare request + req := httptest.NewRequest(testCase.Given.Method, path, bodyRequest) + req.Header.Set(api.IdentityHeader, test_handler.EncodedIdentity(t)) + req.Header.Set("Content-Type", "application/json") + + // Execute the request + code, body, err := suite.serveModuleStreamsRouter(req) + + // Check results + assert.Equal(t, testCase.Expected.Code, code) + require.NoError(t, err) + assert.Equal(t, testCase.Expected.Body, string(body)) + } +}