Skip to content

Commit

Permalink
Add context dealine when calling token APIs (#126)
Browse files Browse the repository at this point in the history
In this PR we extensively rework the DoRetries function in
pkg/token-util:
- we pass-in a context from the calling function
- from this context we create a child context with a deadline of 3
  seconds, which is less than the retry interval of 5 seconds
- we pass-in a function signature that:
  - takes a context as input
  - returns a request, respone and error
- we execute this passed-in function with the new child context
- the passed-in function is expected to do the following:
  - create a new request with the context passed-in to the function,
    this is the child request with the deadline of 3 seconds
  - use a httpclient.Do function to execute the request
  - return the request, response and any error from the httpclient.Do
- if the request returned is non-nil and the context has exceeded its
  deadline, then we retry and decrement the number of retries
- if the error returned is non-nil then we return the error
- if the response status can't be retried we return the response
- if we reach the bottom of the loop, which means that we have a
  response code that can be retried, we decrement the retry counter and
  retry

The objective of the above is to cancel requests that have stalled for 3
seconds and then retry.  This is to help alleviate errors we are seeing
with IAM token generation requests that stall.

We've had to rework the calls to DoRetries from issuertoken and
identitytoken.

We've also had to rework unit-tests.  We've included a new unit-test to
check that a DeadlineExceeded error from the request context will be
retried.

Signed-off-by: Eamonn O'Toole <[email protected]>
  • Loading branch information
eamonnotoole authored Oct 25, 2024
1 parent acabdfc commit d3c8aa0
Show file tree
Hide file tree
Showing 5 changed files with 123 additions and 49 deletions.
12 changes: 0 additions & 12 deletions pkg/token/httpclient/httpclient_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,12 +122,6 @@ func setTestCase() ([]testCaseIssuer, []testCaseIdentity) {
statusCode: http.StatusNotFound,
err: errors.New("Unexpected status code 404"),
},
{
name: "no context",
url: "https://hpe-greenlake-tenant.okta.com/oauth2/default",
ctx: nil,
err: errors.New("network error in post to get token"),
},
{
name: "status code 400",
url: "https://hpe-greenlake-tenant.okta.com/oauth2/default",
Expand Down Expand Up @@ -166,12 +160,6 @@ func setTestCase() ([]testCaseIssuer, []testCaseIdentity) {
statusCode: http.StatusNotFound,
err: errors.New("Unexpected status code 404"),
},
{
name: "no context",
url: "https://client.greenlake.hpe.com/api/iam/identity",
ctx: nil,
err: errors.New("net/http: nil Context"),
},
{
name: "status code 400",
url: "https://client.greenlake.hpe.com/api/iam/identity",
Expand Down
17 changes: 9 additions & 8 deletions pkg/token/identitytoken/identitytoken.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// (C) Copyright 2021 Hewlett Packard Enterprise Development LP
// (C) Copyright 2021-2024 Hewlett Packard Enterprise Development LP

package identitytoken

Expand Down Expand Up @@ -57,14 +57,15 @@ func GenerateToken(
return "", err
}

req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, strings.NewReader(string(b)))
if err != nil {
return "", err
}
req.Header.Set("Content-Type", "application/json")
resp, err := tokenutil.DoRetries(ctx, func(reqCtx context.Context) (*http.Request, *http.Response, error) {
req, errReq := http.NewRequestWithContext(reqCtx, http.MethodPost, url, strings.NewReader(string(b)))
if errReq != nil {
return nil, nil, errReq
}
req.Header.Set("Content-Type", "application/json")
resp, errResp := httpClient.Do(req)

resp, err := tokenutil.DoRetries(func() (*http.Response, error) {
return httpClient.Do(req)
return req, resp, errResp
}, retryLimit)
if err != nil {
return "", err
Expand Down
13 changes: 8 additions & 5 deletions pkg/token/issuertoken/issuertoken.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,19 +41,22 @@ func GenerateToken(
}

// Execute the request, with retries
resp, err := tokenutil.DoRetries(func() (*http.Response, error) {
resp, err := tokenutil.DoRetries(ctx, func(reqCtx context.Context) (*http.Request, *http.Response, error) {
// Create the request
req, errReq := createRequest(ctx, params, clientURL)
req, errReq := createRequest(reqCtx, params, clientURL)
if errReq != nil {
return nil, errReq
return nil, nil, errReq
}
// Close the request after use, i.e. don't reuse the TCP connection
req.Close = true

return httpClient.Do(req)
// Execute the request
resp, errResp := httpClient.Do(req)

return req, resp, errResp
}, retryLimit)
if err != nil {
return "", fmt.Errorf("network error in post to get token")
return "", err
}
defer resp.Body.Close()

Expand Down
53 changes: 44 additions & 9 deletions pkg/token/token-util/token-util.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
package tokenutil

import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
Expand Down Expand Up @@ -75,26 +76,60 @@ func DecodeAccessToken(rawToken string) (Token, error) {
return token, nil
}

func DoRetries(call func() (*http.Response, error), retries int) (*http.Response, error) {
func DoRetries(ctx context.Context, call func(ctx context.Context) (*http.Request, *http.Response, error), retries int) (*http.Response, error) {
var req *http.Request
var resp *http.Response
var err error

for {
resp, err = call()
// If retries are exhausted, return an error
if retries == 0 {
return resp, errors.MakeErrInternalError(errors.ErrorResponse{
ErrorCode: "ErrGenerateTokenRetryLimitExceeded",
Message: "Retry limit exceeded"})
}

// Create a new context with a timeout
ctxWithTimeout, cancel := createContextWithTimeout(ctx)
defer cancel()

// Execute the request
req, resp, err = call(ctxWithTimeout)

// If the error is due to a context timeout, retry the request
if req != nil && req.Context().Err() == context.DeadlineExceeded {
retries = sleepAndDecrementRetries(retries)

continue
}

// For all other errors, return the error
if err != nil {
return nil, err
return resp, err
}

if !isStatusRetryable(resp.StatusCode) || retries == 0 {
break
// If the status code is not retryable, return the response
if !isStatusRetryable(resp.StatusCode) {
return resp, nil
}

log.Printf("Retrying request, retries left: %v", retries)
time.Sleep(3 * time.Second)
retries--
retries = sleepAndDecrementRetries(retries)
}
}

func createContextWithTimeout(ctx context.Context) (context.Context, context.CancelFunc) {
if ctx == nil {
return context.WithTimeout(context.Background(), 3*time.Second)
}

return context.WithTimeout(ctx, 3*time.Second)
}

func sleepAndDecrementRetries(retries int) int {
log.Printf("Retrying request, retries left: %v", retries)
time.Sleep(5 * time.Second)

return resp, nil
return retries - 1
}

func ManageHTTPErrorCodes(resp *http.Response, clientID string) error {
Expand Down
77 changes: 62 additions & 15 deletions pkg/token/token-util/token-util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,22 @@
package tokenutil

import (
"context"
"errors"
"net/http"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

hpeglErrors "github.com/hewlettpackard/hpegl-provider-lib/pkg/token/errors"
)

var errLimitExceeded = hpeglErrors.MakeErrInternalError(hpeglErrors.ErrorResponse{
ErrorCode: "ErrGenerateTokenRetryLimitExceeded",
Message: "Retry limit exceeded"})

//nolint:scopelint
func TestDecodeAccessToken(t *testing.T) {
type args struct {
Expand Down Expand Up @@ -135,50 +143,81 @@ func TestDoRetries(t *testing.T) {
totalRetries := 0
testcases := []struct {
name string
call func() (*http.Response, error)
ctx context.Context
call func(ctx context.Context) (*http.Request, *http.Response, error)
responseStatus int
err error
}{
{
name: "status 500",
call: func() (*http.Response, error) {
ctx: context.Background(),
call: func(ctx context.Context) (*http.Request, *http.Response, error) {
totalRetries++

return &http.Response{StatusCode: http.StatusInternalServerError}, nil
return nil, &http.Response{StatusCode: http.StatusInternalServerError}, nil
},
responseStatus: http.StatusInternalServerError,
},
{
name: "status 429",
call: func() (*http.Response, error) {
ctx: context.Background(),
call: func(ctx context.Context) (*http.Request, *http.Response, error) {
totalRetries++

return &http.Response{StatusCode: http.StatusTooManyRequests}, nil
return nil, &http.Response{StatusCode: http.StatusTooManyRequests}, nil
},
responseStatus: http.StatusTooManyRequests,
},
{
name: "status 502",
call: func() (*http.Response, error) {
ctx: context.Background(),
call: func(ctx context.Context) (*http.Request, *http.Response, error) {
totalRetries++

return &http.Response{StatusCode: http.StatusBadGateway}, nil
return nil, &http.Response{StatusCode: http.StatusBadGateway}, nil
},
responseStatus: http.StatusBadGateway,
},
{
name: "status 403 no retry",
call: func() (*http.Response, error) {
ctx: context.Background(),
call: func(ctx context.Context) (*http.Request, *http.Response, error) {
return nil, &http.Response{StatusCode: http.StatusForbidden}, nil
},
responseStatus: http.StatusForbidden,
},
{
name: "Deadline exceeded",
ctx: context.Background(),
call: func(ctx context.Context) (*http.Request, *http.Response, error) {
totalRetries++
req := &http.Request{}
req = req.WithContext(ctx)
select {
case <-ctx.Done():
return req, nil, context.DeadlineExceeded
case <-time.After(5 * time.Second): // this is greater than the deadline
return req, nil, nil
}
},
err: errLimitExceeded,
},
{
name: "Context cancelled",
ctx: context.Background(),
call: func(ctx context.Context) (*http.Request, *http.Response, error) {
req := &http.Request{}
req = req.WithContext(ctx)

return &http.Response{StatusCode: http.StatusForbidden}, nil
return req, nil, context.Canceled
},
responseStatus: http.StatusForbidden,
err: context.Canceled,
},
{
name: "no url",
call: func() (*http.Response, error) {
return nil, errors.New("http: nil Request.URL")
ctx: context.Background(),
call: func(ctx context.Context) (*http.Request, *http.Response, error) {
return nil, nil, errors.New("http: nil Request.URL")
},
err: errors.New("http: nil Request.URL"),
},
Expand All @@ -187,17 +226,25 @@ func TestDoRetries(t *testing.T) {
for _, testcase := range testcases {
tc := testcase
t.Run(tc.name, func(t *testing.T) {
resp, err := DoRetries(tc.call, 1) // nolint: bodyclose
resp, err := DoRetries(tc.ctx, tc.call, 1) // nolint: bodyclose
if tc.err != nil {
assert.EqualError(t, err, tc.err.Error())
if tc.err == errLimitExceeded {
assert.Equal(t, 1, totalRetries)
} else {
assert.Equal(t, 0, totalRetries)
}

totalRetries = 0

} else {
assert.Equal(t, tc.responseStatus, resp.StatusCode)

// only 429, 500 and 502 status codes should retry
if tc.responseStatus == http.StatusForbidden {
assert.Equal(t, 1, totalRetries)
assert.Equal(t, 0, totalRetries)
} else {
assert.Equal(t, 2, totalRetries)
assert.Equal(t, 1, totalRetries)
}

totalRetries = 0
Expand Down

0 comments on commit d3c8aa0

Please sign in to comment.