From 62299cb575830dccdafbf2a81944d8f1102ba8bf Mon Sep 17 00:00:00 2001 From: Eamonn O'Toole Date: Wed, 13 Nov 2024 13:10:48 +0000 Subject: [PATCH] Fix request context cancel race when reading response (#130) --- .github/workflows/trivy.yaml | 4 +-- go.mod | 3 +- pkg/token/httpclient/httpclient.go | 2 +- pkg/token/identitytoken/identitytoken.go | 38 ++++++++++++++------ pkg/token/issuertoken/issuertoken.go | 46 ++++++++++++++++-------- pkg/token/token-util/token-util.go | 11 ++++-- pkg/token/token-util/token-util_test.go | 24 +++++-------- 7 files changed, 83 insertions(+), 45 deletions(-) diff --git a/.github/workflows/trivy.yaml b/.github/workflows/trivy.yaml index 227ee8f..3eedf59 100644 --- a/.github/workflows/trivy.yaml +++ b/.github/workflows/trivy.yaml @@ -1,4 +1,4 @@ -# Copyright 2022 Hewlett Packard Enterprise Development LP +# Copyright 2022-2024 Hewlett Packard Enterprise Development LP name: Trivy on: pull_request: @@ -11,7 +11,7 @@ jobs: uses: actions/checkout@v2 - name: Run Trivy vulnerability scanner (go.mod) - uses: aquasecurity/trivy-action@master + uses: aquasecurity/trivy-action@0.28.0 with: scan-type: 'fs' hide-progress: false diff --git a/go.mod b/go.mod index 489f623..09571b4 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,7 @@ module github.com/hewlettpackard/hpegl-provider-lib -go 1.21 +go 1.22.1 + toolchain go1.22.5 require ( diff --git a/pkg/token/httpclient/httpclient.go b/pkg/token/httpclient/httpclient.go index 534f9f0..07df9ef 100644 --- a/pkg/token/httpclient/httpclient.go +++ b/pkg/token/httpclient/httpclient.go @@ -22,7 +22,7 @@ type Client struct { // New creates a new identity Client object func New(identityServiceURL string, vendedServiceClient bool, passedInToken string) *Client { - client := &http.Client{Timeout: 10 * time.Second} + client := &http.Client{Timeout: 120 * time.Second} identityServiceURL = strings.TrimRight(identityServiceURL, "/") return &Client{ diff --git a/pkg/token/identitytoken/identitytoken.go b/pkg/token/identitytoken/identitytoken.go index 9511ff2..45907e9 100644 --- a/pkg/token/identitytoken/identitytoken.go +++ b/pkg/token/identitytoken/identitytoken.go @@ -57,16 +57,27 @@ func GenerateToken( return "", err } - resp, err := tokenutil.DoRetries(ctx, func(reqCtx context.Context) (*http.Request, *http.Response, error) { - req, errReq := http.NewRequestWithContext(reqCtx, http.MethodPost, url, strings.NewReader(string(b))) - if errReq != nil { - return nil, nil, errReq - } - req.Header.Set("Content-Type", "application/json") - resp, errResp := httpClient.Do(req) - - return req, resp, errResp - }, retryLimit) + // Create a slice of cancel functions to be returned by the retries + cancelFuncs := make([]context.CancelFunc, 0) + + resp, err := tokenutil.DoRetries( + ctx, + &cancelFuncs, + func(reqCtx context.Context) (*http.Request, *http.Response, error) { + req, errReq := http.NewRequestWithContext(reqCtx, http.MethodPost, url, strings.NewReader(string(b))) + if errReq != nil { + return nil, nil, errReq + } + req.Header.Set("Content-Type", "application/json") + respFromDo, errResp := httpClient.Do(req) + + return req, respFromDo, errResp + }, + retryLimit, + ) + // Defer execution of cancelFuncs + defer executeCancelFuncs(&cancelFuncs) + if err != nil { return "", err } @@ -91,3 +102,10 @@ func GenerateToken( return token.AccessToken, nil } + +// executeCancelFuncs executes all cancel functions in the slice +func executeCancelFuncs(cancelFuncs *[]context.CancelFunc) { + for _, cancel := range *cancelFuncs { + cancel() + } +} diff --git a/pkg/token/issuertoken/issuertoken.go b/pkg/token/issuertoken/issuertoken.go index 54a77ad..3516e90 100644 --- a/pkg/token/issuertoken/issuertoken.go +++ b/pkg/token/issuertoken/issuertoken.go @@ -40,21 +40,32 @@ func GenerateToken( return "", err } + // Create a slice of cancel functions to be returned by the retries + cancelFuncs := make([]context.CancelFunc, 0) + // Execute the request, with retries - resp, err := tokenutil.DoRetries(ctx, func(reqCtx context.Context) (*http.Request, *http.Response, error) { - // Create the request - req, errReq := createRequest(reqCtx, params, clientURL) - if errReq != nil { - return nil, nil, errReq - } - // Close the request after use, i.e. don't reuse the TCP connection - req.Close = true - - // Execute the request - resp, errResp := httpClient.Do(req) - - return req, resp, errResp - }, retryLimit) + resp, err := tokenutil.DoRetries( + ctx, + &cancelFuncs, + func(reqCtx context.Context) (*http.Request, *http.Response, error) { + // Create the request + req, errReq := createRequest(reqCtx, params, clientURL) + if errReq != nil { + return nil, nil, errReq + } + // Close the request after use, i.e. don't reuse the TCP connection + req.Close = true + + // Execute the request + respFromDo, errResp := httpClient.Do(req) + + return req, respFromDo, errResp + }, + retryLimit, + ) + // Defer execution of cancel functions + defer executeCancelFuncs(&cancelFuncs) + if err != nil { return "", err } @@ -80,6 +91,13 @@ func GenerateToken( return token.AccessToken, nil } +// executeCancelFuncs executes all cancel functions in the slice +func executeCancelFuncs(cancelFuncs *[]context.CancelFunc) { + for _, cancel := range *cancelFuncs { + cancel() + } +} + // createRequest creates a new http request func createRequest(ctx context.Context, params url.Values, clientURL string) (*http.Request, error) { req, err := http.NewRequestWithContext(ctx, http.MethodPost, clientURL, strings.NewReader(params.Encode())) diff --git a/pkg/token/token-util/token-util.go b/pkg/token/token-util/token-util.go index 1855721..7868c71 100644 --- a/pkg/token/token-util/token-util.go +++ b/pkg/token/token-util/token-util.go @@ -76,7 +76,12 @@ func DecodeAccessToken(rawToken string) (Token, error) { return token, nil } -func DoRetries(ctx context.Context, call func(ctx context.Context) (*http.Request, *http.Response, error), retries int) (*http.Response, error) { +func DoRetries( + ctx context.Context, + cancelFuncs *[]context.CancelFunc, + call func(ctx context.Context) (*http.Request, *http.Response, error), + retries int, +) (*http.Response, error) { var req *http.Request var resp *http.Response var err error @@ -91,7 +96,9 @@ func DoRetries(ctx context.Context, call func(ctx context.Context) (*http.Reques // Create a new context with a timeout ctxWithTimeout, cancel := createContextWithTimeout(ctx) - defer cancel() + + // Add the cancel function to the list of cancel functions + *cancelFuncs = append(*cancelFuncs, cancel) // Execute the request req, resp, err = call(ctxWithTimeout) diff --git a/pkg/token/token-util/token-util_test.go b/pkg/token/token-util/token-util_test.go index e6742e3..1ceda44 100644 --- a/pkg/token/token-util/token-util_test.go +++ b/pkg/token/token-util/token-util_test.go @@ -1,4 +1,4 @@ -// (C) Copyright 2021 Hewlett Packard Enterprise Development LP +// (C) Copyright 2021-2024 Hewlett Packard Enterprise Development LP package tokenutil @@ -202,17 +202,6 @@ func TestDoRetries(t *testing.T) { }, err: errLimitExceeded, }, - { - name: "Context cancelled", - ctx: context.Background(), - call: func(ctx context.Context) (*http.Request, *http.Response, error) { - req := &http.Request{} - req = req.WithContext(ctx) - - return req, nil, context.Canceled - }, - err: context.Canceled, - }, { name: "no url", ctx: context.Background(), @@ -226,13 +215,16 @@ func TestDoRetries(t *testing.T) { for _, testcase := range testcases { tc := testcase t.Run(tc.name, func(t *testing.T) { - resp, err := DoRetries(tc.ctx, tc.call, 1) // nolint: bodyclose + cancelFuncs := make([]context.CancelFunc, 0) + resp, err := DoRetries(tc.ctx, &cancelFuncs, tc.call, 2) // nolint: bodyclose if tc.err != nil { assert.EqualError(t, err, tc.err.Error()) if tc.err == errLimitExceeded { - assert.Equal(t, 1, totalRetries) + assert.Equal(t, 2, totalRetries) + assert.Equal(t, 2, len(cancelFuncs)) } else { assert.Equal(t, 0, totalRetries) + assert.Equal(t, 1, len(cancelFuncs)) } totalRetries = 0 @@ -243,8 +235,10 @@ func TestDoRetries(t *testing.T) { // only 429, 500 and 502 status codes should retry if tc.responseStatus == http.StatusForbidden { assert.Equal(t, 0, totalRetries) + assert.Equal(t, 1, len(cancelFuncs)) } else { - assert.Equal(t, 1, totalRetries) + assert.Equal(t, 2, totalRetries) + assert.Equal(t, 2, len(cancelFuncs)) } totalRetries = 0