Skip to content

Commit

Permalink
Fixes 2268: add context handling (#24)
Browse files Browse the repository at this point in the history
* Fixes 2268: add context handling

* propagate context through method calls

* export mock interface

* fix typo
  • Loading branch information
rverdile authored May 15, 2024
1 parent cba215c commit b928b22
Show file tree
Hide file tree
Showing 8 changed files with 307 additions and 42 deletions.
17 changes: 11 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,22 @@ settings := YummySettings{

repo, err := NewRepository(settings)

ctx := context.Background()

// To get repomd metadata
repomd, statusCode, err := repo.Repomd()
repomd, statusCode, err := repo.Repomd(ctx)

// To get package metadata
packages, statusCode, err := repo.Packages()
packages, statusCode, err := repo.Packages(ctx)

// To get repository signature
signature, statusCode, err := repo.Signature()
signature, statusCode, err := repo.Signature(ctx)

// To get repository package groups
packageGroups, statusCode, err := repo.PackageGroups()
packageGroups, statusCode, err := repo.PackageGroups(ctx)

// To get repository environments
environments, statusCode, err := repo.Environments()
environments, statusCode, err := repo.Environments(ctx)
```

**To parse packages from a yum repository on disk**
Expand All @@ -56,5 +58,8 @@ environments, statusCode, err := repo.Environments()
```go
url := "https://packages.microsoft.com/keys/microsoft.asc"
client := http.Client { Timeout: time.Second*10 }
gpgKey, statusCode, err := FetchGPGKey(url, client)
gpgKey, statusCode, err := FetchGPGKey(context.Background(), url, client)
```

**Mocking**
Yum also exports a mock interface you can regenerate using the [mockery](https://github.com/vektra/mockery) tool.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ require (
github.com/cloudflare/circl v1.3.7 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/stretchr/objx v0.5.0 // indirect
golang.org/x/crypto v0.17.0 // indirect
golang.org/x/sys v0.15.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
Expand Down
1 change: 1 addition & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
Expand Down
9 changes: 7 additions & 2 deletions pkg/yum/gpg_key.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package yum

import (
"context"
"fmt"
"net/http"
"strings"
Expand All @@ -9,8 +10,12 @@ import (
)

// FetchGPGKey GETs GPG Key from url with request timeout maximum timeout.
func FetchGPGKey(url string, client *http.Client) (*string, int, error) {
resp, err := client.Get(url)
func FetchGPGKey(ctx context.Context, url string, client *http.Client) (*string, int, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, 0, fmt.Errorf("error creating request: %w", err)
}
resp, err := client.Do(req)
if err != nil {
return nil, 0, err
}
Expand Down
3 changes: 2 additions & 1 deletion pkg/yum/gpg_key_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package yum

import (
"context"
_ "embed"
"net/http"
"testing"
Expand All @@ -17,7 +18,7 @@ func TestFetchGPGKey(t *testing.T) {

c := s.Client()

gpg, code, err := FetchGPGKey(s.URL+"/gpgkey.pub", c)
gpg, code, err := FetchGPGKey(context.Background(), s.URL+"/gpgkey.pub", c)
assert.NotEmpty(t, gpg)
assert.Equal(t, 200, code)
assert.Nil(t, err)
Expand Down
55 changes: 34 additions & 21 deletions pkg/yum/repository.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bufio"
"bytes"
"compress/gzip"
"context"
"encoding/xml"
"fmt"
"io"
Expand Down Expand Up @@ -91,14 +92,15 @@ type Comps struct {
Environments []Environment
}

//go:generate mockery --name YumRepository --filename yum_repository_mock.go --inpackage
type YumRepository interface {
Configure(settings YummySettings)
Packages() (packages []Package, statusCode int, err error)
Repomd() (repomd *Repomd, statusCode int, err error)
Signature() (repomdSignature *string, statusCode int, err error)
Comps() (comps *Comps, statusCode int, err error)
PackageGroups() (packageGroups []PackageGroup, statusCode int, err error)
Environments() (environments []Environment, statusCode int, err error)
Packages(ctx context.Context) (packages []Package, statusCode int, err error)
Repomd(ctx context.Context) (repomd *Repomd, statusCode int, err error)
Signature(ctx context.Context) (repomdSignature *string, statusCode int, err error)
Comps(ctx context.Context) (comps *Comps, statusCode int, err error)
PackageGroups(ctx context.Context) (packageGroups []PackageGroup, statusCode int, err error)
Environments(ctx context.Context) (environments []Environment, statusCode int, err error)
Clear()
}

Expand Down Expand Up @@ -146,7 +148,7 @@ func (r *Repository) Clear() {

// Repomd populates r.Repomd with repository's repomd.xml metadata. Returns Repomd, response code, and error.
// If the repomd was successfully fetched previously, will return cached repomd.
func (r *Repository) Repomd() (*Repomd, int, error) {
func (r *Repository) Repomd(ctx context.Context) (*Repomd, int, error) {
var result Repomd
var err error
var resp *http.Response
Expand All @@ -158,7 +160,13 @@ func (r *Repository) Repomd() (*Repomd, int, error) {
if repomdURL, err = r.getRepomdURL(); err != nil {
return nil, 0, fmt.Errorf("Error parsing Repomd URL: %w", err)
}
if resp, err = r.settings.Client.Get(repomdURL); err != nil {

req, err := http.NewRequestWithContext(ctx, http.MethodGet, repomdURL, nil)
if err != nil {
return nil, 0, fmt.Errorf("error creating request: %w", err)
}

if resp, err = r.settings.Client.Do(req); err != nil {
return nil, erroredStatusCode(resp), fmt.Errorf("GET error for file %v: %w", repomdURL, err)
}
defer resp.Body.Close()
Expand All @@ -182,7 +190,7 @@ func erroredStatusCode(response *http.Response) int {
}
}

func (r *Repository) Comps() (*Comps, int, error) {
func (r *Repository) Comps(ctx context.Context) (*Comps, int, error) {
var err error
var compsURL *string
var resp *http.Response
Expand All @@ -192,7 +200,7 @@ func (r *Repository) Comps() (*Comps, int, error) {
return r.comps, 200, nil
}

if _, _, err = r.Repomd(); err != nil {
if _, _, err = r.Repomd(ctx); err != nil {
return nil, 0, fmt.Errorf("error parsing repomd.xml: %w", err)
}

Expand All @@ -201,7 +209,12 @@ func (r *Repository) Comps() (*Comps, int, error) {
}

if compsURL != nil {
if resp, err = r.settings.Client.Get(*compsURL); err != nil {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, *compsURL, nil)
if err != nil {
return nil, 0, fmt.Errorf("error creating request: %w", err)
}

if resp, err = r.settings.Client.Do(req); err != nil {
return nil, erroredStatusCode(resp), fmt.Errorf("GET error for file %v: %w", compsURL, err)
}

Expand All @@ -221,7 +234,7 @@ func (r *Repository) Comps() (*Comps, int, error) {

// Packages populates r.Packages with metadata of each package in repository. Returns response code and error.
// If the packages were successfully fetched previously, will return cached packages.
func (r *Repository) Packages() ([]Package, int, error) {
func (r *Repository) Packages(ctx context.Context) ([]Package, int, error) {
var err error
var primaryURL string
var resp *http.Response
Expand All @@ -231,11 +244,11 @@ func (r *Repository) Packages() ([]Package, int, error) {
return r.packages, 0, nil
}

if _, _, err = r.Repomd(); err != nil {
if _, _, err = r.Repomd(ctx); err != nil {
return nil, 0, fmt.Errorf("error parsing repomd.xml: %w", err)
}

if primaryURL, err = r.getPrimaryURL(); err != nil {
if primaryURL, err = r.getPrimaryURL(ctx); err != nil {
return nil, 0, fmt.Errorf("Error getting primary URL: %w", err)
}

Expand All @@ -257,7 +270,7 @@ func (r *Repository) Packages() ([]Package, int, error) {
}

// PackageGroups populates r.PackageGroups with the package groups of a repository. Returns response code and error.
func (r *Repository) PackageGroups() ([]PackageGroup, int, error) {
func (r *Repository) PackageGroups(ctx context.Context) ([]PackageGroup, int, error) {
var err error
var status int
var comps *Comps
Expand All @@ -266,7 +279,7 @@ func (r *Repository) PackageGroups() ([]PackageGroup, int, error) {
return r.comps.PackageGroups, 200, nil
}

if comps, status, err = r.Comps(); err != nil {
if comps, status, err = r.Comps(ctx); err != nil {
return nil, 0, fmt.Errorf("error getting comps: %w", err)
}

Expand All @@ -279,7 +292,7 @@ func (r *Repository) PackageGroups() ([]PackageGroup, int, error) {
}

// Environments populates r.Environments with the environments of a repository. Returns response code and error.
func (r *Repository) Environments() ([]Environment, int, error) {
func (r *Repository) Environments(ctx context.Context) ([]Environment, int, error) {
var err error
var status int
var comps *Comps
Expand All @@ -288,7 +301,7 @@ func (r *Repository) Environments() ([]Environment, int, error) {
return r.comps.Environments, 200, nil
}

if comps, status, err = r.Comps(); err != nil {
if comps, status, err = r.Comps(ctx); err != nil {
return nil, 0, fmt.Errorf("error getting comps: %w", err)
}

Expand All @@ -302,7 +315,7 @@ func (r *Repository) Environments() ([]Environment, int, error) {

// Signature fetches the yum metadata signature and returns any error and HTTP code encountered.
// If the signature was successfully fetched previously, will return cached signature.
func (r *Repository) Signature() (*string, int, error) {
func (r *Repository) Signature(ctx context.Context) (*string, int, error) {
var sig *string

if r.repomdSignature != nil {
Expand Down Expand Up @@ -371,10 +384,10 @@ func (r *Repository) getSignatureURL() (string, error) {
}
}

func (r *Repository) getPrimaryURL() (string, error) {
func (r *Repository) getPrimaryURL(ctx context.Context) (string, error) {
var primaryLocation string

if _, _, err := r.Repomd(); err != nil {
if _, _, err := r.Repomd(ctx); err != nil {
return "", fmt.Errorf("error fetching Repomd: %w", err)
}

Expand Down
27 changes: 15 additions & 12 deletions pkg/yum/repository_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package yum

import (
"context"
_ "embed"
"encoding/xml"
"fmt"
Expand Down Expand Up @@ -62,10 +63,12 @@ func TestClear(t *testing.T) {
}
r, _ := NewRepository(settings)

_, _, _ = r.Repomd()
_, _, _ = r.Packages()
_, _, _ = r.Signature()
_, _, _ = r.Comps()
ctx := context.Background()

_, _, _ = r.Repomd(ctx)
_, _, _ = r.Packages(ctx)
_, _, _ = r.Signature(ctx)
_, _, _ = r.Comps(ctx)
assert.NotNil(t, r.repomd)
assert.NotNil(t, r.packages)
assert.NotNil(t, r.repomdSignature)
Expand All @@ -89,7 +92,7 @@ func TestGetPrimaryURL(t *testing.T) {
assert.Nil(t, err)
r.repomd = &repomd

primary, err := r.getPrimaryURL()
primary, err := r.getPrimaryURL(context.Background())
assert.Nil(t, err)
assert.Equal(t, "http://foo.example.com/repo/repodata/primary.xml.gz", primary)
}
Expand Down Expand Up @@ -137,7 +140,7 @@ func TestFetchRepomd(t *testing.T) {
RepomdString: &repomdStringMock,
}

repomd, code, err := r.Repomd()
repomd, code, err := r.Repomd(context.Background())
assert.Equal(t, expected, *repomd)
assert.Equal(t, *repomd, *r.repomd)
assert.Equal(t, 200, code)
Expand All @@ -155,7 +158,7 @@ func TestFetchComps(t *testing.T) {
}
r, _ := NewRepository(settings)

comps, code, err := r.Comps()
comps, code, err := r.Comps(context.Background())
assert.Equal(t, *comps, *r.comps)
assert.Equal(t, 200, code)
assert.Nil(t, err)
Expand Down Expand Up @@ -208,7 +211,7 @@ func TestFetchPackages(t *testing.T) {
}
r, _ := NewRepository(settings)

packages, code, err := r.Packages()
packages, code, err := r.Packages(context.Background())
assert.Equal(t, 2, len(packages))
assert.Equal(t, packages, r.packages)
assert.Equal(t, 200, code)
Expand All @@ -226,7 +229,7 @@ func TestFetchPackageGroups(t *testing.T) {
}
r, _ := NewRepository(settings)

packageGroups, code, err := r.PackageGroups()
packageGroups, code, err := r.PackageGroups(context.Background())
assert.Equal(t, 1, len(packageGroups))
assert.Equal(t, packageGroups, r.comps.PackageGroups)
assert.Equal(t, 200, code)
Expand All @@ -244,7 +247,7 @@ func TestFetchEnvironments(t *testing.T) {
}
r, _ := NewRepository(settings)

environments, code, err := r.Environments()
environments, code, err := r.Environments(context.Background())
assert.Equal(t, 1, len(environments))
assert.Equal(t, environments, r.comps.Environments)
assert.Equal(t, 200, code)
Expand All @@ -262,7 +265,7 @@ func TestBadUrl(t *testing.T) {
URL: &badUrl,
}
r, _ := NewRepository(settings)
_, code, err := r.Repomd()
_, code, err := r.Repomd(context.Background())
assert.Error(t, err)
assert.Equal(t, code, 0)
}
Expand All @@ -278,7 +281,7 @@ func TestFetchRepomdSignature(t *testing.T) {
}
r, _ := NewRepository(settings)

signature, code, err := r.Signature()
signature, code, err := r.Signature(context.Background())
assert.NotEmpty(t, signature)
assert.Equal(t, signature, r.repomdSignature)
assert.Equal(t, 200, code)
Expand Down
Loading

0 comments on commit b928b22

Please sign in to comment.