diff --git a/retriever/githubretriever/retriever.go b/retriever/githubretriever/retriever.go index cbbe67ad84b..daa4bbff6ed 100644 --- a/retriever/githubretriever/retriever.go +++ b/retriever/githubretriever/retriever.go @@ -3,11 +3,12 @@ package githubretriever import ( "context" "fmt" + "github.com/thomaspoignant/go-feature-flag/retriever/shared" + "io" "net/http" + "strings" "time" - httpretriever "github.com/thomaspoignant/go-feature-flag/retriever/httpretriever" - "github.com/thomaspoignant/go-feature-flag/internal" ) @@ -48,18 +49,28 @@ func (r *Retriever) Retrieve(ctx context.Context) ([]byte, error) { r.FilePath, branch) - httpRetriever := httpretriever.Retriever{ - URL: URL, - Method: http.MethodGet, - Header: header, - Timeout: r.Timeout, + resp, err := shared.CallHTTPAPI(ctx, URL, http.MethodGet, "", r.Timeout, header, r.httpClient) + if err != nil { + return nil, err } + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode > 399 { + // Collect the headers to add in the error message + ghHeaders := map[string]string{} + for name := range resp.Header { + if strings.HasPrefix(name, "X-") { + ghHeaders[name] = resp.Header.Get(name) + } + } - if r.httpClient != nil { - httpRetriever.SetHTTPClient(r.httpClient) + return nil, fmt.Errorf("request to %s failed with code %d."+ + " GitHub Headers: %v", URL, resp.StatusCode, ghHeaders) } - - return httpRetriever.Retrieve(ctx) + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + return body, nil } // SetHTTPClient is here if you want to override the default http.Client we are using. diff --git a/retriever/githubretriever/retriever_test.go b/retriever/githubretriever/retriever_test.go index d9782454674..bb77929e388 100644 --- a/retriever/githubretriever/retriever_test.go +++ b/retriever/githubretriever/retriever_test.go @@ -24,6 +24,7 @@ func Test_github_Retrieve(t *testing.T) { fields fields want []byte wantErr bool + errMsg string }{ { name: "Success", @@ -116,6 +117,16 @@ func Test_github_Retrieve(t *testing.T) { }, wantErr: true, }, + { + name: "Ratelimiting", + fields: fields{ + httpClient: mock.HTTP{RateLimit: true}, + repositorySlug: "thomaspoignant/go-feature-flag", + filePath: "testdata/flag-config.yaml", + }, + errMsg: "request to https://api.github.com/repos/thomaspoignant/go-feature-flag/contents/testdata/flag-config.yaml?ref=main failed with code 429. GitHub Headers: map[X-Content-Type-Options:nosniff X-Frame-Options:deny X-Github-Media-Type:github.v3; format=json X-Github-Request-Id:F82D:37B98C:232EF263:235C93BD:6650BDC6 X-Ratelimit-Limit:60 X-Ratelimit-Remaining:0 X-Ratelimit-Reset:1716568424 X-Ratelimit-Resource:core X-Ratelimit-Used:60 X-Xss-Protection:1; mode=block]", + wantErr: true, + }, { name: "Use GitHub token", fields: fields{ @@ -149,6 +160,9 @@ func Test_github_Retrieve(t *testing.T) { h.SetHTTPClient(&tt.fields.httpClient) got, err := h.Retrieve(tt.fields.context) + if tt.errMsg != "" { + assert.EqualError(t, err, tt.errMsg) + } assert.Equal(t, tt.wantErr, err != nil, "Retrieve() error = %v, wantErr %v", err, tt.wantErr) if !tt.wantErr { assert.Equal(t, http.MethodGet, tt.fields.httpClient.Req.Method) diff --git a/retriever/httpretriever/retriever.go b/retriever/httpretriever/retriever.go index 5ef40951ee9..c5898fd99fa 100644 --- a/retriever/httpretriever/retriever.go +++ b/retriever/httpretriever/retriever.go @@ -2,11 +2,10 @@ package httpretriever import ( "context" - "errors" "fmt" + "github.com/thomaspoignant/go-feature-flag/retriever/shared" "io" "net/http" - "strings" "time" "github.com/thomaspoignant/go-feature-flag/internal" @@ -39,51 +38,14 @@ func (r *Retriever) SetHTTPClient(client internal.HTTPClient) { } func (r *Retriever) Retrieve(ctx context.Context) ([]byte, error) { - timeout := r.Timeout - if timeout <= 0 { - timeout = 10 * time.Second - } - - if r.URL == "" { - return nil, errors.New("URL is a mandatory parameter when using httpretriever.Retriever") - } - - method := r.Method - if method == "" { - method = http.MethodGet - } - - if ctx == nil { - ctx = context.Background() - } - - req, err := http.NewRequestWithContext(ctx, method, r.URL, strings.NewReader(r.Body)) - if err != nil { - return nil, err - } - - // Add header if some are passed - if len(r.Header) > 0 { - req.Header = r.Header - } - - if r.httpClient == nil { - r.httpClient = internal.HTTPClientWithTimeout(timeout) - } - - // API call - resp, err := r.httpClient.Do(req) + resp, err := shared.CallHTTPAPI(ctx, r.URL, r.Method, r.Body, r.Timeout, r.Header, r.httpClient) if err != nil { return nil, err } - defer resp.Body.Close() - - // Error if http code is more that 399 + defer func() { _ = resp.Body.Close() }() if resp.StatusCode > 399 { return nil, fmt.Errorf("request to %s failed with code %d", r.URL, resp.StatusCode) } - - // read content of the URL. body, err := io.ReadAll(resp.Body) if err != nil { return nil, err diff --git a/retriever/shared/http.go b/retriever/shared/http.go new file mode 100644 index 00000000000..463247ea43b --- /dev/null +++ b/retriever/shared/http.go @@ -0,0 +1,51 @@ +package shared + +import ( + "errors" + "github.com/thomaspoignant/go-feature-flag/internal" + "golang.org/x/net/context" + "net/http" + "strings" + "time" +) + +func CallHTTPAPI( + ctx context.Context, + url string, method string, + body string, + timeout time.Duration, + header http.Header, + httpClient internal.HTTPClient) (*http.Response, error) { + if timeout <= 0 { + timeout = 10 * time.Second + } + + if url == "" { + return nil, errors.New("URL is a mandatory parameter when using httpretriever.Retriever") + } + + if method == "" { + method = http.MethodGet + } + + if ctx == nil { + ctx = context.Background() + } + + req, err := http.NewRequestWithContext(ctx, method, url, strings.NewReader(body)) + if err != nil { + return nil, err + } + + // Add header if some are passed + if len(header) > 0 { + req.Header = header + } + + if httpClient == nil { + httpClient = internal.HTTPClientWithTimeout(timeout) + } + + // API call + return httpClient.Do(req) +} diff --git a/testutils/mock/http_mock.go b/testutils/mock/http_mock.go index 8cab582da94..37a670537a6 100644 --- a/testutils/mock/http_mock.go +++ b/testutils/mock/http_mock.go @@ -9,7 +9,8 @@ import ( ) type HTTP struct { - Req http.Request + Req http.Request + RateLimit bool } func (m *HTTP) Do(req *http.Request) (*http.Response, error) { @@ -37,6 +38,27 @@ func (m *HTTP) Do(req *http.Request) (*http.Response, error) { Body: io.NopCloser(bytes.NewReader([]byte(""))), } + rateLimit := &http.Response{ + Status: "Rate Limit", + StatusCode: http.StatusTooManyRequests, + Body: io.NopCloser(bytes.NewReader([]byte(""))), + Header: map[string][]string{ + "X-Content-Type-Options": {"nosniff"}, + "X-Frame-Options": {"deny"}, + "X-Github-Media-Type": {"github.v3; format=json"}, + "X-Github-Request-Id": {"F82D:37B98C:232EF263:235C93BD:6650BDC6"}, + "X-Ratelimit-Limit": {"60"}, + "X-Ratelimit-Remaining": {"0"}, + "X-Ratelimit-Reset": {"1716568424"}, + "X-Ratelimit-Resource": {"core"}, + "X-Ratelimit-Used": {"60"}, + "X-Xss-Protection": {"1; mode=block"}, + }, + } + if m.RateLimit { + return rateLimit, nil + } + if strings.Contains(req.URL.String(), "error") { return nil, errors.New("http error") } else if strings.HasSuffix(req.URL.String(), "httpError") {