diff --git a/cmd/external-repos/main.go b/cmd/external-repos/main.go index c7db89afa..99def9c7a 100644 --- a/cmd/external-repos/main.go +++ b/cmd/external-repos/main.go @@ -201,7 +201,7 @@ func enqueueSnapshotRepos(urls *[]string) error { } c := client.NewTaskClient(&q) - repoConfigDao := dao.GetRepositoryConfigDao(db.DB) + repoConfigDao := dao.GetRepositoryConfigDao(db.DB, pulp_client.GetPulpClientWithDomain(context.Background(), "")) var filter *dao.ListRepoFilter if urls != nil { filter = &dao.ListRepoFilter{ diff --git a/hammer.py b/hammer.py deleted file mode 100644 index 3171c35f5..000000000 --- a/hammer.py +++ /dev/null @@ -1,10 +0,0 @@ -import requests - -def sendReq(): - url = "http://localhost:8000/api/content-sources/v1.0/repositories/6f202e47-a1d1-4813-8fc9-d506ef67cf7d/introspect/" - headers = { "x-rh-identity":"eyJpZGVudGl0eSI6eyJ0eXBlIjoiVXNlciIsInVzZXIiOnsidXNlcm5hbWUiOiJqZG9lIn0sImludGVybmFsIjp7Im9yZ19pZCI6IjEyMyJ9fX0K" } - x = requests.post(url, headers = headers) - -for i in range(100): - sendReq() - diff --git a/pkg/dao/interfaces.go b/pkg/dao/interfaces.go index 004167f5b..a0d40b870 100644 --- a/pkg/dao/interfaces.go +++ b/pkg/dao/interfaces.go @@ -24,16 +24,22 @@ type DaoRegistry struct { func GetDaoRegistry(db *gorm.DB) *DaoRegistry { reg := DaoRegistry{ RepositoryConfig: &repositoryConfigDaoImpl{ - db: db, - yumRepo: &yum.Repository{}, + db: db, + yumRepo: &yum.Repository{}, + pulpClient: pulp_client.GetPulpClientWithDomain(context.Background(), ""), + ctx: context.Background(), }, Rpm: rpmDaoImpl{db: db}, Repository: repositoryDaoImpl{db: db}, Metrics: metricsDaoImpl{db: db}, - Snapshot: &snapshotDaoImpl{db: db}, - TaskInfo: taskInfoDaoImpl{db: db}, - AdminTask: adminTaskInfoDaoImpl{db: db, pulpClient: pulp_client.GetGlobalPulpClient(context.Background())}, - Domain: domainDaoImpl{db: db}, + Snapshot: &snapshotDaoImpl{ + db: db, + pulpClient: pulp_client.GetPulpClientWithDomain(context.Background(), ""), + ctx: context.Background(), + }, + TaskInfo: taskInfoDaoImpl{db: db}, + AdminTask: adminTaskInfoDaoImpl{db: db, pulpClient: pulp_client.GetGlobalPulpClient(context.Background())}, + Domain: domainDaoImpl{db: db}, } return ® } @@ -55,7 +61,7 @@ type RepositoryConfigDao interface { InternalOnly_FetchRepoConfigsForRepoUUID(uuid string) []api.RepositoryResponse UpdateLastSnapshotTask(taskUUID string, orgID string, repoUUID string) error InternalOnly_RefreshRedHatRepo(request api.RepositoryRequest) (*api.RepositoryResponse, error) - InitializePulpClient(ctx context.Context, orgID string) error + WithContext(ctx context.Context) RepositoryConfigDao } //go:generate mockery --name RpmDao --filename rpms_mock.go --inpackage @@ -84,7 +90,7 @@ type SnapshotDao interface { Delete(snapUUID string) error FetchLatestSnapshot(repoConfigUUID string) (api.SnapshotResponse, error) GetRepositoryConfigurationFile(orgID, snapshotUUID, repoConfigUUID string) (string, error) - InitializePulpClient(ctx context.Context, orgID string) error + WithContext(ctx context.Context) SnapshotDao } //go:generate mockery --name MetricsDao --filename metrics_mock.go --inpackage diff --git a/pkg/dao/mock_helpers.go b/pkg/dao/mock_helpers.go new file mode 100644 index 000000000..96173169d --- /dev/null +++ b/pkg/dao/mock_helpers.go @@ -0,0 +1,13 @@ +package dao + +import "github.com/stretchr/testify/mock" + +func (m *MockRepositoryConfigDao) WithContextMock() *MockRepositoryConfigDao { + m.On("WithContext", mock.AnythingOfType("*context.valueCtx")).Return(m) + return m +} + +func (m *MockSnapshotDao) WithContextMock() *MockSnapshotDao { + m.On("WithContext", mock.AnythingOfType("*context.valueCtx")).Return(m) + return m +} diff --git a/pkg/dao/repository_configs.go b/pkg/dao/repository_configs.go index f0859ef5f..dc8c07063 100644 --- a/pkg/dao/repository_configs.go +++ b/pkg/dao/repository_configs.go @@ -28,16 +28,10 @@ type repositoryConfigDaoImpl struct { db *gorm.DB yumRepo yum.YumRepository pulpClient pulp_client.PulpClient + ctx context.Context } -func GetRepositoryConfigDao(db *gorm.DB) RepositoryConfigDao { - return &repositoryConfigDaoImpl{ - db: db, - yumRepo: &yum.Repository{}, - } -} - -func GetRepositoryConfigDaoWithPulpClient(db *gorm.DB, pulpClient pulp_client.PulpClient) RepositoryConfigDao { +func GetRepositoryConfigDao(db *gorm.DB, pulpClient pulp_client.PulpClient) RepositoryConfigDao { return &repositoryConfigDaoImpl{ db: db, yumRepo: &yum.Repository{}, @@ -79,20 +73,10 @@ func DBErrorToApi(e error) *ce.DaoError { } } -func (r *repositoryConfigDaoImpl) InitializePulpClient(ctx context.Context, orgID string) error { - if !config.Get().Features.Snapshots.Enabled { - return nil - } - - dDao := GetDomainDao(r.db) - domainName, err := dDao.Fetch(orgID) - if err != nil { - return err - } - - pulpClient := pulp_client.GetPulpClientWithDomain(context.TODO(), domainName) - r.pulpClient = pulpClient - return nil +func (r *repositoryConfigDaoImpl) WithContext(ctx context.Context) RepositoryConfigDao { + cpy := *r + cpy.ctx = ctx + return &cpy } func (r repositoryConfigDaoImpl) Create(newRepoReq api.RepositoryRequest) (api.RepositoryResponse, error) { @@ -265,7 +249,6 @@ func (r repositoryConfigDaoImpl) List( ) (api.RepositoryCollectionResponse, int64, error) { var totalRepos int64 repoConfigs := make([]models.RepositoryConfiguration, 0) - var err error var contentPath string filteredDB := r.filteredDbForList(OrgID, r.db, filterData) @@ -304,8 +287,14 @@ func (r repositoryConfigDaoImpl) List( return api.RepositoryCollectionResponse{}, totalRepos, filteredDB.Error } - if r.pulpClient != nil && config.Get().Features.Snapshots.Enabled { - contentPath, err = r.pulpClient.GetContentPath() + if config.Get().Features.Snapshots.Enabled { + dDao := domainDaoImpl{db: r.db} + domain, err := dDao.Fetch(OrgID) + if err != nil { + return api.RepositoryCollectionResponse{}, totalRepos, err + } + + contentPath, err = r.pulpClient.WithContext(r.ctx).WithDomain(domain).GetContentPath() if err != nil { return api.RepositoryCollectionResponse{}, totalRepos, err } @@ -404,10 +393,12 @@ func (r repositoryConfigDaoImpl) Fetch(orgID string, uuid string) (api.Repositor ModelToApiFields(repoConfig, &repo) if repoConfig.LastSnapshot != nil && config.Get().Features.Snapshots.Enabled { - if r.pulpClient == nil { - return api.RepositoryResponse{}, fmt.Errorf("pulpClient cannot be nil") + dDao := domainDaoImpl{db: r.db} + domainName, err := dDao.Fetch(orgID) + if err != nil { + return api.RepositoryResponse{}, err } - contentPath, err := r.pulpClient.GetContentPath() + contentPath, err := r.pulpClient.WithContext(r.ctx).WithDomain(domainName).GetContentPath() if err != nil { return api.RepositoryResponse{}, err } diff --git a/pkg/dao/repository_configs_mock.go b/pkg/dao/repository_configs_mock.go index 9fc4891e7..3fef48529 100644 --- a/pkg/dao/repository_configs_mock.go +++ b/pkg/dao/repository_configs_mock.go @@ -147,20 +147,6 @@ func (_m *MockRepositoryConfigDao) FetchByRepoUuid(orgID string, repoUuid string return r0, r1 } -// InitializePulpClient provides a mock function with given fields: ctx, orgID -func (_m *MockRepositoryConfigDao) InitializePulpClient(ctx context.Context, orgID string) error { - ret := _m.Called(ctx, orgID) - - var r0 error - if rf, ok := ret.Get(0).(func(context.Context, string) error); ok { - r0 = rf(ctx, orgID) - } else { - r0 = ret.Error(0) - } - - return r0 -} - // InternalOnly_FetchRepoConfigsForRepoUUID provides a mock function with given fields: uuid func (_m *MockRepositoryConfigDao) InternalOnly_FetchRepoConfigsForRepoUUID(uuid string) []api.RepositoryResponse { ret := _m.Called(uuid) @@ -350,6 +336,22 @@ func (_m *MockRepositoryConfigDao) ValidateParameters(orgId string, params api.R return r0, r1 } +// WithContext provides a mock function with given fields: ctx +func (_m *MockRepositoryConfigDao) WithContext(ctx context.Context) RepositoryConfigDao { + ret := _m.Called(ctx) + + var r0 RepositoryConfigDao + if rf, ok := ret.Get(0).(func(context.Context) RepositoryConfigDao); ok { + r0 = rf(ctx) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(RepositoryConfigDao) + } + } + + return r0 +} + // NewMockRepositoryConfigDao creates a new instance of MockRepositoryConfigDao. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // The first argument is typically a *testing.T value. func NewMockRepositoryConfigDao(t interface { diff --git a/pkg/dao/repository_configs_test.go b/pkg/dao/repository_configs_test.go index 3e8751b15..b48b610cb 100644 --- a/pkg/dao/repository_configs_test.go +++ b/pkg/dao/repository_configs_test.go @@ -1,6 +1,7 @@ package dao import ( + "context" "fmt" "strconv" "strings" @@ -27,6 +28,7 @@ import ( type RepositoryConfigSuite struct { *DaoSuite + mockPulpClient *pulp_client.MockPulpClient } func (suite *RepositoryConfigSuite) SetupTest() { @@ -43,7 +45,7 @@ func (suite *RepositoryConfigSuite) SetupTest() { func TestRepositoryConfigSuite(t *testing.T) { m := DaoSuite{} - r := RepositoryConfigSuite{DaoSuite: &m} + r := RepositoryConfigSuite{DaoSuite: &m, mockPulpClient: pulp_client.NewMockPulpClient(t)} suite.Run(t, &r) } @@ -80,7 +82,7 @@ func (suite *RepositoryConfigSuite) TestCreate() { MetadataVerification: &metadataVerification, } - dao := GetRepositoryConfigDao(tx) + dao := GetRepositoryConfigDao(tx, suite.mockPulpClient) created, err := dao.Create(toCreate) assert.Nil(t, err) @@ -100,11 +102,11 @@ func (suite *RepositoryConfigSuite) TestCreateTwiceWithNoSlash() { config.El9, }, } - dao := GetRepositoryConfigDao(suite.tx) + dao := GetRepositoryConfigDao(suite.tx, suite.mockPulpClient) _, err := dao.Create(toCreate) assert.ErrorContains(suite.T(), err, "Name cannot be blank") - dao = GetRepositoryConfigDao(suite.tx) + dao = GetRepositoryConfigDao(suite.tx, suite.mockPulpClient) _, err = dao.Create(toCreate) assert.ErrorContains(suite.T(), err, "Name cannot be blank") } @@ -120,7 +122,7 @@ func (suite *RepositoryConfigSuite) TestCreateRedHatRepository() { config.El9, }, } - dao := GetRepositoryConfigDao(suite.tx) + dao := GetRepositoryConfigDao(suite.tx, suite.mockPulpClient) _, err := dao.Create(toCreate) assert.ErrorContains(suite.T(), err, "Creating of Red Hat repositories is not permitted") } @@ -149,7 +151,7 @@ func (suite *RepositoryConfigSuite) TestRepositoryCreateAlreadyExists() { // Force failure on creating duplicate tx.SavePoint("before") - _, err = GetRepositoryConfigDao(tx).Create(api.RepositoryRequest{ + _, err = GetRepositoryConfigDao(tx, suite.mockPulpClient).Create(api.RepositoryRequest{ Name: &found.Name, URL: &found.Repository.URL, OrgID: &found.OrgID, @@ -167,7 +169,7 @@ func (suite *RepositoryConfigSuite) TestRepositoryCreateAlreadyExists() { tx.RollbackTo("before") // Force failure on creating duplicate url - _, err = GetRepositoryConfigDao(tx).Create(api.RepositoryRequest{ + _, err = GetRepositoryConfigDao(tx, suite.mockPulpClient).Create(api.RepositoryRequest{ Name: pointy.Pointer("new name"), URL: &found.Repository.URL, OrgID: &found.OrgID, @@ -225,7 +227,7 @@ func (suite *RepositoryConfigSuite) TestRepositoryCreateBlank() { } tx.SavePoint("testrepositorycreateblanktest") for i := 0; i < len(blankItems); i++ { - _, err := GetRepositoryConfigDao(tx).Create(blankItems[i].given) + _, err := GetRepositoryConfigDao(tx, suite.mockPulpClient).Create(blankItems[i].given) assert.NotNil(t, err) if blankItems[i].expected == "" { assert.NoError(t, err) @@ -263,7 +265,7 @@ func (suite *RepositoryConfigSuite) TestBulkCreateCleanupURL() { }, } - rr, errs := GetRepositoryConfigDao(tx).BulkCreate(request) + rr, errs := GetRepositoryConfigDao(tx, suite.mockPulpClient).BulkCreate(request) require.Empty(t, errs) assert.Equal(t, repository.URL, rr[0].URL) } @@ -287,7 +289,7 @@ func (suite *RepositoryConfigSuite) TestBulkCreate() { } } - rr, errs := GetRepositoryConfigDao(tx).BulkCreate(requests) + rr, errs := GetRepositoryConfigDao(tx, suite.mockPulpClient).BulkCreate(requests) assert.Empty(t, errs) assert.Equal(t, amountToCreate, len(rr)) @@ -324,7 +326,7 @@ func (suite *RepositoryConfigSuite) TestBulkCreateOneFails() { }, } - rr, errs := GetRepositoryConfigDao(tx).BulkCreate(requests) + rr, errs := GetRepositoryConfigDao(tx, suite.mockPulpClient).BulkCreate(requests) assert.NotEmpty(t, errs) assert.Empty(t, rr) @@ -367,14 +369,14 @@ func (suite *RepositoryConfigSuite) updateTest(url string) { t := suite.T() var err error - createResp, err := GetRepositoryConfigDao(suite.tx).Create(api.RepositoryRequest{ + createResp, err := GetRepositoryConfigDao(suite.tx, suite.mockPulpClient).Create(api.RepositoryRequest{ Name: pointy.String("NotUpdated"), URL: &url, OrgID: pointy.String("MyGreatOrg"), }) assert.Nil(t, err) - _, err = GetRepositoryConfigDao(suite.tx).Update(createResp.OrgID, createResp.UUID, + _, err = GetRepositoryConfigDao(suite.tx, suite.mockPulpClient).Update(createResp.OrgID, createResp.UUID, api.RepositoryRequest{ Name: &name, URL: &url, @@ -398,7 +400,7 @@ func (suite *RepositoryConfigSuite) TestUpdateDuplicateVersions() { assert.Nil(t, err) found := models.RepositoryConfiguration{} suite.tx.First(&found) - _, err = GetRepositoryConfigDao(suite.tx).Update(found.OrgID, found.UUID, + _, err = GetRepositoryConfigDao(suite.tx, suite.mockPulpClient).Update(found.OrgID, found.UUID, api.RepositoryRequest{ DistributionVersions: &duplicateVersions, }) @@ -447,7 +449,7 @@ func (suite *RepositoryConfigSuite) TestUpdateEmpty() { assert.NotEmpty(t, found.Arch) // Update the RepositoryConfiguration record using dao method - _, err = GetRepositoryConfigDao(tx).Update(found.OrgID, found.UUID, + _, err = GetRepositoryConfigDao(tx, suite.mockPulpClient).Update(found.OrgID, found.UUID, api.RepositoryRequest{ Name: &name, DistributionArch: &arch, @@ -476,7 +478,7 @@ func (suite *RepositoryConfigSuite) TestDuplicateUpdate() { var created1 api.RepositoryResponse var created2 api.RepositoryResponse - created1, err = GetRepositoryConfigDao(suite.tx). + created1, err = GetRepositoryConfigDao(suite.tx, suite.mockPulpClient). Create(api.RepositoryRequest{ OrgID: &repoConfig.OrgID, AccountID: &repoConfig.AccountID, @@ -485,7 +487,7 @@ func (suite *RepositoryConfigSuite) TestDuplicateUpdate() { }) assert.NoError(t, err) - created2, err = GetRepositoryConfigDao(suite.tx). + created2, err = GetRepositoryConfigDao(suite.tx, suite.mockPulpClient). Create(api.RepositoryRequest{ OrgID: &created1.OrgID, AccountID: &created1.AccountID, @@ -493,7 +495,7 @@ func (suite *RepositoryConfigSuite) TestDuplicateUpdate() { URL: &url}) assert.NoError(t, err) - _, err = GetRepositoryConfigDao(tx).Update( + _, err = GetRepositoryConfigDao(tx, suite.mockPulpClient).Update( created2.OrgID, created2.UUID, api.RepositoryRequest{ @@ -521,7 +523,7 @@ func (suite *RepositoryConfigSuite) TestUpdateNotFound() { Error require.NoError(t, err) - _, err = GetRepositoryConfigDao(suite.tx).Update("Wrong OrgID!! zomg hacker", found.UUID, + _, err = GetRepositoryConfigDao(suite.tx, suite.mockPulpClient).Update("Wrong OrgID!! zomg hacker", found.UUID, api.RepositoryRequest{ Name: &name, URL: &name, @@ -581,7 +583,7 @@ func (suite *RepositoryConfigSuite) TestUpdateBlank() { } tx.SavePoint("updateblanktest") for i := 0; i < len(blankItems); i++ { - _, err := GetRepositoryConfigDao(tx).Update(orgID, found.UUID, blankItems[i].given) + _, err := GetRepositoryConfigDao(tx, suite.mockPulpClient).Update(orgID, found.UUID, blankItems[i].given) assert.Error(t, err) if blankItems[i].expected == "" { assert.NoError(t, err) @@ -611,8 +613,7 @@ func (suite *RepositoryConfigSuite) TestFetch() { Error assert.NoError(t, err) - mockPulpClient := pulp_client.NewMockPulpClient(t) - rDao := repositoryConfigDaoImpl{db: tx, pulpClient: mockPulpClient} + repoConfigDao := GetRepositoryConfigDao(suite.tx, suite.mockPulpClient).WithContext(context.Background()) snap := models.Snapshot{ Base: models.Base{UUID: uuid.NewString()}, @@ -635,10 +636,10 @@ func (suite *RepositoryConfigSuite) TestFetch() { assert.NoError(t, err) if config.Get().Features.Snapshots.Enabled { - mockPulpClient.On("GetContentPath").Return(testContentPath, nil) + suite.mockPulpForListOrFetch(1) } - fetched, err := rDao.Fetch(found.OrgID, found.UUID) + fetched, err := repoConfigDao.Fetch(found.OrgID, found.UUID) assert.Nil(t, err) assert.Equal(t, found.UUID, fetched.UUID) assert.Equal(t, found.Name, fetched.Name) @@ -665,7 +666,7 @@ func (suite *RepositoryConfigSuite) TestFetchByRepo() { Error assert.NoError(t, err) - fetched, err := GetRepositoryConfigDao(suite.tx).FetchByRepoUuid(found.OrgID, found.RepositoryUUID) + fetched, err := GetRepositoryConfigDao(suite.tx, suite.mockPulpClient).FetchByRepoUuid(found.OrgID, found.RepositoryUUID) assert.Nil(t, err) assert.Equal(t, found.UUID, fetched.UUID) assert.Equal(t, found.Name, fetched.Name) @@ -685,13 +686,15 @@ func (suite *RepositoryConfigSuite) TestFetchNotFound() { Error assert.NoError(t, err) - _, err = GetRepositoryConfigDao(suite.tx).Fetch("bad org id", found.UUID) + repoConfigDao := GetRepositoryConfigDao(suite.tx, suite.mockPulpClient).WithContext(context.Background()) + + _, err = repoConfigDao.Fetch("bad org id", found.UUID) assert.NotNil(t, err) daoError, ok := err.(*ce.DaoError) assert.True(t, ok) assert.True(t, daoError.NotFound) - _, err = GetRepositoryConfigDao(suite.tx).Fetch(orgID, "bad uuid") + _, err = repoConfigDao.Fetch(orgID, "bad uuid") assert.NotNil(t, err) daoError, ok = err.(*ce.DaoError) assert.True(t, ok) @@ -718,7 +721,7 @@ func (suite *RepositoryConfigSuite) TestInternalOnly_FetchRepoConfigsForRepoUUID assert.Nil(t, err) } - results := GetRepositoryConfigDao(suite.tx).InternalOnly_FetchRepoConfigsForRepoUUID(repoConfig.RepositoryUUID) + results := GetRepositoryConfigDao(suite.tx, suite.mockPulpClient).InternalOnly_FetchRepoConfigsForRepoUUID(repoConfig.RepositoryUUID) // Confirm all 10 repoConfigs are returned assert.Equal(t, numberOfRepos, len(results)) @@ -774,13 +777,12 @@ func (suite *RepositoryConfigSuite) TestList() { Error assert.NoError(t, err) - mockPulpClient := pulp_client.NewMockPulpClient(t) - rDao := repositoryConfigDaoImpl{db: suite.tx, pulpClient: mockPulpClient} + rDao := repositoryConfigDaoImpl{db: suite.tx, pulpClient: suite.mockPulpClient} if config.Get().Features.Snapshots.Enabled { - mockPulpClient.On("GetContentPath").Return(testContentPath, nil) + suite.mockPulpForListOrFetch(1) } - response, total, err := rDao.List(orgID, pageData, filterData) + response, total, err := rDao.WithContext(context.Background()).List(orgID, pageData, filterData) assert.Nil(t, err) assert.Equal(t, int64(1), total) assert.Equal(t, 1, len(response.Data)) @@ -822,7 +824,9 @@ func (suite *RepositoryConfigSuite) TestListPageDataLimit0() { assert.Nil(t, result.Error) assert.Equal(t, int64(1), total) - response, total, err := GetRepositoryConfigDao(suite.tx).List(orgID, pageData, filterData) + repoConfigDao := GetRepositoryConfigDao(suite.tx, suite.mockPulpClient).WithContext(context.Background()) + suite.mockPulpForListOrFetch(1) + response, total, err := repoConfigDao.List(orgID, pageData, filterData) assert.Nil(t, err) assert.Equal(t, int64(1), total) assert.Equal(t, 0, len(response.Data)) // We have limited the data to 0, so response.data will return 0 @@ -847,7 +851,10 @@ func (suite *RepositoryConfigSuite) TestListNoRepositories() { assert.Nil(t, result.Error) assert.Equal(t, int64(0), total) - response, total, err := GetRepositoryConfigDao(suite.tx).List(orgID, pageData, filterData) + repoConfigDao := GetRepositoryConfigDao(suite.tx, suite.mockPulpClient).WithContext(context.Background()) + suite.mockPulpForListOrFetch(1) + + response, total, err := repoConfigDao.List(orgID, pageData, filterData) assert.Nil(t, err) assert.Empty(t, response.Data) assert.Equal(t, int64(0), total) @@ -878,7 +885,10 @@ func (suite *RepositoryConfigSuite) TestListPageLimit() { assert.Nil(t, result.Error) assert.Equal(t, int64(20), total) - response, total, err := GetRepositoryConfigDao(suite.tx).List(orgID, pageData, filterData) + repoConfigDao := GetRepositoryConfigDao(suite.tx, suite.mockPulpClient).WithContext(context.Background()) + suite.mockPulpForListOrFetch(1) + + response, total, err := repoConfigDao.List(orgID, pageData, filterData) assert.Nil(t, err) assert.Equal(t, len(response.Data), pageData.Limit) @@ -896,11 +906,16 @@ func (suite *RepositoryConfigSuite) TestListFilterName() { filterData := api.FilterData{} assert.Nil(t, seeds.SeedRepositoryConfigurations(suite.tx, 2, seeds.SeedOptions{OrgID: orgID, Versions: &[]string{config.El9}})) - allRepoResp, _, err := GetRepositoryConfigDao(suite.tx).List(orgID, api.PaginationData{Limit: -1}, api.FilterData{}) + + repoConfigDao := GetRepositoryConfigDao(suite.tx, suite.mockPulpClient).WithContext(context.Background()) + suite.mockPulpForListOrFetch(1) + suite.mockPulpForListOrFetch(1) + + allRepoResp, _, err := repoConfigDao.List(orgID, api.PaginationData{Limit: -1}, api.FilterData{}) assert.NoError(t, err) filterData.Name = allRepoResp.Data[0].Name - response, total, err := GetRepositoryConfigDao(suite.tx).List(orgID, api.PaginationData{Limit: -1}, filterData) + response, total, err := repoConfigDao.List(orgID, api.PaginationData{Limit: -1}, filterData) assert.Nil(t, err) assert.Equal(t, 1, len(response.Data)) assert.Equal(t, 1, int(total)) @@ -911,14 +926,18 @@ func (suite *RepositoryConfigSuite) TestListFilterName() { func (suite *RepositoryConfigSuite) TestListFilterUrl() { t := suite.T() orgID := seeds.RandomOrgId() + + repoConfigDao := GetRepositoryConfigDao(suite.tx, suite.mockPulpClient).WithContext(context.Background()) + suite.mockPulpForListOrFetch(3) + filterData := api.FilterData{} assert.Nil(t, seeds.SeedRepositoryConfigurations(suite.tx, 2, seeds.SeedOptions{OrgID: orgID, Versions: &[]string{config.El9}})) - allRepoResp, _, err := GetRepositoryConfigDao(suite.tx).List(orgID, api.PaginationData{Limit: -1}, api.FilterData{}) + allRepoResp, _, err := repoConfigDao.List(orgID, api.PaginationData{Limit: -1}, api.FilterData{}) assert.NoError(t, err) filterData.URL = allRepoResp.Data[0].URL - response, total, err := GetRepositoryConfigDao(suite.tx).List(orgID, api.PaginationData{Limit: -1}, filterData) + response, total, err := repoConfigDao.List(orgID, api.PaginationData{Limit: -1}, filterData) assert.Nil(t, err) assert.Equal(t, 1, len(response.Data)) assert.Equal(t, 1, int(total)) @@ -927,7 +946,7 @@ func (suite *RepositoryConfigSuite) TestListFilterUrl() { // Test that it works with urls missing a trailing slash filterData.URL = filterData.URL[:len(filterData.URL)-1] - response, total, err = GetRepositoryConfigDao(suite.tx).List(orgID, api.PaginationData{Limit: -1}, filterData) + response, total, err = repoConfigDao.List(orgID, api.PaginationData{Limit: -1}, filterData) assert.Nil(t, err) assert.Equal(t, 1, len(response.Data)) assert.Equal(t, 1, int(total)) @@ -952,7 +971,11 @@ func (suite *RepositoryConfigSuite) TestListFilterVersion() { quantity := 20 assert.Nil(t, seeds.SeedRepositoryConfigurations(suite.tx, quantity, seeds.SeedOptions{OrgID: orgID, Versions: &[]string{config.El9}})) - response, total, err := GetRepositoryConfigDao(suite.tx).List(orgID, pageData, filterData) + + repoConfigDao := GetRepositoryConfigDao(suite.tx, suite.mockPulpClient).WithContext(context.Background()) + suite.mockPulpForListOrFetch(1) + + response, total, err := repoConfigDao.List(orgID, pageData, filterData) assert.Nil(t, err) assert.Equal(t, quantity, len(response.Data)) @@ -995,7 +1018,10 @@ func (suite *RepositoryConfigSuite) TestListFilterArch() { assert.Nil(t, result.Error) assert.Equal(t, int64(quantity), total) - response, total, err := GetRepositoryConfigDao(tx).List(orgID, pageData, filterData) + repoConfigDao := GetRepositoryConfigDao(suite.tx, suite.mockPulpClient).WithContext(context.Background()) + suite.mockPulpForListOrFetch(1) + + response, total, err := repoConfigDao.List(orgID, pageData, filterData) assert.Nil(t, err) assert.Equal(t, quantity, len(response.Data)) @@ -1038,14 +1064,17 @@ func (suite *RepositoryConfigSuite) TestListFilterOrigin() { assert.Nil(t, result.Error) assert.Equal(t, int64(quantity), total) - response, total, err := GetRepositoryConfigDao(tx).List(orgID, pageData, filterData) + repoConfigDao := GetRepositoryConfigDao(suite.tx, suite.mockPulpClient).WithContext(context.Background()) + suite.mockPulpForListOrFetch(2) + + response, total, err := repoConfigDao.List(orgID, pageData, filterData) assert.Nil(t, err) assert.Equal(t, quantity, len(response.Data)) assert.Equal(t, int64(quantity), total) filterData.Origin = fmt.Sprintf("%v,%v", config.OriginExternal, "notarealorigin") - response, total, err = GetRepositoryConfigDao(tx).List(orgID, pageData, filterData) + response, total, err = repoConfigDao.List(orgID, pageData, filterData) assert.Nil(t, err) assert.Equal(t, quantity, len(response.Data)) @@ -1074,7 +1103,10 @@ func (suite *RepositoryConfigSuite) TestListFilterContentType() { err = seeds.SeedRepositoryConfigurations(tx, quantity, seeds.SeedOptions{OrgID: orgID, ContentType: pointy.Pointer("SomeOther")}) assert.Nil(t, err) - response, total, err := GetRepositoryConfigDao(tx).List(orgID, pageData, filterData) + repoConfigDao := GetRepositoryConfigDao(suite.tx, suite.mockPulpClient).WithContext(context.Background()) + + suite.mockPulpForListOrFetch(1) + response, total, err := repoConfigDao.List(orgID, pageData, filterData) assert.Nil(t, err) assert.Equal(t, quantity, len(response.Data)) @@ -1106,7 +1138,10 @@ func (suite *RepositoryConfigSuite) TestListFilterStatus() { assert.Nil(t, seeds.SeedRepositoryConfigurations(suite.tx, quantity/3, seeds.SeedOptions{OrgID: orgID, Status: pointy.String(config.StatusPending)})) - response, count, err := GetRepositoryConfigDao(suite.tx).List(orgID, pageData, filterData) + repoConfigDao := GetRepositoryConfigDao(suite.tx, suite.mockPulpClient).WithContext(context.Background()) + + suite.mockPulpForListOrFetch(1) + response, count, err := repoConfigDao.List(orgID, pageData, filterData) assert.Nil(t, err) assert.Equal(t, 20, len(response.Data)) @@ -1141,7 +1176,10 @@ func (suite *RepositoryConfigSuite) TestListFilterMultipleArch() { assert.Nil(t, seeds.SeedRepositoryConfigurations(suite.tx, 10, seeds.SeedOptions{OrgID: orgID, Arch: &s390xref})) assert.Nil(t, seeds.SeedRepositoryConfigurations(suite.tx, 30, seeds.SeedOptions{OrgID: orgID, Arch: &x86ref})) - response, count, err := GetRepositoryConfigDao(suite.tx).List(orgID, pageData, filterData) + repoConfigDao := GetRepositoryConfigDao(suite.tx, suite.mockPulpClient).WithContext(context.Background()) + + suite.mockPulpForListOrFetch(1) + response, count, err := repoConfigDao.List(orgID, pageData, filterData) assert.Nil(t, err) assert.Equal(t, quantity, len(response.Data)) @@ -1182,7 +1220,10 @@ func (suite *RepositoryConfigSuite) TestListFilterMultipleVersions() { assert.Nil(t, seeds.SeedRepositoryConfigurations(suite.tx, quantity, seeds.SeedOptions{OrgID: "kdksfkdf", Versions: &[]string{config.El7, config.El8, config.El9}})) - response, count, err := GetRepositoryConfigDao(suite.tx).List(orgID, pageData, filterData) + repoConfigDao := GetRepositoryConfigDao(suite.tx, suite.mockPulpClient).WithContext(context.Background()) + + suite.mockPulpForListOrFetch(1) + response, count, err := repoConfigDao.List(orgID, pageData, filterData) assert.Nil(t, err) assert.Equal(t, quantity, len(response.Data)) @@ -1218,7 +1259,9 @@ func (suite *RepositoryConfigSuite) TestListFilterSearch() { Version: "", } - _, err := GetRepositoryConfigDao(tx).Create(api.RepositoryRequest{ + repoConfigDao := GetRepositoryConfigDao(suite.tx, suite.mockPulpClient).WithContext(context.Background()) + + _, err := repoConfigDao.Create(api.RepositoryRequest{ OrgID: &orgID, AccountID: &accountID, Name: &name, @@ -1233,7 +1276,8 @@ func (suite *RepositoryConfigSuite) TestListFilterSearch() { assert.Nil(t, result.Error) assert.Equal(t, quantity, total) - response, total, err := GetRepositoryConfigDao(tx).List(orgID, pageData, filterData) + suite.mockPulpForListOrFetch(1) + response, total, err := repoConfigDao.List(orgID, pageData, filterData) assert.Nil(t, err) assert.Equal(t, int(quantity), len(response.Data)) @@ -1250,7 +1294,7 @@ func (suite *RepositoryConfigSuite) TestSavePublicUrls() { } // Create the two Repository records - err := GetRepositoryConfigDao(tx).SavePublicRepos(repoUrls) + err := GetRepositoryConfigDao(tx, suite.mockPulpClient).SavePublicRepos(repoUrls) require.NoError(t, err) repo := []models.Repository{} err = tx. @@ -1263,7 +1307,7 @@ func (suite *RepositoryConfigSuite) TestSavePublicUrls() { assert.Equal(t, int64(len(repo)), count) // Repeat to check clause on conflict - err = GetRepositoryConfigDao(suite.tx).SavePublicRepos(repoUrls) + err = GetRepositoryConfigDao(suite.tx, suite.mockPulpClient).SavePublicRepos(repoUrls) assert.NoError(t, err) err = tx. Model(&models.Repository{}). @@ -1290,7 +1334,7 @@ func (suite *RepositoryConfigSuite) TestDelete() { Error require.NoError(t, err) - err = GetRepositoryConfigDao(tx).SoftDelete(repoConfig.OrgID, repoConfig.UUID) + err = GetRepositoryConfigDao(tx, suite.mockPulpClient).SoftDelete(repoConfig.OrgID, repoConfig.UUID) assert.NoError(t, err) repoConfig2 := models.RepositoryConfiguration{} @@ -1315,7 +1359,7 @@ func (suite *RepositoryConfigSuite) TestDeleteNotFound() { Error require.NoError(t, err) - err = GetRepositoryConfigDao(suite.tx).SoftDelete("bad org id", found.UUID) + err = GetRepositoryConfigDao(suite.tx, suite.mockPulpClient).SoftDelete("bad org id", found.UUID) assert.Error(t, err) daoError, ok := err.(*ce.DaoError) assert.True(t, ok) @@ -1329,7 +1373,7 @@ func (suite *RepositoryConfigSuite) TestDeleteNotFound() { func (suite *RepositoryConfigSuite) TestBulkDelete() { t := suite.T() - dao := GetRepositoryConfigDao(suite.tx) + dao := GetRepositoryConfigDao(suite.tx, suite.mockPulpClient) orgID := seeds.RandomOrgId() repoConfigCount := 5 @@ -1352,7 +1396,7 @@ func (suite *RepositoryConfigSuite) TestBulkDelete() { func (suite *RepositoryConfigSuite) TestUpdateLastSnapshotTask() { t := suite.T() - dao := GetRepositoryConfigDao(suite.tx) + dao := GetRepositoryConfigDao(suite.tx, suite.mockPulpClient) orgID := seeds.RandomOrgId() repoConfigCount := 1 @@ -1377,7 +1421,7 @@ func (suite *RepositoryConfigSuite) TestUpdateLastSnapshotTask() { func (suite *RepositoryConfigSuite) TestBulkDeleteOneNotFound() { t := suite.T() - dao := GetRepositoryConfigDao(suite.tx) + dao := GetRepositoryConfigDao(suite.tx, suite.mockPulpClient) orgID := seeds.RandomOrgId() repoConfigCount := 5 @@ -1402,7 +1446,7 @@ func (suite *RepositoryConfigSuite) TestBulkDeleteOneNotFound() { func (suite *RepositoryConfigSuite) TestBulkDeleteRedhatRepository() { t := suite.T() - dao := GetRepositoryConfigDao(suite.tx) + dao := GetRepositoryConfigDao(suite.tx, suite.mockPulpClient) orgID := config.RedHatOrg repoConfigCount := 5 @@ -1421,7 +1465,7 @@ func (suite *RepositoryConfigSuite) TestBulkDeleteRedhatRepository() { func (suite *RepositoryConfigSuite) TestBulkDeleteMultipleNotFound() { t := suite.T() - dao := GetRepositoryConfigDao(suite.tx) + dao := GetRepositoryConfigDao(suite.tx, suite.mockPulpClient) orgID := seeds.RandomOrgId() repoConfigCount := 5 @@ -1723,7 +1767,7 @@ func (suite *RepositoryConfigSuite) TestListReposToSnapshot() { }() t := suite.T() - dao := GetRepositoryConfigDao(suite.tx) + dao := GetRepositoryConfigDao(suite.tx, suite.mockPulpClient) repo, err := dao.Create(api.RepositoryRequest{ Name: pointy.Pointer("name"), @@ -1807,7 +1851,7 @@ func (suite *RepositoryConfigSuite) TestListReposToSnapshot() { } func (suite *RepositoryConfigSuite) TestRefreshRedHatRepo() { - dao := GetRepositoryConfigDao(suite.tx) + dao := GetRepositoryConfigDao(suite.tx, suite.mockPulpClient) rhRepo := api.RepositoryRequest{ UUID: nil, Name: pointy.Pointer("Some redhat repo"), @@ -1834,3 +1878,9 @@ func (suite *RepositoryConfigSuite) TestRefreshRedHatRepo() { assert.Equal(suite.T(), *rhRepo.Name, response.Name) } + +func (suite *RepositoryConfigSuite) mockPulpForListOrFetch(times int) { + if config.Get().Features.Snapshots.Enabled { + suite.mockPulpClient.WithContextMock().WithDomainMock().On("GetContentPath").Return(testContentPath, nil).Times(times) + } +} diff --git a/pkg/dao/snapshots.go b/pkg/dao/snapshots.go index eaa3d2827..be62f2f5c 100644 --- a/pkg/dao/snapshots.go +++ b/pkg/dao/snapshots.go @@ -16,14 +16,22 @@ import ( type snapshotDaoImpl struct { db *gorm.DB pulpClient pulp_client.PulpClient + ctx context.Context } func GetSnapshotDao(db *gorm.DB) SnapshotDao { return &snapshotDaoImpl{ - db: db, + db: db, + ctx: context.Background(), } } +func (sDao *snapshotDaoImpl) WithContext(ctx context.Context) SnapshotDao { + cpy := *sDao + cpy.ctx = ctx + return &cpy +} + // Create records a snapshot of a repository func (sDao *snapshotDaoImpl) Create(s *models.Snapshot) error { trans := sDao.db.Create(s) @@ -109,7 +117,7 @@ func (sDao *snapshotDaoImpl) List( return api.SnapshotCollectionResponse{Data: []api.SnapshotResponse{}}, totalSnaps, nil } - pulpContentPath, err := sDao.pulpClient.GetContentPath() + pulpContentPath, err := sDao.pulpClient.WithContext(sDao.ctx).GetContentPath() if err != nil { return api.SnapshotCollectionResponse{}, 0, err } @@ -146,7 +154,8 @@ func (sDao *snapshotDaoImpl) GetRepositoryConfigurationFile(orgID, snapshotUUID, return "", err } - contentPath, err := sDao.pulpClient.GetContentPath() + pc := sDao.pulpClient.WithContext(sDao.ctx) + contentPath, err := pc.GetContentPath() if err != nil { return "", err } @@ -167,18 +176,6 @@ func (sDao *snapshotDaoImpl) GetRepositoryConfigurationFile(orgID, snapshotUUID, return fileConfig, nil } -func (sDao *snapshotDaoImpl) InitializePulpClient(ctx context.Context, orgID string) error { - dDao := GetDomainDao(sDao.db) - domainName, err := dDao.Fetch(orgID) - if err != nil { - return err - } - - pulpClient := pulp_client.GetPulpClientWithDomain(context.TODO(), domainName) - sDao.pulpClient = pulpClient - return nil -} - func (sDao *snapshotDaoImpl) FetchForRepoConfigUUID(repoConfigUUID string) ([]models.Snapshot, error) { var snaps []models.Snapshot result := sDao.db.Model(&models.Snapshot{}). diff --git a/pkg/dao/snapshots_mock.go b/pkg/dao/snapshots_mock.go index c370f1c83..c88336667 100644 --- a/pkg/dao/snapshots_mock.go +++ b/pkg/dao/snapshots_mock.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.33.0. DO NOT EDIT. +// Code generated by mockery v2.32.0. DO NOT EDIT. package dao @@ -119,20 +119,6 @@ func (_m *MockSnapshotDao) GetRepositoryConfigurationFile(orgID string, snapshot return r0, r1 } -// InitializePulpClient provides a mock function with given fields: ctx, orgID -func (_m *MockSnapshotDao) InitializePulpClient(ctx context.Context, orgID string) error { - ret := _m.Called(ctx, orgID) - - var r0 error - if rf, ok := ret.Get(0).(func(context.Context, string) error); ok { - r0 = rf(ctx, orgID) - } else { - r0 = ret.Error(0) - } - - return r0 -} - // List provides a mock function with given fields: orgID, repoConfigUuid, paginationData, filterData func (_m *MockSnapshotDao) List(orgID string, repoConfigUuid string, paginationData api.PaginationData, filterData api.FilterData) (api.SnapshotCollectionResponse, int64, error) { ret := _m.Called(orgID, repoConfigUuid, paginationData, filterData) @@ -164,6 +150,22 @@ func (_m *MockSnapshotDao) List(orgID string, repoConfigUuid string, paginationD return r0, r1, r2 } +// WithContext provides a mock function with given fields: ctx +func (_m *MockSnapshotDao) WithContext(ctx context.Context) SnapshotDao { + ret := _m.Called(ctx) + + var r0 SnapshotDao + if rf, ok := ret.Get(0).(func(context.Context) SnapshotDao); ok { + r0 = rf(ctx) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(SnapshotDao) + } + } + + return r0 +} + // NewMockSnapshotDao creates a new instance of MockSnapshotDao. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // The first argument is typically a *testing.T value. func NewMockSnapshotDao(t interface { diff --git a/pkg/dao/snapshots_test.go b/pkg/dao/snapshots_test.go index 7f69d201f..64fce9461 100644 --- a/pkg/dao/snapshots_test.go +++ b/pkg/dao/snapshots_test.go @@ -1,6 +1,7 @@ package dao import ( + "context" "fmt" "testing" "time" @@ -24,7 +25,7 @@ type SnapshotsSuite struct { func TestSnapshotsSuite(t *testing.T) { m := DaoSuite{} - r := SnapshotsSuite{&m} + r := SnapshotsSuite{DaoSuite: &m} suite.Run(t, &r) } @@ -109,10 +110,17 @@ func (s *SnapshotsSuite) TestCreateAndList() { tx := s.tx mockPulpClient := pulp_client.NewMockPulpClient(t) - sDao := snapshotDaoImpl{db: tx, pulpClient: mockPulpClient} - mockPulpClient.On("GetContentPath").Return(testContentPath, nil) + sDaoImpl := snapshotDaoImpl{db: tx, pulpClient: mockPulpClient} + sDao := sDaoImpl.WithContext(context.Background()) + + if config.Get().Features.Snapshots.Enabled { + mockPulpClient.WithContextMock().WithDomainMock().On("GetContentPath").Return(testContentPath, nil) + } else { + mockPulpClient.WithContextMock().On("GetContentPath").Return(testContentPath, nil) + } - repoDao := repositoryConfigDaoImpl{db: tx, yumRepo: &mockExt.YumRepositoryMock{}} + repoDaoImpl := repositoryConfigDaoImpl{db: tx, yumRepo: &mockExt.YumRepositoryMock{}, pulpClient: mockPulpClient} + repoDao := repoDaoImpl.WithContext(context.Background()) rConfig := s.createRepository() pageData := api.PaginationData{ @@ -129,7 +137,7 @@ func (s *SnapshotsSuite) TestCreateAndList() { collection, total, err := sDao.List(rConfig.OrgID, rConfig.UUID, pageData, filterData) - repository, _ := repoDao.fetchRepoConfig(rConfig.OrgID, rConfig.UUID, false) + repository, _ := repoDaoImpl.fetchRepoConfig(rConfig.OrgID, rConfig.UUID, false) repositoryList, repoCount, _ := repoDao.List(rConfig.OrgID, api.PaginationData{Limit: -1}, api.FilterData{}) assert.NoError(t, err) @@ -158,10 +166,17 @@ func (s *SnapshotsSuite) TestCreateAndListRedHatRepo() { tx := s.tx mockPulpClient := pulp_client.NewMockPulpClient(t) - sDao := snapshotDaoImpl{db: tx, pulpClient: mockPulpClient} - mockPulpClient.On("GetContentPath").Return(testContentPath, nil) + sDaoImpl := snapshotDaoImpl{db: tx, pulpClient: mockPulpClient} + sDao := sDaoImpl.WithContext(context.Background()) + + if config.Get().Features.Snapshots.Enabled { + mockPulpClient.WithContextMock().WithDomainMock().On("GetContentPath").Return(testContentPath, nil) + } else { + mockPulpClient.WithContextMock().On("GetContentPath").Return(testContentPath, nil) + } - repoDao := repositoryConfigDaoImpl{db: tx, yumRepo: &mockExt.YumRepositoryMock{}} + repoDaoImpl := repositoryConfigDaoImpl{db: tx, yumRepo: &mockExt.YumRepositoryMock{}, pulpClient: mockPulpClient} + repoDao := repoDaoImpl.WithContext(context.Background()) redhatRepositoryConfig := s.createRedhatRepository() redhatSnap := s.createSnapshot(redhatRepositoryConfig) @@ -178,7 +193,7 @@ func (s *SnapshotsSuite) TestCreateAndListRedHatRepo() { collection, total, err := sDao.List("ShouldNotMatter", redhatRepositoryConfig.UUID, pageData, filterData) - repository, _ := repoDao.fetchRepoConfig("ShouldNotMatter", redhatRepositoryConfig.UUID, true) + repository, _ := repoDaoImpl.fetchRepoConfig("ShouldNotMatter", redhatRepositoryConfig.UUID, true) repositoryList, repoCount, _ := repoDao.List("ShouldNotMatter", api.PaginationData{Limit: -1}, api.FilterData{}) assert.NoError(t, err) @@ -246,8 +261,10 @@ func (s *SnapshotsSuite) TestListPageLimit() { tx := s.tx mockPulpClient := pulp_client.NewMockPulpClient(t) - sDao := snapshotDaoImpl{db: tx, pulpClient: mockPulpClient} - mockPulpClient.On("GetContentPath").Return(testContentPath, nil).Once() + sDaoImpl := snapshotDaoImpl{db: tx, pulpClient: mockPulpClient} + sDao := sDaoImpl.WithContext(context.Background()) + + mockPulpClient.WithContextMock().On("GetContentPath").Return(testContentPath, nil) rConfig := s.createRepository() pageData := api.PaginationData{ @@ -395,15 +412,15 @@ func (s *SnapshotsSuite) TestGetRepositoryConfigurationFile() { snapshot := s.createSnapshot(repoConfig) // Test happy scenario - mockPulpClient.On("GetContentPath").Return(testContentPath, nil).Once() - repoConfigFile, err := sDao.GetRepositoryConfigurationFile(repoConfig.OrgID, snapshot.UUID, repoConfig.UUID) + mockPulpClient.WithContextMock().On("GetContentPath").Return(testContentPath, nil).Once() + repoConfigFile, err := sDao.WithContext(context.Background()).GetRepositoryConfigurationFile(repoConfig.OrgID, snapshot.UUID, repoConfig.UUID) assert.NoError(t, err) assert.Contains(t, repoConfigFile, repoConfig.Name) assert.Contains(t, repoConfigFile, testContentPath) // Test error from pulp call - mockPulpClient.On("GetContentPath").Return("", fmt.Errorf("some error")).Once() - repoConfigFile, err = sDao.GetRepositoryConfigurationFile(repoConfig.OrgID, snapshot.UUID, repoConfig.UUID) + mockPulpClient.WithContextMock().On("GetContentPath").Return("", fmt.Errorf("some error")).Once() + repoConfigFile, err = sDao.WithContext(context.Background()).GetRepositoryConfigurationFile(repoConfig.OrgID, snapshot.UUID, repoConfig.UUID) assert.Error(t, err) assert.Empty(t, repoConfigFile) } @@ -418,9 +435,12 @@ func (s *SnapshotsSuite) TestGetRepositoryConfigurationFileNotFound() { repoConfig := s.createRepository() snapshot := s.createSnapshot(repoConfig) + if config.Get().Features.Snapshots.Enabled { + mockPulpClient.WithContextMock().On("GetContentPath").Return(testContentPath, nil).Times(3) + } + // Test bad repo UUID - mockPulpClient.On("GetContentPath").Return(testContentPath, nil).Once() - repoConfigFile, err := sDao.GetRepositoryConfigurationFile(repoConfig.OrgID, snapshot.UUID, uuid2.NewString()) + repoConfigFile, err := sDao.WithContext(context.Background()).GetRepositoryConfigurationFile(repoConfig.OrgID, snapshot.UUID, uuid2.NewString()) assert.Error(t, err) if err != nil { daoError, ok := err.(*ce.DaoError) @@ -431,8 +451,7 @@ func (s *SnapshotsSuite) TestGetRepositoryConfigurationFileNotFound() { assert.Empty(t, repoConfigFile) // Test bad snapshot UUID - mockPulpClient.On("GetContentPath").Return(testContentPath, nil).Once() - repoConfigFile, err = sDao.GetRepositoryConfigurationFile(repoConfig.OrgID, uuid2.NewString(), repoConfig.UUID) + repoConfigFile, err = sDao.WithContext(context.Background()).GetRepositoryConfigurationFile(repoConfig.OrgID, uuid2.NewString(), repoConfig.UUID) assert.Error(t, err) if err != nil { daoError, ok := err.(*ce.DaoError) @@ -443,8 +462,7 @@ func (s *SnapshotsSuite) TestGetRepositoryConfigurationFileNotFound() { assert.Empty(t, repoConfigFile) // Test bad org ID - mockPulpClient.On("GetContentPath").Return(testContentPath, nil).Once() - repoConfigFile, err = sDao.GetRepositoryConfigurationFile("bad orgID", snapshot.UUID, repoConfig.UUID) + repoConfigFile, err = sDao.WithContext(context.Background()).GetRepositoryConfigurationFile("bad orgID", snapshot.UUID, repoConfig.UUID) assert.Error(t, err) if err != nil { daoError, ok := err.(*ce.DaoError) diff --git a/pkg/dao/task_info_test.go b/pkg/dao/task_info_test.go index 4c16f3d79..da7d67369 100644 --- a/pkg/dao/task_info_test.go +++ b/pkg/dao/task_info_test.go @@ -1,6 +1,7 @@ package dao import ( + "context" "encoding/json" "testing" "time" @@ -9,6 +10,7 @@ import ( "github.com/content-services/content-sources-backend/pkg/config" ce "github.com/content-services/content-sources-backend/pkg/errors" "github.com/content-services/content-sources-backend/pkg/models" + "github.com/content-services/content-sources-backend/pkg/pulp_client" "github.com/content-services/content-sources-backend/pkg/seeds" "github.com/google/uuid" "github.com/openlyinc/pointy" @@ -411,11 +413,15 @@ type CleanupTestCase struct { beDeleted bool } -func (suite *TaskInfoSuite) TestCleanup() { +func (suite *TaskInfoSuite) TestTaskCleanup() { err := seeds.SeedRepositoryConfigurations(suite.tx, 2, seeds.SeedOptions{OrgID: orgIDTest}) assert.NoError(suite.T(), err) - repoConfigDao := GetRepositoryConfigDao(suite.tx) + mockPulpClient := pulp_client.NewMockPulpClient(suite.T()) + repoConfigDao := GetRepositoryConfigDao(suite.tx, mockPulpClient).WithContext(context.Background()) + if config.Get().Features.Snapshots.Enabled { + mockPulpClient.WithContextMock().WithDomainMock().On("GetContentPath").Return(testContentPath, nil) + } results, _, _ := repoConfigDao.List(orgIDTest, api.PaginationData{Limit: 2}, api.FilterData{}) if len(results.Data) != 2 { assert.Fail(suite.T(), "Expected to create 2 repo configs") diff --git a/pkg/handler/popular_repositories.go b/pkg/handler/popular_repositories.go index e27c4932e..88b8cb32a 100644 --- a/pkg/handler/popular_repositories.go +++ b/pkg/handler/popular_repositories.go @@ -92,13 +92,8 @@ func filterPopularRepositories(configData []api.PopularRepositoryResponse, filte func (rh *PopularRepositoriesHandler) updateIfExists(c echo.Context, repo *api.PopularRepositoryResponse) error { _, orgID := getAccountIdOrgId(c) - err := rh.Dao.RepositoryConfig.InitializePulpClient(c.Request().Context(), orgID) - if err != nil { - return ce.NewErrorResponse(ce.HttpCodeForDaoError(err), "Error initializing pulp client", err.Error()) - } - // Go get the records for this URL - repos, _, err := rh.Dao.RepositoryConfig.List(orgID, api.PaginationData{Limit: 1}, api.FilterData{Search: repo.URL}) + repos, _, err := rh.Dao.RepositoryConfig.WithContext(c.Request().Context()).List(orgID, api.PaginationData{Limit: 1}, api.FilterData{Search: repo.URL}) if err != nil { return ce.NewErrorResponseFromError("Could not get repository list", err) } diff --git a/pkg/handler/popular_repositories_test.go b/pkg/handler/popular_repositories_test.go index 2138c3485..cf8b1d1f7 100644 --- a/pkg/handler/popular_repositories_test.go +++ b/pkg/handler/popular_repositories_test.go @@ -17,7 +17,6 @@ import ( "github.com/labstack/echo/v4" "github.com/redhatinsights/platform-go-middlewares/identity" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" "github.com/stretchr/testify/suite" ) @@ -80,10 +79,9 @@ func (s *PopularReposSuite) servePopularRepositoriesRouter(req *http.Request) (i func (s *PopularReposSuite) TestPopularRepos() { collection := createRepoCollection(0, 10, 0) paginationData := api.PaginationData{Limit: 1} - s.dao.RepositoryConfig.On("List", test_handler.MockOrgId, paginationData, api.FilterData{Search: "https://dl.fedoraproject.org/pub/epel/9/Everything/x86_64/"}).Return(collection, int64(0), nil) + s.dao.RepositoryConfig.WithContextMock().On("List", test_handler.MockOrgId, paginationData, api.FilterData{Search: "https://dl.fedoraproject.org/pub/epel/9/Everything/x86_64/"}).Return(collection, int64(0), nil) s.dao.RepositoryConfig.On("List", test_handler.MockOrgId, paginationData, api.FilterData{Search: "https://dl.fedoraproject.org/pub/epel/8/Everything/x86_64/"}).Return(collection, int64(0), nil) s.dao.RepositoryConfig.On("List", test_handler.MockOrgId, paginationData, api.FilterData{Search: "https://dl.fedoraproject.org/pub/epel/7/x86_64/"}).Return(collection, int64(0), nil) - s.dao.RepositoryConfig.On("InitializePulpClient", mock.AnythingOfType("*context.valueCtx"), test_handler.MockOrgId).Return(nil).Times(3) path := fmt.Sprintf("%s/popular_repositories/?limit=%d", fullRootPath(), 10) req := httptest.NewRequest(http.MethodGet, path, nil) @@ -110,8 +108,7 @@ func (s *PopularReposSuite) TestPopularReposSearchWithExisting() { existingName := "bestNameEver" collection := api.RepositoryCollectionResponse{Data: []api.RepositoryResponse{{UUID: magicalUUID, Name: existingName, URL: popularRepository.URL, DistributionVersions: popularRepository.DistributionVersions, DistributionArch: popularRepository.DistributionArch}}} paginationData := api.PaginationData{Limit: 1} - s.dao.RepositoryConfig.On("List", test_handler.MockOrgId, paginationData, api.FilterData{Search: popularRepository.URL}).Return(collection, int64(0), nil) - s.dao.RepositoryConfig.On("InitializePulpClient", mock.AnythingOfType("*context.valueCtx"), test_handler.MockOrgId).Return(nil).Once() + s.dao.RepositoryConfig.WithContextMock().On("List", test_handler.MockOrgId, paginationData, api.FilterData{Search: popularRepository.URL}).Return(collection, int64(0), nil) path := fmt.Sprintf("%s/popular_repositories/?limit=%d&search=%s", fullRootPath(), 10, popularRepository.URL) req := httptest.NewRequest(http.MethodGet, path, nil) @@ -138,8 +135,7 @@ func (s *PopularReposSuite) TestPopularReposSearchWithExisting() { func (s *PopularReposSuite) TestPopularReposSearchByURL() { collection := createRepoCollection(0, 10, 0) paginationData := api.PaginationData{Limit: 1} - s.dao.RepositoryConfig.On("List", test_handler.MockOrgId, paginationData, api.FilterData{Search: popularRepository.URL}).Return(collection, int64(0), nil) - s.dao.RepositoryConfig.On("InitializePulpClient", mock.AnythingOfType("*context.valueCtx"), test_handler.MockOrgId).Return(nil).Once() + s.dao.RepositoryConfig.WithContextMock().On("List", test_handler.MockOrgId, paginationData, api.FilterData{Search: popularRepository.URL}).Return(collection, int64(0), nil) path := fmt.Sprintf("%s/popular_repositories/?limit=%d&search=%s", fullRootPath(), 10, popularRepository.URL) req := httptest.NewRequest(http.MethodGet, path, nil) req.Header.Set(api.IdentityHeader, test_handler.EncodedIdentity(s.T())) @@ -163,8 +159,7 @@ func (s *PopularReposSuite) TestPopularReposSearchByURL() { func (s *PopularReposSuite) TestPopularReposSearchByName() { collection := createRepoCollection(0, 10, 0) paginationData := api.PaginationData{Limit: 1} - s.dao.RepositoryConfig.On("List", test_handler.MockOrgId, paginationData, api.FilterData{Search: popularRepository.URL}).Return(collection, int64(0), nil) - s.dao.RepositoryConfig.On("InitializePulpClient", mock.AnythingOfType("*context.valueCtx"), test_handler.MockOrgId).Return(nil).Once() + s.dao.RepositoryConfig.WithContextMock().On("List", test_handler.MockOrgId, paginationData, api.FilterData{Search: popularRepository.URL}).Return(collection, int64(0), nil) path := fmt.Sprintf("%s/popular_repositories/?limit=%d&search=%s", fullRootPath(), 10, url.QueryEscape(popularRepository.SuggestedName)) req := httptest.NewRequest(http.MethodGet, path, nil) diff --git a/pkg/handler/repositories.go b/pkg/handler/repositories.go index 370abd695..5a8756eca 100644 --- a/pkg/handler/repositories.go +++ b/pkg/handler/repositories.go @@ -109,12 +109,7 @@ func (rh *RepositoryHandler) listRepositories(c echo.Context) error { pageData := ParsePagination(c) filterData := ParseFilters(c) - err := rh.DaoRegistry.RepositoryConfig.InitializePulpClient(c.Request().Context(), orgID) - if err != nil { - return ce.NewErrorResponse(ce.HttpCodeForDaoError(err), "Error initializing pulp client", err.Error()) - } - - repos, totalRepos, err := rh.DaoRegistry.RepositoryConfig.List(orgID, pageData, filterData) + repos, totalRepos, err := rh.DaoRegistry.RepositoryConfig.WithContext(c.Request().Context()).List(orgID, pageData, filterData) if err != nil { return ce.NewErrorResponse(ce.HttpCodeForDaoError(err), "Error listing repositories", err.Error()) } @@ -243,12 +238,7 @@ func (rh *RepositoryHandler) fetch(c echo.Context) error { _, orgID := getAccountIdOrgId(c) uuid := c.Param("uuid") - err := rh.DaoRegistry.RepositoryConfig.InitializePulpClient(c.Request().Context(), orgID) - if err != nil { - return ce.NewErrorResponse(ce.HttpCodeForDaoError(err), "Error initializing pulp client", err.Error()) - } - - response, err := rh.DaoRegistry.RepositoryConfig.Fetch(orgID, uuid) + response, err := rh.DaoRegistry.RepositoryConfig.WithContext(c.Request().Context()).Fetch(orgID, uuid) if err != nil { return ce.NewErrorResponse(ce.HttpCodeForDaoError(err), "Error fetching repository", err.Error()) } @@ -310,12 +300,7 @@ func (rh *RepositoryHandler) update(c echo.Context, fillDefaults bool) error { repoParams.FillDefaults() } - err := rh.DaoRegistry.RepositoryConfig.InitializePulpClient(c.Request().Context(), orgID) - if err != nil { - return ce.NewErrorResponse(ce.HttpCodeForDaoError(err), "Error initializing pulp client", err.Error()) - } - - repoConfig, err := rh.DaoRegistry.RepositoryConfig.Fetch(orgID, uuid) + repoConfig, err := rh.DaoRegistry.RepositoryConfig.WithContext(c.Request().Context()).Fetch(orgID, uuid) if err != nil { return ce.NewErrorResponse(ce.HttpCodeForDaoError(err), "Error fetching repository", err.Error()) } @@ -338,7 +323,7 @@ func (rh *RepositoryHandler) update(c echo.Context, fillDefaults bool) error { } } - response, err := rh.DaoRegistry.RepositoryConfig.Fetch(orgID, uuid) + response, err := rh.DaoRegistry.RepositoryConfig.WithContext(c.Request().Context()).Fetch(orgID, uuid) if urlUpdated && response.Snapshot { rh.enqueueSnapshotEvent(c, &response) } @@ -367,7 +352,7 @@ func (rh *RepositoryHandler) deleteRepository(c echo.Context) error { _, orgID := getAccountIdOrgId(c) uuid := c.Param("uuid") - repoConfig, err := rh.DaoRegistry.RepositoryConfig.Fetch(orgID, uuid) + repoConfig, err := rh.DaoRegistry.RepositoryConfig.WithContext(c.Request().Context()).Fetch(orgID, uuid) if err != nil { return ce.NewErrorResponse(ce.HttpCodeForDaoError(err), "Error fetching repository", err.Error()) } @@ -425,7 +410,7 @@ func (rh *RepositoryHandler) bulkDeleteRepositories(c echo.Context) error { hasErr := false errs := make([]error, len(uuids)) for i := range uuids { - repoConfig, err := rh.DaoRegistry.RepositoryConfig.Fetch(orgID, uuids[i]) + repoConfig, err := rh.DaoRegistry.RepositoryConfig.WithContext(c.Request().Context()).Fetch(orgID, uuids[i]) responses[i] = repoConfig if err != nil { hasErr = true @@ -487,7 +472,7 @@ func (rh *RepositoryHandler) introspect(c echo.Context) error { return ce.NewErrorResponse(http.StatusBadRequest, "Error binding parameters", err.Error()) } - response, err := rh.DaoRegistry.RepositoryConfig.Fetch(orgID, uuid) + response, err := rh.DaoRegistry.RepositoryConfig.WithContext(c.Request().Context()).Fetch(orgID, uuid) if err != nil { return ce.NewErrorResponse(ce.HttpCodeForDaoError(err), "Error fetching repository", err.Error()) } diff --git a/pkg/handler/repositories_test.go b/pkg/handler/repositories_test.go index bfd53da9f..6ce50548a 100644 --- a/pkg/handler/repositories_test.go +++ b/pkg/handler/repositories_test.go @@ -27,7 +27,6 @@ import ( "github.com/openlyinc/pointy" "github.com/redhatinsights/platform-go-middlewares/identity" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" "github.com/stretchr/testify/suite" ) @@ -149,8 +148,7 @@ func (suite *ReposSuite) TestSimple() { collection := createRepoCollection(1, 10, 0) paginationData := api.PaginationData{Limit: 10, Offset: DefaultOffset} - suite.reg.RepositoryConfig.On("List", test_handler.MockOrgId, paginationData, api.FilterData{}).Return(collection, int64(1), nil) - suite.reg.RepositoryConfig.On("InitializePulpClient", mock.AnythingOfType("*context.valueCtx"), test_handler.MockOrgId).Return(nil).Once() + suite.reg.RepositoryConfig.WithContextMock().On("List", test_handler.MockOrgId, paginationData, api.FilterData{}).Return(collection, int64(1), nil) path := fmt.Sprintf("%s/repositories/?limit=%d", fullRootPath(), 10) req := httptest.NewRequest(http.MethodGet, path, nil) @@ -187,8 +185,7 @@ func (suite *ReposSuite) TestListNoRepositories() { collection := api.RepositoryCollectionResponse{} paginationData := api.PaginationData{Limit: DefaultLimit, Offset: DefaultOffset} - suite.reg.RepositoryConfig.On("List", test_handler.MockOrgId, paginationData, api.FilterData{}).Return(collection, int64(0), nil) - suite.reg.RepositoryConfig.On("InitializePulpClient", mock.AnythingOfType("*context.valueCtx"), test_handler.MockOrgId).Return(nil).Once() + suite.reg.RepositoryConfig.WithContextMock().On("List", test_handler.MockOrgId, paginationData, api.FilterData{}).Return(collection, int64(0), nil) req := httptest.NewRequest(http.MethodGet, fullRootPath()+"/repositories/", nil) req.Header.Set(api.IdentityHeader, test_handler.EncodedIdentity(t)) @@ -215,9 +212,8 @@ func (suite *ReposSuite) TestListPagedExtraRemaining() { paginationData1 := api.PaginationData{Limit: 10, Offset: 0} paginationData2 := api.PaginationData{Limit: 10, Offset: 100} - suite.reg.RepositoryConfig.On("List", test_handler.MockOrgId, paginationData1, api.FilterData{}).Return(collection, int64(102), nil).Once() + suite.reg.RepositoryConfig.WithContextMock().On("List", test_handler.MockOrgId, paginationData1, api.FilterData{}).Return(collection, int64(102), nil).Once() suite.reg.RepositoryConfig.On("List", test_handler.MockOrgId, paginationData2, api.FilterData{}).Return(collection, int64(102), nil).Once() - suite.reg.RepositoryConfig.On("InitializePulpClient", mock.AnythingOfType("*context.valueCtx"), test_handler.MockOrgId).Return(nil).Twice() path := fmt.Sprintf("%s/repositories/?limit=%d", fullRootPath(), 10) req := httptest.NewRequest(http.MethodGet, path, nil) @@ -251,8 +247,7 @@ func (suite *ReposSuite) TestListWithFilters() { t := suite.T() collection := api.RepositoryCollectionResponse{} - suite.reg.RepositoryConfig.On("List", test_handler.MockOrgId, api.PaginationData{Limit: 100}, api.FilterData{ContentType: "rpm", Origin: "external"}).Return(collection, int64(100), nil) - suite.reg.RepositoryConfig.On("InitializePulpClient", mock.AnythingOfType("*context.valueCtx"), test_handler.MockOrgId).Return(nil).Once() + suite.reg.RepositoryConfig.WithContextMock().On("List", test_handler.MockOrgId, api.PaginationData{Limit: 100}, api.FilterData{ContentType: "rpm", Origin: "external"}).Return(collection, int64(100), nil) path := fmt.Sprintf("%s/repositories/?origin=%v&content_type=%v", fullRootPath(), "external", "rpm") req := httptest.NewRequest(http.MethodGet, path, nil) @@ -269,9 +264,8 @@ func (suite *ReposSuite) TestListPagedNoRemaining() { paginationData2 := api.PaginationData{Limit: 10, Offset: 90} collection := api.RepositoryCollectionResponse{} - suite.reg.RepositoryConfig.On("List", test_handler.MockOrgId, paginationData1, api.FilterData{}).Return(collection, int64(100), nil) + suite.reg.RepositoryConfig.WithContextMock().On("List", test_handler.MockOrgId, paginationData1, api.FilterData{}).Return(collection, int64(100), nil) suite.reg.RepositoryConfig.On("List", test_handler.MockOrgId, paginationData2, api.FilterData{}).Return(collection, int64(100), nil) - suite.reg.RepositoryConfig.On("InitializePulpClient", mock.AnythingOfType("*context.valueCtx"), test_handler.MockOrgId).Return(nil).Twice() path := fmt.Sprintf("%s/repositories/?limit=%d", fullRootPath(), 10) req := httptest.NewRequest(http.MethodGet, path, nil) @@ -309,9 +303,8 @@ func (suite *ReposSuite) TestListDaoError() { } paginationData := api.PaginationData{Limit: DefaultLimit} - suite.reg.RepositoryConfig.On("List", test_handler.MockOrgId, paginationData, api.FilterData{}). + suite.reg.RepositoryConfig.WithContextMock().On("List", test_handler.MockOrgId, paginationData, api.FilterData{}). Return(api.RepositoryCollectionResponse{}, int64(0), &daoError) - suite.reg.RepositoryConfig.On("InitializePulpClient", mock.AnythingOfType("*context.valueCtx"), test_handler.MockOrgId).Return(nil).Once() path := fmt.Sprintf("%s/repositories/", fullRootPath()) req := httptest.NewRequest(http.MethodGet, path, nil) @@ -332,8 +325,7 @@ func (suite *ReposSuite) TestFetch() { UUID: uuid, } - suite.reg.RepositoryConfig.On("Fetch", test_handler.MockOrgId, uuid).Return(repo, nil) - suite.reg.RepositoryConfig.On("InitializePulpClient", mock.AnythingOfType("*context.valueCtx"), test_handler.MockOrgId).Return(nil).Once() + suite.reg.RepositoryConfig.WithContextMock().On("Fetch", test_handler.MockOrgId, uuid).Return(repo, nil) body, err := json.Marshal(repo) if err != nil { @@ -369,8 +361,7 @@ func (suite *ReposSuite) TestFetchNotFound() { NotFound: true, Message: "Not found", } - suite.reg.RepositoryConfig.On("Fetch", test_handler.MockOrgId, uuid).Return(api.RepositoryResponse{}, &daoError) - suite.reg.RepositoryConfig.On("InitializePulpClient", mock.AnythingOfType("*context.valueCtx"), test_handler.MockOrgId).Return(nil).Once() + suite.reg.RepositoryConfig.WithContextMock().On("Fetch", test_handler.MockOrgId, uuid).Return(api.RepositoryResponse{}, &daoError) body, err := json.Marshal(repo) if err != nil { @@ -633,7 +624,7 @@ func (suite *ReposSuite) TestDelete() { t := suite.T() uuid := "valid-uuid" - suite.reg.RepositoryConfig.On("Fetch", test_handler.MockOrgId, uuid).Return(api.RepositoryResponse{ + suite.reg.RepositoryConfig.WithContextMock().On("Fetch", test_handler.MockOrgId, uuid).Return(api.RepositoryResponse{ Name: "my repo", URL: "https://example.com", UUID: uuid, @@ -659,7 +650,7 @@ func (suite *ReposSuite) TestDeleteNotFound() { NotFound: true, } - suite.reg.RepositoryConfig.On("Fetch", test_handler.MockOrgId, uuid).Return(api.RepositoryResponse{ + suite.reg.RepositoryConfig.WithContextMock().On("Fetch", test_handler.MockOrgId, uuid).Return(api.RepositoryResponse{ Name: "my repo", URL: "https://example.com", UUID: uuid, @@ -680,7 +671,7 @@ func (suite *ReposSuite) TestSnapshotInProgress() { t := suite.T() uuid := "inprogress-uuid" - suite.reg.RepositoryConfig.On("Fetch", test_handler.MockOrgId, uuid).Return(api.RepositoryResponse{ + suite.reg.RepositoryConfig.WithContextMock().On("Fetch", test_handler.MockOrgId, uuid).Return(api.RepositoryResponse{ Name: "my repo", URL: "https://example.com", UUID: uuid, @@ -701,7 +692,7 @@ func (suite *ReposSuite) TestBulkDelete() { uuids := []string{"uuid-1", "uuid-2"} for i := range uuids { - suite.reg.RepositoryConfig.On("Fetch", test_handler.MockOrgId, uuids[i]).Return(api.RepositoryResponse{ + suite.reg.RepositoryConfig.WithContextMock().On("Fetch", test_handler.MockOrgId, uuids[i]).Return(api.RepositoryResponse{ Name: fmt.Sprintf("my repo %d", i), URL: fmt.Sprintf("https://example.com/%d", i), UUID: uuids[i], @@ -757,7 +748,7 @@ func (suite *ReposSuite) TestBulkDeleteNotFound() { NotFound: true, } - suite.reg.RepositoryConfig.On("Fetch", test_handler.MockOrgId, uuids[0]).Return(api.RepositoryResponse{ + suite.reg.RepositoryConfig.WithContextMock().On("Fetch", test_handler.MockOrgId, uuids[0]).Return(api.RepositoryResponse{ Name: "my repo", URL: "https://example.com/%d", UUID: uuids[0], @@ -791,7 +782,7 @@ func (suite *ReposSuite) TestBulkDeleteSnapshotInProgress() { uuids := []string{"inprogress-uuid", "uuid-1"} for i := range uuids { - suite.reg.RepositoryConfig.On("Fetch", test_handler.MockOrgId, uuids[i]).Return(api.RepositoryResponse{ + suite.reg.RepositoryConfig.WithContextMock().On("Fetch", test_handler.MockOrgId, uuids[i]).Return(api.RepositoryResponse{ Name: fmt.Sprintf("my repo %d", i), URL: fmt.Sprintf("https://example.com/%d", i), UUID: uuids[i], @@ -853,7 +844,7 @@ func (suite *ReposSuite) TestFullUpdate() { expected := createRepoRequest(*request.Name, *request.URL) expected.FillDefaults() - suite.reg.RepositoryConfig.On("Update", test_handler.MockOrgId, uuid, expected).Return(false, nil) + suite.reg.RepositoryConfig.WithContextMock().On("Update", test_handler.MockOrgId, uuid, expected).Return(false, nil) suite.reg.RepositoryConfig.On("Fetch", test_handler.MockOrgId, uuid).Return(api.RepositoryResponse{ Name: "my repo", URL: "https://example.com", @@ -861,7 +852,6 @@ func (suite *ReposSuite) TestFullUpdate() { RepositoryUUID: repoUuid, }, nil) suite.reg.RepositoryConfig.On("Update", test_handler.MockOrgId, uuid, expected).Return(false, nil) - suite.reg.RepositoryConfig.On("InitializePulpClient", mock.AnythingOfType("*context.valueCtx"), test_handler.MockOrgId).Return(nil).Once() mockTaskClientEnqueueIntrospect(suite.tcMock, "https://example.com", repoUuid) @@ -894,10 +884,10 @@ func (suite *ReposSuite) TestPartialUpdateUrlChange() { RepositoryUUID: repoUuid, Snapshot: true, } - suite.reg.RepositoryConfig.On("Update", test_handler.MockOrgId, repoConfigUuid, expected).Return(true, nil) + + suite.reg.RepositoryConfig.WithContextMock().On("Update", test_handler.MockOrgId, repoConfigUuid, expected).Return(true, nil) suite.reg.RepositoryConfig.On("Fetch", test_handler.MockOrgId, repoConfigUuid).Return(repoConfig, nil) suite.reg.TaskInfo.On("IsSnapshotInProgress", *expected.OrgID, repoUuid).Return(false, nil) - suite.reg.RepositoryConfig.On("InitializePulpClient", mock.AnythingOfType("*context.valueCtx"), test_handler.MockOrgId).Return(nil).Once() mockTaskClientEnqueueSnapshot(suite, &repoConfig) mockTaskClientEnqueueIntrospect(suite.tcMock, "https://example.com", repoUuid) @@ -924,7 +914,7 @@ func (suite *ReposSuite) TestPartialUpdate() { request := createRepoRequest("Some Name", "https://example.com") expected := createRepoRequest(*request.Name, *request.URL) - suite.reg.RepositoryConfig.On("Update", test_handler.MockOrgId, uuid, expected).Return(true, nil) + suite.reg.RepositoryConfig.WithContextMock().On("Update", test_handler.MockOrgId, uuid, expected).Return(true, nil) suite.reg.RepositoryConfig.On("Fetch", test_handler.MockOrgId, uuid).Return(api.RepositoryResponse{ Name: "my repo", URL: "https://example.com", @@ -932,7 +922,6 @@ func (suite *ReposSuite) TestPartialUpdate() { RepositoryUUID: repoUuid, Snapshot: false, }, nil) - suite.reg.RepositoryConfig.On("InitializePulpClient", mock.AnythingOfType("*context.valueCtx"), test_handler.MockOrgId).Return(nil).Once() mockTaskClientEnqueueIntrospect(suite.tcMock, "https://example.com", repoUuid) @@ -975,7 +964,7 @@ func (suite *ReposSuite) TestIntrospectRepository() { // Fetch will filter the request by Org ID before updating suite.reg.Repository.On("Update", repoUpdate).Return(nil).NotBefore( suite.reg.Repository.On("FetchForUrl", repoResp.URL).Return(repo, nil).NotBefore( - suite.reg.RepositoryConfig.On("Fetch", test_handler.MockOrgId, uuid).Return(repoResp, nil), + suite.reg.RepositoryConfig.WithContextMock().On("Fetch", test_handler.MockOrgId, uuid).Return(repoResp, nil), ), ) body, err := json.Marshal(intReq) @@ -1012,7 +1001,7 @@ func (suite *ReposSuite) TestIntrospectRepositoryBeforeTimeLimit() { // Fetch will filter the request by Org ID before updating suite.reg.Repository.On("FetchForUrl", repoResp.URL).Return(repo, nil).NotBefore( - suite.reg.RepositoryConfig.On("Fetch", test_handler.MockOrgId, uuid).Return(repoResp, nil), + suite.reg.RepositoryConfig.WithContextMock().On("Fetch", test_handler.MockOrgId, uuid).Return(repoResp, nil), ) body, err := json.Marshal(intReq) if err != nil { diff --git a/pkg/handler/snapshots.go b/pkg/handler/snapshots.go index fe4240e9f..df19a110e 100644 --- a/pkg/handler/snapshots.go +++ b/pkg/handler/snapshots.go @@ -41,17 +41,12 @@ func RegisterSnapshotRoutes(group *echo.Group, daoReg *dao.DaoRegistry) { // @Failure 500 {object} ce.ErrorResponse // @Router /repositories/{uuid}/snapshots/ [get] func (sh *SnapshotHandler) listSnapshots(c echo.Context) error { - _, orgID := getAccountIdOrgId(c) uuid := c.Param("uuid") pageData := ParsePagination(c) filterData := ParseFilters(c) + _, orgID := getAccountIdOrgId(c) - err := sh.DaoRegistry.Snapshot.InitializePulpClient(c.Request().Context(), orgID) - if err != nil { - return ce.NewErrorResponse(ce.HttpCodeForDaoError(err), "Error initializing pulp client", err.Error()) - } - - snapshots, totalSnaps, err := sh.DaoRegistry.Snapshot.List(orgID, uuid, pageData, filterData) + snapshots, totalSnaps, err := sh.DaoRegistry.Snapshot.WithContext(c.Request().Context()).List(orgID, uuid, pageData, filterData) if err != nil { return ce.NewErrorResponse(ce.HttpCodeForDaoError(err), "Error listing repository snapshots", err.Error()) } @@ -78,12 +73,7 @@ func (sh *SnapshotHandler) getRepoConfigurationFile(c echo.Context) error { snapshotUUID := c.Param("snapshot_uuid") var repoConfigFile string - err := sh.DaoRegistry.Snapshot.InitializePulpClient(c.Request().Context(), orgID) - if err != nil { - return ce.NewErrorResponse(ce.HttpCodeForDaoError(err), "Error initializing pulp client", err.Error()) - } - - repoConfigFile, err = sh.DaoRegistry.Snapshot.GetRepositoryConfigurationFile(orgID, snapshotUUID, uuid) + repoConfigFile, err := sh.DaoRegistry.Snapshot.WithContext(c.Request().Context()).GetRepositoryConfigurationFile(orgID, snapshotUUID, uuid) if err != nil { return ce.NewErrorResponse(ce.HttpCodeForDaoError(err), "Error getting repository configuration file", err.Error()) } diff --git a/pkg/handler/snapshots_test.go b/pkg/handler/snapshots_test.go index 1f0bcf455..5717eeb55 100644 --- a/pkg/handler/snapshots_test.go +++ b/pkg/handler/snapshots_test.go @@ -63,12 +63,13 @@ func (suite *SnapshotSuite) TestSnapshotList() { paginationData := api.PaginationData{Limit: 10, Offset: DefaultOffset} collection := createSnapshotCollection(1, 10, 0) - uuid := "abcadaba" - orgID := test_handler.MockOrgId - suite.reg.Snapshot.On("InitializePulpClient", mock.AnythingOfType("*context.valueCtx"), orgID).Return(nil).Once() - suite.reg.Snapshot.On("List", test_handler.MockOrgId, uuid, paginationData, api.FilterData{}).Return(collection, int64(1), nil) + repoUUID := "abcadaba" + suite.reg.Snapshot.WithContextMock().On("List", repoUUID, paginationData, api.FilterData{}).Return(collection, int64(1), nil) + + suite.reg.Snapshot.On("WithContext", mock.AnythingOfType("*context.valueCtx")).Return(&suite.reg.Snapshot).Once() + suite.reg.Snapshot.On("List", test_handler.MockOrgId, repoUUID, paginationData, api.FilterData{}).Return(collection, int64(1), nil) - path := fmt.Sprintf("%s/repositories/%s/snapshots/?limit=%d", fullRootPath(), uuid, 10) + path := fmt.Sprintf("%s/repositories/%s/snapshots/?limit=%d", fullRootPath(), repoUUID, 10) req := httptest.NewRequest(http.MethodGet, path, nil) req.Header.Set(api.IdentityHeader, test_handler.EncodedIdentity(t)) @@ -96,8 +97,7 @@ func (suite *SnapshotSuite) TestGetRepositoryConfigurationFile() { snapUUID := uuid.NewString() repoConfigFile := "file" - suite.reg.Snapshot.On("GetRepositoryConfigurationFile", orgID, snapUUID, repoUUID).Return(repoConfigFile, nil).Once() - suite.reg.Snapshot.On("InitializePulpClient", mock.AnythingOfType("*context.valueCtx"), orgID).Return(nil).Once() + suite.reg.Snapshot.WithContextMock().On("GetRepositoryConfigurationFile", orgID, snapUUID, repoUUID).Return(repoConfigFile, nil).Once() path := fmt.Sprintf("%s/repositories/%s/snapshots/%s/config.repo", fullRootPath(), repoUUID, snapUUID) req := httptest.NewRequest(http.MethodGet, path, nil) diff --git a/pkg/pulp_client/client.go b/pkg/pulp_client/client.go index b372abb7d..358decd8c 100644 --- a/pkg/pulp_client/client.go +++ b/pkg/pulp_client/client.go @@ -28,6 +28,18 @@ func GetPulpClientWithDomain(ctx context.Context, domainName string) PulpClient return &impl } +func (p *pulpDaoImpl) WithContext(ctx context.Context) PulpClient { + pulp := getPulpImpl(ctx) + pulp.domainName = p.domainName + return &pulp +} + +func (p *pulpDaoImpl) WithDomain(domainName string) PulpClient { + cpy := *p + cpy.domainName = domainName + return &cpy +} + func getPulpImpl(ctx context.Context) pulpDaoImpl { ctx2 := context.WithValue(ctx, zest.ContextServerIndex, 0) timeout := 60 * time.Second diff --git a/pkg/pulp_client/interfaces.go b/pkg/pulp_client/interfaces.go index 796e8a161..5c73e613c 100644 --- a/pkg/pulp_client/interfaces.go +++ b/pkg/pulp_client/interfaces.go @@ -1,6 +1,8 @@ package pulp_client import ( + "context" + zest "github.com/content-services/zest/release/v2023" ) @@ -60,4 +62,8 @@ type PulpClient interface { // Status Status() (*zest.StatusResponse, error) + + // Chainable + WithContext(ctx context.Context) PulpClient + WithDomain(domainName string) PulpClient } diff --git a/pkg/pulp_client/pulp_client_mock.go b/pkg/pulp_client/pulp_client_mock.go index ab4d8bef0..ead5be67b 100644 --- a/pkg/pulp_client/pulp_client_mock.go +++ b/pkg/pulp_client/pulp_client_mock.go @@ -3,6 +3,8 @@ package pulp_client import ( + context "context" + zest "github.com/content-services/zest/release/v2023" mock "github.com/stretchr/testify/mock" ) @@ -628,6 +630,38 @@ func (_m *MockPulpClient) UpdateRpmRemote(pulpHref string, url string, clientCer return r0, r1 } +// WithContext provides a mock function with given fields: ctx +func (_m *MockPulpClient) WithContext(ctx context.Context) PulpClient { + ret := _m.Called(ctx) + + var r0 PulpClient + if rf, ok := ret.Get(0).(func(context.Context) PulpClient); ok { + r0 = rf(ctx) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(PulpClient) + } + } + + return r0 +} + +// WithDomain provides a mock function with given fields: domainName +func (_m *MockPulpClient) WithDomain(domainName string) PulpClient { + ret := _m.Called(domainName) + + var r0 PulpClient + if rf, ok := ret.Get(0).(func(string) PulpClient); ok { + r0 = rf(domainName) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(PulpClient) + } + } + + return r0 +} + // NewMockPulpClient creates a new instance of MockPulpClient. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // The first argument is typically a *testing.T value. func NewMockPulpClient(t interface { diff --git a/pkg/pulp_client/pulp_client_mock_helpers.go b/pkg/pulp_client/pulp_client_mock_helpers.go new file mode 100644 index 000000000..b83b3c51b --- /dev/null +++ b/pkg/pulp_client/pulp_client_mock_helpers.go @@ -0,0 +1,13 @@ +package pulp_client + +import "context" + +func (m *MockPulpClient) WithContextMock() *MockPulpClient { + m.On("WithContext", context.Background()).Return(m) + return m +} + +func (m *MockPulpClient) WithDomainMock() *MockPulpClient { + m.On("WithDomain", "").Return(m) + return m +} diff --git a/test/integration/snapshot_test.go b/test/integration/snapshot_test.go index e59d885bf..3b136373a 100644 --- a/test/integration/snapshot_test.go +++ b/test/integration/snapshot_test.go @@ -76,12 +76,6 @@ func (s *SnapshotSuite) TestSnapshot() { repoUuid, err := uuid2.Parse(repo.RepositoryUUID) assert.NoError(s.T(), err) - err = s.dao.Snapshot.InitializePulpClient(context.Background(), accountId) - assert.NoError(s.T(), err) - - err = s.dao.RepositoryConfig.InitializePulpClient(context.Background(), accountId) - assert.NoError(s.T(), err) - // Start the task taskClient := client.NewTaskClient(&s.queue) s.snapshotAndWait(taskClient, repo, repoUuid, accountId)