diff --git a/pkg/dao/repository_configs.go b/pkg/dao/repository_configs.go index da266c50f..dc8c07063 100644 --- a/pkg/dao/repository_configs.go +++ b/pkg/dao/repository_configs.go @@ -74,8 +74,9 @@ func DBErrorToApi(e error) *ce.DaoError { } func (r *repositoryConfigDaoImpl) WithContext(ctx context.Context) RepositoryConfigDao { - r.ctx = ctx - return r + cpy := *r + cpy.ctx = ctx + return &cpy } func (r repositoryConfigDaoImpl) Create(newRepoReq api.RepositoryRequest) (api.RepositoryResponse, error) { diff --git a/pkg/dao/repository_configs_test.go b/pkg/dao/repository_configs_test.go index 81ba5875c..b48b610cb 100644 --- a/pkg/dao/repository_configs_test.go +++ b/pkg/dao/repository_configs_test.go @@ -122,7 +122,7 @@ func (suite *RepositoryConfigSuite) TestCreateRedHatRepository() { config.El9, }, } - dao := GetRepositoryConfigDao(suite.tx) + dao := GetRepositoryConfigDao(suite.tx, suite.mockPulpClient) _, err := dao.Create(toCreate) assert.ErrorContains(suite.T(), err, "Creating of Red Hat repositories is not permitted") } @@ -1446,7 +1446,7 @@ func (suite *RepositoryConfigSuite) TestBulkDeleteOneNotFound() { func (suite *RepositoryConfigSuite) TestBulkDeleteRedhatRepository() { t := suite.T() - dao := GetRepositoryConfigDao(suite.tx) + dao := GetRepositoryConfigDao(suite.tx, suite.mockPulpClient) orgID := config.RedHatOrg repoConfigCount := 5 diff --git a/pkg/dao/snapshots.go b/pkg/dao/snapshots.go index 224cb4571..be62f2f5c 100644 --- a/pkg/dao/snapshots.go +++ b/pkg/dao/snapshots.go @@ -27,8 +27,9 @@ func GetSnapshotDao(db *gorm.DB) SnapshotDao { } func (sDao *snapshotDaoImpl) WithContext(ctx context.Context) SnapshotDao { - sDao.ctx = ctx - return sDao + cpy := *sDao + cpy.ctx = ctx + return &cpy } // Create records a snapshot of a repository diff --git a/pkg/dao/snapshots_test.go b/pkg/dao/snapshots_test.go index 31bc9632f..c7ac6c682 100644 --- a/pkg/dao/snapshots_test.go +++ b/pkg/dao/snapshots_test.go @@ -110,13 +110,13 @@ func (s *SnapshotsSuite) TestCreateAndList() { tx := s.tx mockPulpClient := pulp_client.NewMockPulpClient(t) - sDao := snapshotDaoImpl{db: tx, pulpClient: mockPulpClient} - sDao.WithContext(context.Background()) + sDaoImpl := snapshotDaoImpl{db: tx, pulpClient: mockPulpClient} + sDao := sDaoImpl.WithContext(context.Background()) - mockPulpClient.WithContextMock().On("GetContentPath").Return(testContentPath, nil) + mockPulpClient.WithContextMock().WithDomainMock().On("GetContentPath").Return(testContentPath, nil) - repoDao := repositoryConfigDaoImpl{db: tx, yumRepo: &mockExt.YumRepositoryMock{}, pulpClient: mockPulpClient} - repoDao.WithContext(context.Background()) + repoDaoImpl := repositoryConfigDaoImpl{db: tx, yumRepo: &mockExt.YumRepositoryMock{}, pulpClient: mockPulpClient} + repoDao := repoDaoImpl.WithContext(context.Background()) rConfig := s.createRepository() pageData := api.PaginationData{ @@ -133,7 +133,7 @@ func (s *SnapshotsSuite) TestCreateAndList() { collection, total, err := sDao.List(rConfig.OrgID, rConfig.UUID, pageData, filterData) - repository, _ := repoDao.fetchRepoConfig(rConfig.OrgID, rConfig.UUID, false) + repository, _ := repoDaoImpl.fetchRepoConfig(rConfig.OrgID, rConfig.UUID, false) repositoryList, repoCount, _ := repoDao.List(rConfig.OrgID, api.PaginationData{Limit: -1}, api.FilterData{}) assert.NoError(t, err) @@ -162,10 +162,12 @@ func (s *SnapshotsSuite) TestCreateAndListRedHatRepo() { tx := s.tx mockPulpClient := pulp_client.NewMockPulpClient(t) - sDao := snapshotDaoImpl{db: tx, pulpClient: mockPulpClient} - mockPulpClient.On("GetContentPath").Return(testContentPath, nil) + sDaoImpl := snapshotDaoImpl{db: tx, pulpClient: mockPulpClient} + sDao := sDaoImpl.WithContext(context.Background()) + mockPulpClient.WithContextMock().WithDomainMock().On("GetContentPath").Return(testContentPath, nil) - repoDao := repositoryConfigDaoImpl{db: tx, yumRepo: &mockExt.YumRepositoryMock{}} + repoDaoImpl := repositoryConfigDaoImpl{db: tx, yumRepo: &mockExt.YumRepositoryMock{}, pulpClient: mockPulpClient} + repoDao := repoDaoImpl.WithContext(context.Background()) redhatRepositoryConfig := s.createRedhatRepository() redhatSnap := s.createSnapshot(redhatRepositoryConfig) @@ -182,7 +184,7 @@ func (s *SnapshotsSuite) TestCreateAndListRedHatRepo() { collection, total, err := sDao.List("ShouldNotMatter", redhatRepositoryConfig.UUID, pageData, filterData) - repository, _ := repoDao.fetchRepoConfig("ShouldNotMatter", redhatRepositoryConfig.UUID, true) + repository, _ := repoDaoImpl.fetchRepoConfig("ShouldNotMatter", redhatRepositoryConfig.UUID, true) repositoryList, repoCount, _ := repoDao.List("ShouldNotMatter", api.PaginationData{Limit: -1}, api.FilterData{}) assert.NoError(t, err) @@ -250,8 +252,8 @@ func (s *SnapshotsSuite) TestListPageLimit() { tx := s.tx mockPulpClient := pulp_client.NewMockPulpClient(t) - sDao := snapshotDaoImpl{db: tx, pulpClient: mockPulpClient} - sDao.WithContext(context.Background()) + sDaoImpl := snapshotDaoImpl{db: tx, pulpClient: mockPulpClient} + sDao := sDaoImpl.WithContext(context.Background()) mockPulpClient.WithContextMock().On("GetContentPath").Return(testContentPath, nil) @@ -396,21 +398,20 @@ func (s *SnapshotsSuite) TestGetRepositoryConfigurationFile() { mockPulpClient := pulp_client.NewMockPulpClient(t) sDao := snapshotDaoImpl{db: tx, pulpClient: mockPulpClient} - sDao.WithContext(context.Background()) repoConfig := s.createRepository() snapshot := s.createSnapshot(repoConfig) // Test happy scenario mockPulpClient.WithContextMock().On("GetContentPath").Return(testContentPath, nil).Once() - repoConfigFile, err := sDao.GetRepositoryConfigurationFile(repoConfig.OrgID, snapshot.UUID, repoConfig.UUID) + repoConfigFile, err := sDao.WithContext(context.Background()).GetRepositoryConfigurationFile(repoConfig.OrgID, snapshot.UUID, repoConfig.UUID) assert.NoError(t, err) assert.Contains(t, repoConfigFile, repoConfig.Name) assert.Contains(t, repoConfigFile, testContentPath) // Test error from pulp call mockPulpClient.WithContextMock().On("GetContentPath").Return("", fmt.Errorf("some error")).Once() - repoConfigFile, err = sDao.GetRepositoryConfigurationFile(repoConfig.OrgID, snapshot.UUID, repoConfig.UUID) + repoConfigFile, err = sDao.WithContext(context.Background()).GetRepositoryConfigurationFile(repoConfig.OrgID, snapshot.UUID, repoConfig.UUID) assert.Error(t, err) assert.Empty(t, repoConfigFile) } @@ -421,7 +422,6 @@ func (s *SnapshotsSuite) TestGetRepositoryConfigurationFileNotFound() { mockPulpClient := pulp_client.MockPulpClient{} sDao := snapshotDaoImpl{db: tx, pulpClient: &mockPulpClient} - sDao.WithContext(context.Background()) repoConfig := s.createRepository() snapshot := s.createSnapshot(repoConfig) @@ -431,7 +431,7 @@ func (s *SnapshotsSuite) TestGetRepositoryConfigurationFileNotFound() { } // Test bad repo UUID - repoConfigFile, err := sDao.GetRepositoryConfigurationFile(repoConfig.OrgID, snapshot.UUID, uuid2.NewString()) + repoConfigFile, err := sDao.WithContext(context.Background()).GetRepositoryConfigurationFile(repoConfig.OrgID, snapshot.UUID, uuid2.NewString()) assert.Error(t, err) if err != nil { daoError, ok := err.(*ce.DaoError) @@ -442,7 +442,7 @@ func (s *SnapshotsSuite) TestGetRepositoryConfigurationFileNotFound() { assert.Empty(t, repoConfigFile) // Test bad snapshot UUID - repoConfigFile, err = sDao.GetRepositoryConfigurationFile(repoConfig.OrgID, uuid2.NewString(), repoConfig.UUID) + repoConfigFile, err = sDao.WithContext(context.Background()).GetRepositoryConfigurationFile(repoConfig.OrgID, uuid2.NewString(), repoConfig.UUID) assert.Error(t, err) if err != nil { daoError, ok := err.(*ce.DaoError) @@ -453,7 +453,7 @@ func (s *SnapshotsSuite) TestGetRepositoryConfigurationFileNotFound() { assert.Empty(t, repoConfigFile) // Test bad org ID - repoConfigFile, err = sDao.GetRepositoryConfigurationFile("bad orgID", snapshot.UUID, repoConfig.UUID) + repoConfigFile, err = sDao.WithContext(context.Background()).GetRepositoryConfigurationFile("bad orgID", snapshot.UUID, repoConfig.UUID) assert.Error(t, err) if err != nil { daoError, ok := err.(*ce.DaoError) diff --git a/pkg/handler/snapshots_test.go b/pkg/handler/snapshots_test.go index 181416863..5717eeb55 100644 --- a/pkg/handler/snapshots_test.go +++ b/pkg/handler/snapshots_test.go @@ -3,7 +3,6 @@ package handler import ( "encoding/json" "fmt" - "github.com/stretchr/testify/mock" "io" "net/http" "net/http/httptest" @@ -20,6 +19,7 @@ import ( echo_middleware "github.com/labstack/echo/v4/middleware" "github.com/redhatinsights/platform-go-middlewares/identity" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/suite" ) diff --git a/pkg/pulp_client/client.go b/pkg/pulp_client/client.go index 8a3346aaa..358decd8c 100644 --- a/pkg/pulp_client/client.go +++ b/pkg/pulp_client/client.go @@ -35,8 +35,9 @@ func (p *pulpDaoImpl) WithContext(ctx context.Context) PulpClient { } func (p *pulpDaoImpl) WithDomain(domainName string) PulpClient { - p.domainName = domainName - return p + cpy := *p + cpy.domainName = domainName + return &cpy } func getPulpImpl(ctx context.Context) pulpDaoImpl {