Skip to content

Commit

Permalink
jwt transport: fix retry on unauthorized from CAPI(#3006)
Browse files Browse the repository at this point in the history
  • Loading branch information
blotus authored May 24, 2024
1 parent 09afcbe commit f06e3e7
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 45 deletions.
103 changes: 58 additions & 45 deletions pkg/apiclient/auth_jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ type JWTTransport struct {
URL *url.URL
VersionPrefix string
UserAgent string
RetryConfig *RetryConfig
// Transport is the underlying HTTP transport to use when making requests.
// It will default to http.DefaultTransport if nil.
Transport http.RoundTripper
Expand Down Expand Up @@ -165,36 +166,67 @@ func (t *JWTTransport) prepareRequest(req *http.Request) (*http.Request, error)

// RoundTrip implements the RoundTripper interface.
func (t *JWTTransport) RoundTrip(req *http.Request) (*http.Response, error) {
req, err := t.prepareRequest(req)
if err != nil {
return nil, err
}

if log.GetLevel() >= log.TraceLevel {
// requestToDump := cloneRequest(req)
dump, _ := httputil.DumpRequest(req, true)
log.Tracef("req-jwt: %s", string(dump))
}
var resp *http.Response
attemptsCount := make(map[int]int)

// Make the HTTP request.
resp, err := t.transport().RoundTrip(req)
if log.GetLevel() >= log.TraceLevel {
dump, _ := httputil.DumpResponse(resp, true)
log.Tracef("resp-jwt: %s (err:%v)", string(dump), err)
}
for {
if log.GetLevel() >= log.TraceLevel {
// requestToDump := cloneRequest(req)
dump, _ := httputil.DumpRequest(req, true)
log.Tracef("req-jwt: %s", string(dump))
}
// Make the HTTP request.
clonedReq := cloneRequest(req)

if err != nil {
// we had an error (network error for example, or 401 because token is refused), reset the token?
t.ResetToken()
clonedReq, err := t.prepareRequest(clonedReq)
if err != nil {
return nil, err
}

return resp, fmt.Errorf("performing jwt auth: %w", err)
}
resp, err = t.transport().RoundTrip(clonedReq)
if log.GetLevel() >= log.TraceLevel {
dump, _ := httputil.DumpResponse(resp, true)
log.Tracef("resp-jwt: %s (err:%v)", string(dump), err)
}

if resp != nil {
log.Debugf("resp-jwt: %d", resp.StatusCode)
}
if err != nil {
// we had an error (network error for example), reset the token?
t.ResetToken()
return resp, fmt.Errorf("performing jwt auth: %w", err)
}

if resp != nil {
log.Debugf("resp-jwt: %d", resp.StatusCode)
}

config, shouldRetry := t.RetryConfig.StatusCodeConfig[resp.StatusCode]
if !shouldRetry {
break
}

if attemptsCount[resp.StatusCode] >= config.MaxAttempts {
log.Infof("max attempts reached for status code %d", resp.StatusCode)
break
}

if config.InvalidateToken {
log.Debugf("invalidating token for status code %d", resp.StatusCode)
t.ResetToken()
}

log.Debugf("retrying request to %s", req.URL.String())
attemptsCount[resp.StatusCode]++
log.Infof("attempt %d out of %d", attemptsCount[resp.StatusCode], config.MaxAttempts)

if config.Backoff {
backoff := 2*attemptsCount[resp.StatusCode] + 5
log.Infof("retrying in %d seconds (attempt %d of %d)", backoff, attemptsCount[resp.StatusCode], config.MaxAttempts)
time.Sleep(time.Duration(backoff) * time.Second)
}
}
return resp, nil

}

func (t *JWTTransport) Client() *http.Client {
Expand All @@ -211,27 +243,8 @@ func (t *JWTTransport) ResetToken() {
// transport() returns a round tripper that retries once when the status is unauthorized,
// and 5 times when the infrastructure is overloaded.
func (t *JWTTransport) transport() http.RoundTripper {
transport := t.Transport
if transport == nil {
transport = http.DefaultTransport
}

return &retryRoundTripper{
next: &retryRoundTripper{
next: transport,
maxAttempts: 5,
withBackOff: true,
retryStatusCodes: []int{http.StatusTooManyRequests, http.StatusServiceUnavailable, http.StatusGatewayTimeout},
},
maxAttempts: 2,
withBackOff: false,
retryStatusCodes: []int{http.StatusUnauthorized, http.StatusForbidden},
onBeforeRequest: func(attempt int) {
// reset the token only in the second attempt as this is when we know we had a 401 or 403
// the second attempt is supposed to refresh the token
if attempt > 0 {
t.ResetToken()
}
},
if t.Transport != nil {
return t.Transport
}
return http.DefaultTransport
}
7 changes: 7 additions & 0 deletions pkg/apiclient/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,13 @@ func NewClient(config *Config) (*ApiClient, error) {
UserAgent: config.UserAgent,
VersionPrefix: config.VersionPrefix,
UpdateScenario: config.UpdateScenario,
RetryConfig: NewRetryConfig(
WithStatusCodeConfig(http.StatusUnauthorized, 2, false, true),
WithStatusCodeConfig(http.StatusForbidden, 2, false, true),
WithStatusCodeConfig(http.StatusTooManyRequests, 5, true, false),
WithStatusCodeConfig(http.StatusServiceUnavailable, 5, true, false),
WithStatusCodeConfig(http.StatusGatewayTimeout, 5, true, false),
),
}

transport, baseURL := createTransport(config.URL)
Expand Down
33 changes: 33 additions & 0 deletions pkg/apiclient/retry_config.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package apiclient

type StatusCodeConfig struct {
MaxAttempts int
Backoff bool
InvalidateToken bool
}

type RetryConfig struct {
StatusCodeConfig map[int]StatusCodeConfig
}

type RetryConfigOption func(*RetryConfig)

func NewRetryConfig(options ...RetryConfigOption) *RetryConfig {
rc := &RetryConfig{
StatusCodeConfig: make(map[int]StatusCodeConfig),
}
for _, opt := range options {
opt(rc)
}
return rc
}

func WithStatusCodeConfig(statusCode int, maxAttempts int, backOff bool, invalidateToken bool) RetryConfigOption {
return func(rc *RetryConfig) {
rc.StatusCodeConfig[statusCode] = StatusCodeConfig{
MaxAttempts: maxAttempts,
Backoff: backOff,
InvalidateToken: invalidateToken,
}
}
}

0 comments on commit f06e3e7

Please sign in to comment.