diff --git a/pkg/token/httpclient/httpclient_test.go b/pkg/token/httpclient/httpclient_test.go index 9cc4626..26c556f 100644 --- a/pkg/token/httpclient/httpclient_test.go +++ b/pkg/token/httpclient/httpclient_test.go @@ -122,12 +122,6 @@ func setTestCase() ([]testCaseIssuer, []testCaseIdentity) { statusCode: http.StatusNotFound, err: errors.New("Unexpected status code 404"), }, - { - name: "no context", - url: "https://hpe-greenlake-tenant.okta.com/oauth2/default", - ctx: nil, - err: errors.New("network error in post to get token"), - }, { name: "status code 400", url: "https://hpe-greenlake-tenant.okta.com/oauth2/default", @@ -166,12 +160,6 @@ func setTestCase() ([]testCaseIssuer, []testCaseIdentity) { statusCode: http.StatusNotFound, err: errors.New("Unexpected status code 404"), }, - { - name: "no context", - url: "https://client.greenlake.hpe.com/api/iam/identity", - ctx: nil, - err: errors.New("net/http: nil Context"), - }, { name: "status code 400", url: "https://client.greenlake.hpe.com/api/iam/identity", diff --git a/pkg/token/identitytoken/identitytoken.go b/pkg/token/identitytoken/identitytoken.go index bf79126..9511ff2 100644 --- a/pkg/token/identitytoken/identitytoken.go +++ b/pkg/token/identitytoken/identitytoken.go @@ -1,4 +1,4 @@ -// (C) Copyright 2021 Hewlett Packard Enterprise Development LP +// (C) Copyright 2021-2024 Hewlett Packard Enterprise Development LP package identitytoken @@ -57,14 +57,15 @@ func GenerateToken( return "", err } - req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, strings.NewReader(string(b))) - if err != nil { - return "", err - } - req.Header.Set("Content-Type", "application/json") + resp, err := tokenutil.DoRetries(ctx, func(reqCtx context.Context) (*http.Request, *http.Response, error) { + req, errReq := http.NewRequestWithContext(reqCtx, http.MethodPost, url, strings.NewReader(string(b))) + if errReq != nil { + return nil, nil, errReq + } + req.Header.Set("Content-Type", "application/json") + resp, errResp := httpClient.Do(req) - resp, err := tokenutil.DoRetries(func() (*http.Response, error) { - return httpClient.Do(req) + return req, resp, errResp }, retryLimit) if err != nil { return "", err diff --git a/pkg/token/issuertoken/issuertoken.go b/pkg/token/issuertoken/issuertoken.go index 7271860..54a77ad 100644 --- a/pkg/token/issuertoken/issuertoken.go +++ b/pkg/token/issuertoken/issuertoken.go @@ -41,19 +41,22 @@ func GenerateToken( } // Execute the request, with retries - resp, err := tokenutil.DoRetries(func() (*http.Response, error) { + resp, err := tokenutil.DoRetries(ctx, func(reqCtx context.Context) (*http.Request, *http.Response, error) { // Create the request - req, errReq := createRequest(ctx, params, clientURL) + req, errReq := createRequest(reqCtx, params, clientURL) if errReq != nil { - return nil, errReq + return nil, nil, errReq } // Close the request after use, i.e. don't reuse the TCP connection req.Close = true - return httpClient.Do(req) + // Execute the request + resp, errResp := httpClient.Do(req) + + return req, resp, errResp }, retryLimit) if err != nil { - return "", fmt.Errorf("network error in post to get token") + return "", err } defer resp.Body.Close() diff --git a/pkg/token/token-util/token-util.go b/pkg/token/token-util/token-util.go index 05d1fb6..1855721 100644 --- a/pkg/token/token-util/token-util.go +++ b/pkg/token/token-util/token-util.go @@ -3,6 +3,7 @@ package tokenutil import ( + "context" "encoding/base64" "encoding/json" "fmt" @@ -75,26 +76,60 @@ func DecodeAccessToken(rawToken string) (Token, error) { return token, nil } -func DoRetries(call func() (*http.Response, error), retries int) (*http.Response, error) { +func DoRetries(ctx context.Context, call func(ctx context.Context) (*http.Request, *http.Response, error), retries int) (*http.Response, error) { + var req *http.Request var resp *http.Response var err error for { - resp, err = call() + // If retries are exhausted, return an error + if retries == 0 { + return resp, errors.MakeErrInternalError(errors.ErrorResponse{ + ErrorCode: "ErrGenerateTokenRetryLimitExceeded", + Message: "Retry limit exceeded"}) + } + + // Create a new context with a timeout + ctxWithTimeout, cancel := createContextWithTimeout(ctx) + defer cancel() + + // Execute the request + req, resp, err = call(ctxWithTimeout) + + // If the error is due to a context timeout, retry the request + if req != nil && req.Context().Err() == context.DeadlineExceeded { + retries = sleepAndDecrementRetries(retries) + + continue + } + + // For all other errors, return the error if err != nil { - return nil, err + return resp, err } - if !isStatusRetryable(resp.StatusCode) || retries == 0 { - break + // If the status code is not retryable, return the response + if !isStatusRetryable(resp.StatusCode) { + return resp, nil } - log.Printf("Retrying request, retries left: %v", retries) - time.Sleep(3 * time.Second) - retries-- + retries = sleepAndDecrementRetries(retries) } +} + +func createContextWithTimeout(ctx context.Context) (context.Context, context.CancelFunc) { + if ctx == nil { + return context.WithTimeout(context.Background(), 3*time.Second) + } + + return context.WithTimeout(ctx, 3*time.Second) +} + +func sleepAndDecrementRetries(retries int) int { + log.Printf("Retrying request, retries left: %v", retries) + time.Sleep(5 * time.Second) - return resp, nil + return retries - 1 } func ManageHTTPErrorCodes(resp *http.Response, clientID string) error { diff --git a/pkg/token/token-util/token-util_test.go b/pkg/token/token-util/token-util_test.go index 5a56f64..e6742e3 100644 --- a/pkg/token/token-util/token-util_test.go +++ b/pkg/token/token-util/token-util_test.go @@ -3,14 +3,22 @@ package tokenutil import ( + "context" "errors" "net/http" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + hpeglErrors "github.com/hewlettpackard/hpegl-provider-lib/pkg/token/errors" ) +var errLimitExceeded = hpeglErrors.MakeErrInternalError(hpeglErrors.ErrorResponse{ + ErrorCode: "ErrGenerateTokenRetryLimitExceeded", + Message: "Retry limit exceeded"}) + //nolint:scopelint func TestDecodeAccessToken(t *testing.T) { type args struct { @@ -135,50 +143,81 @@ func TestDoRetries(t *testing.T) { totalRetries := 0 testcases := []struct { name string - call func() (*http.Response, error) + ctx context.Context + call func(ctx context.Context) (*http.Request, *http.Response, error) responseStatus int err error }{ { name: "status 500", - call: func() (*http.Response, error) { + ctx: context.Background(), + call: func(ctx context.Context) (*http.Request, *http.Response, error) { totalRetries++ - return &http.Response{StatusCode: http.StatusInternalServerError}, nil + return nil, &http.Response{StatusCode: http.StatusInternalServerError}, nil }, responseStatus: http.StatusInternalServerError, }, { name: "status 429", - call: func() (*http.Response, error) { + ctx: context.Background(), + call: func(ctx context.Context) (*http.Request, *http.Response, error) { totalRetries++ - return &http.Response{StatusCode: http.StatusTooManyRequests}, nil + return nil, &http.Response{StatusCode: http.StatusTooManyRequests}, nil }, responseStatus: http.StatusTooManyRequests, }, { name: "status 502", - call: func() (*http.Response, error) { + ctx: context.Background(), + call: func(ctx context.Context) (*http.Request, *http.Response, error) { totalRetries++ - return &http.Response{StatusCode: http.StatusBadGateway}, nil + return nil, &http.Response{StatusCode: http.StatusBadGateway}, nil }, responseStatus: http.StatusBadGateway, }, { name: "status 403 no retry", - call: func() (*http.Response, error) { + ctx: context.Background(), + call: func(ctx context.Context) (*http.Request, *http.Response, error) { + return nil, &http.Response{StatusCode: http.StatusForbidden}, nil + }, + responseStatus: http.StatusForbidden, + }, + { + name: "Deadline exceeded", + ctx: context.Background(), + call: func(ctx context.Context) (*http.Request, *http.Response, error) { totalRetries++ + req := &http.Request{} + req = req.WithContext(ctx) + select { + case <-ctx.Done(): + return req, nil, context.DeadlineExceeded + case <-time.After(5 * time.Second): // this is greater than the deadline + return req, nil, nil + } + }, + err: errLimitExceeded, + }, + { + name: "Context cancelled", + ctx: context.Background(), + call: func(ctx context.Context) (*http.Request, *http.Response, error) { + req := &http.Request{} + req = req.WithContext(ctx) - return &http.Response{StatusCode: http.StatusForbidden}, nil + return req, nil, context.Canceled }, - responseStatus: http.StatusForbidden, + err: context.Canceled, }, { name: "no url", - call: func() (*http.Response, error) { - return nil, errors.New("http: nil Request.URL") + ctx: context.Background(), + call: func(ctx context.Context) (*http.Request, *http.Response, error) { + return nil, nil, errors.New("http: nil Request.URL") }, err: errors.New("http: nil Request.URL"), }, @@ -187,17 +226,25 @@ func TestDoRetries(t *testing.T) { for _, testcase := range testcases { tc := testcase t.Run(tc.name, func(t *testing.T) { - resp, err := DoRetries(tc.call, 1) // nolint: bodyclose + resp, err := DoRetries(tc.ctx, tc.call, 1) // nolint: bodyclose if tc.err != nil { assert.EqualError(t, err, tc.err.Error()) + if tc.err == errLimitExceeded { + assert.Equal(t, 1, totalRetries) + } else { + assert.Equal(t, 0, totalRetries) + } + + totalRetries = 0 + } else { assert.Equal(t, tc.responseStatus, resp.StatusCode) // only 429, 500 and 502 status codes should retry if tc.responseStatus == http.StatusForbidden { - assert.Equal(t, 1, totalRetries) + assert.Equal(t, 0, totalRetries) } else { - assert.Equal(t, 2, totalRetries) + assert.Equal(t, 1, totalRetries) } totalRetries = 0