diff --git a/README.md b/README.md index e8a7561..fcfa40c 100644 --- a/README.md +++ b/README.md @@ -25,20 +25,22 @@ settings := YummySettings{ repo, err := NewRepository(settings) +ctx := context.Background() + // To get repomd metadata -repomd, statusCode, err := repo.Repomd() +repomd, statusCode, err := repo.Repomd(ctx) // To get package metadata -packages, statusCode, err := repo.Packages() +packages, statusCode, err := repo.Packages(ctx) // To get repository signature -signature, statusCode, err := repo.Signature() +signature, statusCode, err := repo.Signature(ctx) // To get repository package groups -packageGroups, statusCode, err := repo.PackageGroups() +packageGroups, statusCode, err := repo.PackageGroups(ctx) // To get repository environments -environments, statusCode, err := repo.Environments() +environments, statusCode, err := repo.Environments(ctx) ``` **To parse packages from a yum repository on disk** @@ -56,5 +58,8 @@ environments, statusCode, err := repo.Environments() ```go url := "https://packages.microsoft.com/keys/microsoft.asc" client := http.Client { Timeout: time.Second*10 } -gpgKey, statusCode, err := FetchGPGKey(url, client) +gpgKey, statusCode, err := FetchGPGKey(context.Background(), url, client) ``` + +**Mocking** +Yum also exports a mock interface you can regenerate using the [mockery](https://github.com/vektra/mockery) tool. \ No newline at end of file diff --git a/go.mod b/go.mod index ff064a6..4024b6d 100644 --- a/go.mod +++ b/go.mod @@ -15,6 +15,7 @@ require ( github.com/cloudflare/circl v1.3.7 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/stretchr/objx v0.5.0 // indirect golang.org/x/crypto v0.17.0 // indirect golang.org/x/sys v0.15.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/go.sum b/go.sum index 0d040db..7cc3683 100644 --- a/go.sum +++ b/go.sum @@ -17,6 +17,7 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= diff --git a/pkg/yum/gpg_key.go b/pkg/yum/gpg_key.go index 0b9f743..b580f05 100644 --- a/pkg/yum/gpg_key.go +++ b/pkg/yum/gpg_key.go @@ -1,6 +1,7 @@ package yum import ( + "context" "fmt" "net/http" "strings" @@ -9,8 +10,12 @@ import ( ) // FetchGPGKey GETs GPG Key from url with request timeout maximum timeout. -func FetchGPGKey(url string, client *http.Client) (*string, int, error) { - resp, err := client.Get(url) +func FetchGPGKey(ctx context.Context, url string, client *http.Client) (*string, int, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return nil, 0, fmt.Errorf("error creating request: %w", err) + } + resp, err := client.Do(req) if err != nil { return nil, 0, err } diff --git a/pkg/yum/gpg_key_test.go b/pkg/yum/gpg_key_test.go index 6cb7e53..37c8a48 100644 --- a/pkg/yum/gpg_key_test.go +++ b/pkg/yum/gpg_key_test.go @@ -1,6 +1,7 @@ package yum import ( + "context" _ "embed" "net/http" "testing" @@ -17,7 +18,7 @@ func TestFetchGPGKey(t *testing.T) { c := s.Client() - gpg, code, err := FetchGPGKey(s.URL+"/gpgkey.pub", c) + gpg, code, err := FetchGPGKey(context.Background(), s.URL+"/gpgkey.pub", c) assert.NotEmpty(t, gpg) assert.Equal(t, 200, code) assert.Nil(t, err) diff --git a/pkg/yum/repository.go b/pkg/yum/repository.go index 2e92e91..22f6748 100644 --- a/pkg/yum/repository.go +++ b/pkg/yum/repository.go @@ -4,6 +4,7 @@ import ( "bufio" "bytes" "compress/gzip" + "context" "encoding/xml" "fmt" "io" @@ -91,14 +92,15 @@ type Comps struct { Environments []Environment } +//go:generate mockery --name YumRepository --filename yum_repository_mock.go --inpackage type YumRepository interface { Configure(settings YummySettings) - Packages() (packages []Package, statusCode int, err error) - Repomd() (repomd *Repomd, statusCode int, err error) - Signature() (repomdSignature *string, statusCode int, err error) - Comps() (comps *Comps, statusCode int, err error) - PackageGroups() (packageGroups []PackageGroup, statusCode int, err error) - Environments() (environments []Environment, statusCode int, err error) + Packages(ctx context.Context) (packages []Package, statusCode int, err error) + Repomd(ctx context.Context) (repomd *Repomd, statusCode int, err error) + Signature(ctx context.Context) (repomdSignature *string, statusCode int, err error) + Comps(ctx context.Context) (comps *Comps, statusCode int, err error) + PackageGroups(ctx context.Context) (packageGroups []PackageGroup, statusCode int, err error) + Environments(ctx context.Context) (environments []Environment, statusCode int, err error) Clear() } @@ -146,7 +148,7 @@ func (r *Repository) Clear() { // Repomd populates r.Repomd with repository's repomd.xml metadata. Returns Repomd, response code, and error. // If the repomd was successfully fetched previously, will return cached repomd. -func (r *Repository) Repomd() (*Repomd, int, error) { +func (r *Repository) Repomd(ctx context.Context) (*Repomd, int, error) { var result Repomd var err error var resp *http.Response @@ -158,7 +160,13 @@ func (r *Repository) Repomd() (*Repomd, int, error) { if repomdURL, err = r.getRepomdURL(); err != nil { return nil, 0, fmt.Errorf("Error parsing Repomd URL: %w", err) } - if resp, err = r.settings.Client.Get(repomdURL); err != nil { + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, repomdURL, nil) + if err != nil { + return nil, 0, fmt.Errorf("error creating request: %w", err) + } + + if resp, err = r.settings.Client.Do(req); err != nil { return nil, erroredStatusCode(resp), fmt.Errorf("GET error for file %v: %w", repomdURL, err) } defer resp.Body.Close() @@ -182,7 +190,7 @@ func erroredStatusCode(response *http.Response) int { } } -func (r *Repository) Comps() (*Comps, int, error) { +func (r *Repository) Comps(ctx context.Context) (*Comps, int, error) { var err error var compsURL *string var resp *http.Response @@ -192,7 +200,7 @@ func (r *Repository) Comps() (*Comps, int, error) { return r.comps, 200, nil } - if _, _, err = r.Repomd(); err != nil { + if _, _, err = r.Repomd(ctx); err != nil { return nil, 0, fmt.Errorf("error parsing repomd.xml: %w", err) } @@ -201,7 +209,12 @@ func (r *Repository) Comps() (*Comps, int, error) { } if compsURL != nil { - if resp, err = r.settings.Client.Get(*compsURL); err != nil { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, *compsURL, nil) + if err != nil { + return nil, 0, fmt.Errorf("error creating request: %w", err) + } + + if resp, err = r.settings.Client.Do(req); err != nil { return nil, erroredStatusCode(resp), fmt.Errorf("GET error for file %v: %w", compsURL, err) } @@ -221,7 +234,7 @@ func (r *Repository) Comps() (*Comps, int, error) { // Packages populates r.Packages with metadata of each package in repository. Returns response code and error. // If the packages were successfully fetched previously, will return cached packages. -func (r *Repository) Packages() ([]Package, int, error) { +func (r *Repository) Packages(ctx context.Context) ([]Package, int, error) { var err error var primaryURL string var resp *http.Response @@ -231,11 +244,11 @@ func (r *Repository) Packages() ([]Package, int, error) { return r.packages, 0, nil } - if _, _, err = r.Repomd(); err != nil { + if _, _, err = r.Repomd(ctx); err != nil { return nil, 0, fmt.Errorf("error parsing repomd.xml: %w", err) } - if primaryURL, err = r.getPrimaryURL(); err != nil { + if primaryURL, err = r.getPrimaryURL(ctx); err != nil { return nil, 0, fmt.Errorf("Error getting primary URL: %w", err) } @@ -257,7 +270,7 @@ func (r *Repository) Packages() ([]Package, int, error) { } // PackageGroups populates r.PackageGroups with the package groups of a repository. Returns response code and error. -func (r *Repository) PackageGroups() ([]PackageGroup, int, error) { +func (r *Repository) PackageGroups(ctx context.Context) ([]PackageGroup, int, error) { var err error var status int var comps *Comps @@ -266,7 +279,7 @@ func (r *Repository) PackageGroups() ([]PackageGroup, int, error) { return r.comps.PackageGroups, 200, nil } - if comps, status, err = r.Comps(); err != nil { + if comps, status, err = r.Comps(ctx); err != nil { return nil, 0, fmt.Errorf("error getting comps: %w", err) } @@ -279,7 +292,7 @@ func (r *Repository) PackageGroups() ([]PackageGroup, int, error) { } // Environments populates r.Environments with the environments of a repository. Returns response code and error. -func (r *Repository) Environments() ([]Environment, int, error) { +func (r *Repository) Environments(ctx context.Context) ([]Environment, int, error) { var err error var status int var comps *Comps @@ -288,7 +301,7 @@ func (r *Repository) Environments() ([]Environment, int, error) { return r.comps.Environments, 200, nil } - if comps, status, err = r.Comps(); err != nil { + if comps, status, err = r.Comps(ctx); err != nil { return nil, 0, fmt.Errorf("error getting comps: %w", err) } @@ -302,7 +315,7 @@ func (r *Repository) Environments() ([]Environment, int, error) { // Signature fetches the yum metadata signature and returns any error and HTTP code encountered. // If the signature was successfully fetched previously, will return cached signature. -func (r *Repository) Signature() (*string, int, error) { +func (r *Repository) Signature(ctx context.Context) (*string, int, error) { var sig *string if r.repomdSignature != nil { @@ -371,10 +384,10 @@ func (r *Repository) getSignatureURL() (string, error) { } } -func (r *Repository) getPrimaryURL() (string, error) { +func (r *Repository) getPrimaryURL(ctx context.Context) (string, error) { var primaryLocation string - if _, _, err := r.Repomd(); err != nil { + if _, _, err := r.Repomd(ctx); err != nil { return "", fmt.Errorf("error fetching Repomd: %w", err) } diff --git a/pkg/yum/repository_test.go b/pkg/yum/repository_test.go index 75d9159..ce31f5a 100644 --- a/pkg/yum/repository_test.go +++ b/pkg/yum/repository_test.go @@ -1,6 +1,7 @@ package yum import ( + "context" _ "embed" "encoding/xml" "fmt" @@ -62,10 +63,12 @@ func TestClear(t *testing.T) { } r, _ := NewRepository(settings) - _, _, _ = r.Repomd() - _, _, _ = r.Packages() - _, _, _ = r.Signature() - _, _, _ = r.Comps() + ctx := context.Background() + + _, _, _ = r.Repomd(ctx) + _, _, _ = r.Packages(ctx) + _, _, _ = r.Signature(ctx) + _, _, _ = r.Comps(ctx) assert.NotNil(t, r.repomd) assert.NotNil(t, r.packages) assert.NotNil(t, r.repomdSignature) @@ -89,7 +92,7 @@ func TestGetPrimaryURL(t *testing.T) { assert.Nil(t, err) r.repomd = &repomd - primary, err := r.getPrimaryURL() + primary, err := r.getPrimaryURL(context.Background()) assert.Nil(t, err) assert.Equal(t, "http://foo.example.com/repo/repodata/primary.xml.gz", primary) } @@ -137,7 +140,7 @@ func TestFetchRepomd(t *testing.T) { RepomdString: &repomdStringMock, } - repomd, code, err := r.Repomd() + repomd, code, err := r.Repomd(context.Background()) assert.Equal(t, expected, *repomd) assert.Equal(t, *repomd, *r.repomd) assert.Equal(t, 200, code) @@ -155,7 +158,7 @@ func TestFetchComps(t *testing.T) { } r, _ := NewRepository(settings) - comps, code, err := r.Comps() + comps, code, err := r.Comps(context.Background()) assert.Equal(t, *comps, *r.comps) assert.Equal(t, 200, code) assert.Nil(t, err) @@ -208,7 +211,7 @@ func TestFetchPackages(t *testing.T) { } r, _ := NewRepository(settings) - packages, code, err := r.Packages() + packages, code, err := r.Packages(context.Background()) assert.Equal(t, 2, len(packages)) assert.Equal(t, packages, r.packages) assert.Equal(t, 200, code) @@ -226,7 +229,7 @@ func TestFetchPackageGroups(t *testing.T) { } r, _ := NewRepository(settings) - packageGroups, code, err := r.PackageGroups() + packageGroups, code, err := r.PackageGroups(context.Background()) assert.Equal(t, 1, len(packageGroups)) assert.Equal(t, packageGroups, r.comps.PackageGroups) assert.Equal(t, 200, code) @@ -244,7 +247,7 @@ func TestFetchEnvironments(t *testing.T) { } r, _ := NewRepository(settings) - environments, code, err := r.Environments() + environments, code, err := r.Environments(context.Background()) assert.Equal(t, 1, len(environments)) assert.Equal(t, environments, r.comps.Environments) assert.Equal(t, 200, code) @@ -262,7 +265,7 @@ func TestBadUrl(t *testing.T) { URL: &badUrl, } r, _ := NewRepository(settings) - _, code, err := r.Repomd() + _, code, err := r.Repomd(context.Background()) assert.Error(t, err) assert.Equal(t, code, 0) } @@ -278,7 +281,7 @@ func TestFetchRepomdSignature(t *testing.T) { } r, _ := NewRepository(settings) - signature, code, err := r.Signature() + signature, code, err := r.Signature(context.Background()) assert.NotEmpty(t, signature) assert.Equal(t, signature, r.repomdSignature) assert.Equal(t, 200, code) diff --git a/pkg/yum/yum_repository_mock.go b/pkg/yum/yum_repository_mock.go new file mode 100644 index 0000000..9a35a79 --- /dev/null +++ b/pkg/yum/yum_repository_mock.go @@ -0,0 +1,236 @@ +// Code generated by mockery v2.36.1. DO NOT EDIT. + +package yum + +import ( + context "context" + + mock "github.com/stretchr/testify/mock" +) + +// MockYumRepository is an autogenerated mock type for the YumRepository type +type MockYumRepository struct { + mock.Mock +} + +// Clear provides a mock function with given fields: +func (_m *MockYumRepository) Clear() { + _m.Called() +} + +// Comps provides a mock function with given fields: ctx +func (_m *MockYumRepository) Comps(ctx context.Context) (*Comps, int, error) { + ret := _m.Called(ctx) + + var r0 *Comps + var r1 int + var r2 error + if rf, ok := ret.Get(0).(func(context.Context) (*Comps, int, error)); ok { + return rf(ctx) + } + if rf, ok := ret.Get(0).(func(context.Context) *Comps); ok { + r0 = rf(ctx) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*Comps) + } + } + + if rf, ok := ret.Get(1).(func(context.Context) int); ok { + r1 = rf(ctx) + } else { + r1 = ret.Get(1).(int) + } + + if rf, ok := ret.Get(2).(func(context.Context) error); ok { + r2 = rf(ctx) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// Configure provides a mock function with given fields: settings +func (_m *MockYumRepository) Configure(settings YummySettings) { + _m.Called(settings) +} + +// Environments provides a mock function with given fields: ctx +func (_m *MockYumRepository) Environments(ctx context.Context) ([]Environment, int, error) { + ret := _m.Called(ctx) + + var r0 []Environment + var r1 int + var r2 error + if rf, ok := ret.Get(0).(func(context.Context) ([]Environment, int, error)); ok { + return rf(ctx) + } + if rf, ok := ret.Get(0).(func(context.Context) []Environment); ok { + r0 = rf(ctx) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]Environment) + } + } + + if rf, ok := ret.Get(1).(func(context.Context) int); ok { + r1 = rf(ctx) + } else { + r1 = ret.Get(1).(int) + } + + if rf, ok := ret.Get(2).(func(context.Context) error); ok { + r2 = rf(ctx) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// PackageGroups provides a mock function with given fields: ctx +func (_m *MockYumRepository) PackageGroups(ctx context.Context) ([]PackageGroup, int, error) { + ret := _m.Called(ctx) + + var r0 []PackageGroup + var r1 int + var r2 error + if rf, ok := ret.Get(0).(func(context.Context) ([]PackageGroup, int, error)); ok { + return rf(ctx) + } + if rf, ok := ret.Get(0).(func(context.Context) []PackageGroup); ok { + r0 = rf(ctx) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]PackageGroup) + } + } + + if rf, ok := ret.Get(1).(func(context.Context) int); ok { + r1 = rf(ctx) + } else { + r1 = ret.Get(1).(int) + } + + if rf, ok := ret.Get(2).(func(context.Context) error); ok { + r2 = rf(ctx) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// Packages provides a mock function with given fields: ctx +func (_m *MockYumRepository) Packages(ctx context.Context) ([]Package, int, error) { + ret := _m.Called(ctx) + + var r0 []Package + var r1 int + var r2 error + if rf, ok := ret.Get(0).(func(context.Context) ([]Package, int, error)); ok { + return rf(ctx) + } + if rf, ok := ret.Get(0).(func(context.Context) []Package); ok { + r0 = rf(ctx) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]Package) + } + } + + if rf, ok := ret.Get(1).(func(context.Context) int); ok { + r1 = rf(ctx) + } else { + r1 = ret.Get(1).(int) + } + + if rf, ok := ret.Get(2).(func(context.Context) error); ok { + r2 = rf(ctx) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// Repomd provides a mock function with given fields: ctx +func (_m *MockYumRepository) Repomd(ctx context.Context) (*Repomd, int, error) { + ret := _m.Called(ctx) + + var r0 *Repomd + var r1 int + var r2 error + if rf, ok := ret.Get(0).(func(context.Context) (*Repomd, int, error)); ok { + return rf(ctx) + } + if rf, ok := ret.Get(0).(func(context.Context) *Repomd); ok { + r0 = rf(ctx) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*Repomd) + } + } + + if rf, ok := ret.Get(1).(func(context.Context) int); ok { + r1 = rf(ctx) + } else { + r1 = ret.Get(1).(int) + } + + if rf, ok := ret.Get(2).(func(context.Context) error); ok { + r2 = rf(ctx) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// Signature provides a mock function with given fields: ctx +func (_m *MockYumRepository) Signature(ctx context.Context) (*string, int, error) { + ret := _m.Called(ctx) + + var r0 *string + var r1 int + var r2 error + if rf, ok := ret.Get(0).(func(context.Context) (*string, int, error)); ok { + return rf(ctx) + } + if rf, ok := ret.Get(0).(func(context.Context) *string); ok { + r0 = rf(ctx) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*string) + } + } + + if rf, ok := ret.Get(1).(func(context.Context) int); ok { + r1 = rf(ctx) + } else { + r1 = ret.Get(1).(int) + } + + if rf, ok := ret.Get(2).(func(context.Context) error); ok { + r2 = rf(ctx) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// NewMockYumRepository creates a new instance of MockYumRepository. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockYumRepository(t interface { + mock.TestingT + Cleanup(func()) +}) *MockYumRepository { + mock := &MockYumRepository{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +}