From d3c8aa03fb9fc68270a463c37fca6d5918bb59e9 Mon Sep 17 00:00:00 2001 From: Eamonn O'Toole Date: Fri, 25 Oct 2024 09:34:08 +0100 Subject: [PATCH] Add context dealine when calling token APIs (#126) In this PR we extensively rework the DoRetries function in pkg/token-util: - we pass-in a context from the calling function - from this context we create a child context with a deadline of 3 seconds, which is less than the retry interval of 5 seconds - we pass-in a function signature that: - takes a context as input - returns a request, respone and error - we execute this passed-in function with the new child context - the passed-in function is expected to do the following: - create a new request with the context passed-in to the function, this is the child request with the deadline of 3 seconds - use a httpclient.Do function to execute the request - return the request, response and any error from the httpclient.Do - if the request returned is non-nil and the context has exceeded its deadline, then we retry and decrement the number of retries - if the error returned is non-nil then we return the error - if the response status can't be retried we return the response - if we reach the bottom of the loop, which means that we have a response code that can be retried, we decrement the retry counter and retry The objective of the above is to cancel requests that have stalled for 3 seconds and then retry. This is to help alleviate errors we are seeing with IAM token generation requests that stall. We've had to rework the calls to DoRetries from issuertoken and identitytoken. We've also had to rework unit-tests. We've included a new unit-test to check that a DeadlineExceeded error from the request context will be retried. Signed-off-by: Eamonn O'Toole --- pkg/token/httpclient/httpclient_test.go | 12 ---- pkg/token/identitytoken/identitytoken.go | 17 +++--- pkg/token/issuertoken/issuertoken.go | 13 ++-- pkg/token/token-util/token-util.go | 53 +++++++++++++--- pkg/token/token-util/token-util_test.go | 77 +++++++++++++++++++----- 5 files changed, 123 insertions(+), 49 deletions(-) diff --git a/pkg/token/httpclient/httpclient_test.go b/pkg/token/httpclient/httpclient_test.go index 9cc4626..26c556f 100644 --- a/pkg/token/httpclient/httpclient_test.go +++ b/pkg/token/httpclient/httpclient_test.go @@ -122,12 +122,6 @@ func setTestCase() ([]testCaseIssuer, []testCaseIdentity) { statusCode: http.StatusNotFound, err: errors.New("Unexpected status code 404"), }, - { - name: "no context", - url: "https://hpe-greenlake-tenant.okta.com/oauth2/default", - ctx: nil, - err: errors.New("network error in post to get token"), - }, { name: "status code 400", url: "https://hpe-greenlake-tenant.okta.com/oauth2/default", @@ -166,12 +160,6 @@ func setTestCase() ([]testCaseIssuer, []testCaseIdentity) { statusCode: http.StatusNotFound, err: errors.New("Unexpected status code 404"), }, - { - name: "no context", - url: "https://client.greenlake.hpe.com/api/iam/identity", - ctx: nil, - err: errors.New("net/http: nil Context"), - }, { name: "status code 400", url: "https://client.greenlake.hpe.com/api/iam/identity", diff --git a/pkg/token/identitytoken/identitytoken.go b/pkg/token/identitytoken/identitytoken.go index bf79126..9511ff2 100644 --- a/pkg/token/identitytoken/identitytoken.go +++ b/pkg/token/identitytoken/identitytoken.go @@ -1,4 +1,4 @@ -// (C) Copyright 2021 Hewlett Packard Enterprise Development LP +// (C) Copyright 2021-2024 Hewlett Packard Enterprise Development LP package identitytoken @@ -57,14 +57,15 @@ func GenerateToken( return "", err } - req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, strings.NewReader(string(b))) - if err != nil { - return "", err - } - req.Header.Set("Content-Type", "application/json") + resp, err := tokenutil.DoRetries(ctx, func(reqCtx context.Context) (*http.Request, *http.Response, error) { + req, errReq := http.NewRequestWithContext(reqCtx, http.MethodPost, url, strings.NewReader(string(b))) + if errReq != nil { + return nil, nil, errReq + } + req.Header.Set("Content-Type", "application/json") + resp, errResp := httpClient.Do(req) - resp, err := tokenutil.DoRetries(func() (*http.Response, error) { - return httpClient.Do(req) + return req, resp, errResp }, retryLimit) if err != nil { return "", err diff --git a/pkg/token/issuertoken/issuertoken.go b/pkg/token/issuertoken/issuertoken.go index 7271860..54a77ad 100644 --- a/pkg/token/issuertoken/issuertoken.go +++ b/pkg/token/issuertoken/issuertoken.go @@ -41,19 +41,22 @@ func GenerateToken( } // Execute the request, with retries - resp, err := tokenutil.DoRetries(func() (*http.Response, error) { + resp, err := tokenutil.DoRetries(ctx, func(reqCtx context.Context) (*http.Request, *http.Response, error) { // Create the request - req, errReq := createRequest(ctx, params, clientURL) + req, errReq := createRequest(reqCtx, params, clientURL) if errReq != nil { - return nil, errReq + return nil, nil, errReq } // Close the request after use, i.e. don't reuse the TCP connection req.Close = true - return httpClient.Do(req) + // Execute the request + resp, errResp := httpClient.Do(req) + + return req, resp, errResp }, retryLimit) if err != nil { - return "", fmt.Errorf("network error in post to get token") + return "", err } defer resp.Body.Close() diff --git a/pkg/token/token-util/token-util.go b/pkg/token/token-util/token-util.go index 05d1fb6..1855721 100644 --- a/pkg/token/token-util/token-util.go +++ b/pkg/token/token-util/token-util.go @@ -3,6 +3,7 @@ package tokenutil import ( + "context" "encoding/base64" "encoding/json" "fmt" @@ -75,26 +76,60 @@ func DecodeAccessToken(rawToken string) (Token, error) { return token, nil } -func DoRetries(call func() (*http.Response, error), retries int) (*http.Response, error) { +func DoRetries(ctx context.Context, call func(ctx context.Context) (*http.Request, *http.Response, error), retries int) (*http.Response, error) { + var req *http.Request var resp *http.Response var err error for { - resp, err = call() + // If retries are exhausted, return an error + if retries == 0 { + return resp, errors.MakeErrInternalError(errors.ErrorResponse{ + ErrorCode: "ErrGenerateTokenRetryLimitExceeded", + Message: "Retry limit exceeded"}) + } + + // Create a new context with a timeout + ctxWithTimeout, cancel := createContextWithTimeout(ctx) + defer cancel() + + // Execute the request + req, resp, err = call(ctxWithTimeout) + + // If the error is due to a context timeout, retry the request + if req != nil && req.Context().Err() == context.DeadlineExceeded { + retries = sleepAndDecrementRetries(retries) + + continue + } + + // For all other errors, return the error if err != nil { - return nil, err + return resp, err } - if !isStatusRetryable(resp.StatusCode) || retries == 0 { - break + // If the status code is not retryable, return the response + if !isStatusRetryable(resp.StatusCode) { + return resp, nil } - log.Printf("Retrying request, retries left: %v", retries) - time.Sleep(3 * time.Second) - retries-- + retries = sleepAndDecrementRetries(retries) } +} + +func createContextWithTimeout(ctx context.Context) (context.Context, context.CancelFunc) { + if ctx == nil { + return context.WithTimeout(context.Background(), 3*time.Second) + } + + return context.WithTimeout(ctx, 3*time.Second) +} + +func sleepAndDecrementRetries(retries int) int { + log.Printf("Retrying request, retries left: %v", retries) + time.Sleep(5 * time.Second) - return resp, nil + return retries - 1 } func ManageHTTPErrorCodes(resp *http.Response, clientID string) error { diff --git a/pkg/token/token-util/token-util_test.go b/pkg/token/token-util/token-util_test.go index 5a56f64..e6742e3 100644 --- a/pkg/token/token-util/token-util_test.go +++ b/pkg/token/token-util/token-util_test.go @@ -3,14 +3,22 @@ package tokenutil import ( + "context" "errors" "net/http" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + hpeglErrors "github.com/hewlettpackard/hpegl-provider-lib/pkg/token/errors" ) +var errLimitExceeded = hpeglErrors.MakeErrInternalError(hpeglErrors.ErrorResponse{ + ErrorCode: "ErrGenerateTokenRetryLimitExceeded", + Message: "Retry limit exceeded"}) + //nolint:scopelint func TestDecodeAccessToken(t *testing.T) { type args struct { @@ -135,50 +143,81 @@ func TestDoRetries(t *testing.T) { totalRetries := 0 testcases := []struct { name string - call func() (*http.Response, error) + ctx context.Context + call func(ctx context.Context) (*http.Request, *http.Response, error) responseStatus int err error }{ { name: "status 500", - call: func() (*http.Response, error) { + ctx: context.Background(), + call: func(ctx context.Context) (*http.Request, *http.Response, error) { totalRetries++ - return &http.Response{StatusCode: http.StatusInternalServerError}, nil + return nil, &http.Response{StatusCode: http.StatusInternalServerError}, nil }, responseStatus: http.StatusInternalServerError, }, { name: "status 429", - call: func() (*http.Response, error) { + ctx: context.Background(), + call: func(ctx context.Context) (*http.Request, *http.Response, error) { totalRetries++ - return &http.Response{StatusCode: http.StatusTooManyRequests}, nil + return nil, &http.Response{StatusCode: http.StatusTooManyRequests}, nil }, responseStatus: http.StatusTooManyRequests, }, { name: "status 502", - call: func() (*http.Response, error) { + ctx: context.Background(), + call: func(ctx context.Context) (*http.Request, *http.Response, error) { totalRetries++ - return &http.Response{StatusCode: http.StatusBadGateway}, nil + return nil, &http.Response{StatusCode: http.StatusBadGateway}, nil }, responseStatus: http.StatusBadGateway, }, { name: "status 403 no retry", - call: func() (*http.Response, error) { + ctx: context.Background(), + call: func(ctx context.Context) (*http.Request, *http.Response, error) { + return nil, &http.Response{StatusCode: http.StatusForbidden}, nil + }, + responseStatus: http.StatusForbidden, + }, + { + name: "Deadline exceeded", + ctx: context.Background(), + call: func(ctx context.Context) (*http.Request, *http.Response, error) { totalRetries++ + req := &http.Request{} + req = req.WithContext(ctx) + select { + case <-ctx.Done(): + return req, nil, context.DeadlineExceeded + case <-time.After(5 * time.Second): // this is greater than the deadline + return req, nil, nil + } + }, + err: errLimitExceeded, + }, + { + name: "Context cancelled", + ctx: context.Background(), + call: func(ctx context.Context) (*http.Request, *http.Response, error) { + req := &http.Request{} + req = req.WithContext(ctx) - return &http.Response{StatusCode: http.StatusForbidden}, nil + return req, nil, context.Canceled }, - responseStatus: http.StatusForbidden, + err: context.Canceled, }, { name: "no url", - call: func() (*http.Response, error) { - return nil, errors.New("http: nil Request.URL") + ctx: context.Background(), + call: func(ctx context.Context) (*http.Request, *http.Response, error) { + return nil, nil, errors.New("http: nil Request.URL") }, err: errors.New("http: nil Request.URL"), }, @@ -187,17 +226,25 @@ func TestDoRetries(t *testing.T) { for _, testcase := range testcases { tc := testcase t.Run(tc.name, func(t *testing.T) { - resp, err := DoRetries(tc.call, 1) // nolint: bodyclose + resp, err := DoRetries(tc.ctx, tc.call, 1) // nolint: bodyclose if tc.err != nil { assert.EqualError(t, err, tc.err.Error()) + if tc.err == errLimitExceeded { + assert.Equal(t, 1, totalRetries) + } else { + assert.Equal(t, 0, totalRetries) + } + + totalRetries = 0 + } else { assert.Equal(t, tc.responseStatus, resp.StatusCode) // only 429, 500 and 502 status codes should retry if tc.responseStatus == http.StatusForbidden { - assert.Equal(t, 1, totalRetries) + assert.Equal(t, 0, totalRetries) } else { - assert.Equal(t, 2, totalRetries) + assert.Equal(t, 1, totalRetries) } totalRetries = 0