From 8da490f5930406180bef6f4b0b99e0b0dc86dff8 Mon Sep 17 00:00:00 2001 From: mmetc <92726601+mmetc@users.noreply.github.com> Date: Thu, 22 Feb 2024 11:42:33 +0100 Subject: [PATCH] refact pkg/apiclient (#2846) * extract resperr.go * extract method prepareRequest() * reset token inside mutex --- pkg/apiclient/auth_jwt.go | 37 +++++++++++++++++++++---------- pkg/apiclient/client.go | 36 ------------------------------ pkg/apiclient/resperr.go | 46 +++++++++++++++++++++++++++++++++++++++ pkg/apiserver/apic.go | 1 - 4 files changed, 72 insertions(+), 48 deletions(-) create mode 100644 pkg/apiclient/resperr.go diff --git a/pkg/apiclient/auth_jwt.go b/pkg/apiclient/auth_jwt.go index 71b0e273105..2ead10cf6da 100644 --- a/pkg/apiclient/auth_jwt.go +++ b/pkg/apiclient/auth_jwt.go @@ -130,20 +130,24 @@ func (t *JWTTransport) refreshJwtToken() error { return nil } -// RoundTrip implements the RoundTripper interface. -func (t *JWTTransport) RoundTrip(req *http.Request) (*http.Response, error) { - // In a few occasions several goroutines will execute refreshJwtToken concurrently which is useless and will cause overload on CAPI - // we use a mutex to avoid this - // We also bypass the refresh if we are requesting the login endpoint, as it does not require a token, and it leads to do 2 requests instead of one (refresh + actual login request) +func (t *JWTTransport) needsTokenRefresh() bool { + return t.Token == "" || t.Expiration.Add(-time.Minute).Before(time.Now().UTC()) +} + +// prepareRequest returns a copy of the request with the necessary authentication headers. +func (t *JWTTransport) prepareRequest(req *http.Request) (*http.Request, error) { + // In a few occasions several goroutines will execute refreshJwtToken concurrently which is useless + // and will cause overload on CAPI. We use a mutex to avoid this. t.refreshTokenMutex.Lock() - if req.URL.Path != "/"+t.VersionPrefix+"/watchers/login" && (t.Token == "" || t.Expiration.Add(-time.Minute).Before(time.Now().UTC())) { - if err := t.refreshJwtToken(); err != nil { - t.refreshTokenMutex.Unlock() + defer t.refreshTokenMutex.Unlock() + // We bypass the refresh if we are requesting the login endpoint, as it does not require a token, + // and it leads to do 2 requests instead of one (refresh + actual login request). + if req.URL.Path != "/"+t.VersionPrefix+"/watchers/login" && t.needsTokenRefresh() { + if err := t.refreshJwtToken(); err != nil { return nil, err } } - t.refreshTokenMutex.Unlock() if t.UserAgent != "" { req.Header.Add("User-Agent", t.UserAgent) @@ -151,6 +155,16 @@ func (t *JWTTransport) RoundTrip(req *http.Request) (*http.Response, error) { req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", t.Token)) + return req, nil +} + +// RoundTrip implements the RoundTripper interface. +func (t *JWTTransport) RoundTrip(req *http.Request) (*http.Response, error) { + req, err := t.prepareRequest(req) + if err != nil { + return nil, err + } + if log.GetLevel() >= log.TraceLevel { //requestToDump := cloneRequest(req) dump, _ := httputil.DumpRequest(req, true) @@ -166,7 +180,7 @@ func (t *JWTTransport) RoundTrip(req *http.Request) (*http.Response, error) { if err != nil { // we had an error (network error for example, or 401 because token is refused), reset the token? - t.Token = "" + t.ResetToken() return resp, fmt.Errorf("performing jwt auth: %w", err) } @@ -189,7 +203,8 @@ func (t *JWTTransport) ResetToken() { t.refreshTokenMutex.Unlock() } -// transport() returns a round tripper that retries once when the status is unauthorized, and 5 times when the infrastructure is overloaded. +// transport() returns a round tripper that retries once when the status is unauthorized, +// and 5 times when the infrastructure is overloaded. func (t *JWTTransport) transport() http.RoundTripper { transport := t.Transport if transport == nil { diff --git a/pkg/apiclient/client.go b/pkg/apiclient/client.go index b183a8c7909..b487f68a698 100644 --- a/pkg/apiclient/client.go +++ b/pkg/apiclient/client.go @@ -4,9 +4,7 @@ import ( "context" "crypto/tls" "crypto/x509" - "encoding/json" "fmt" - "io" "net/http" "net/url" @@ -167,44 +165,10 @@ type Response struct { //... } -type ErrorResponse struct { - models.ErrorResponse -} - -func (e *ErrorResponse) Error() string { - err := fmt.Sprintf("API error: %s", *e.Message) - if len(e.Errors) > 0 { - err += fmt.Sprintf(" (%s)", e.Errors) - } - - return err -} - func newResponse(r *http.Response) *Response { return &Response{Response: r} } -func CheckResponse(r *http.Response) error { - if c := r.StatusCode; 200 <= c && c <= 299 || c == 304 { - return nil - } - - errorResponse := &ErrorResponse{} - - data, err := io.ReadAll(r.Body) - if err == nil && len(data)>0 { - err := json.Unmarshal(data, errorResponse) - if err != nil { - return fmt.Errorf("http code %d, invalid body: %w", r.StatusCode, err) - } - } else { - errorResponse.Message = new(string) - *errorResponse.Message = fmt.Sprintf("http code %d, no error message", r.StatusCode) - } - - return errorResponse -} - type ListOpts struct { //Page int //PerPage int diff --git a/pkg/apiclient/resperr.go b/pkg/apiclient/resperr.go new file mode 100644 index 00000000000..ff954a73609 --- /dev/null +++ b/pkg/apiclient/resperr.go @@ -0,0 +1,46 @@ +package apiclient + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + + "github.com/crowdsecurity/go-cs-lib/ptr" + + "github.com/crowdsecurity/crowdsec/pkg/models" +) + +type ErrorResponse struct { + models.ErrorResponse +} + +func (e *ErrorResponse) Error() string { + err := fmt.Sprintf("API error: %s", *e.Message) + if len(e.Errors) > 0 { + err += fmt.Sprintf(" (%s)", e.Errors) + } + + return err +} + +// CheckResponse verifies the API response and builds an appropriate Go error if necessary. +func CheckResponse(r *http.Response) error { + if c := r.StatusCode; 200 <= c && c <= 299 || c == 304 { + return nil + } + + ret := &ErrorResponse{} + + data, err := io.ReadAll(r.Body) + if err != nil || len(data) == 0 { + ret.Message = ptr.Of(fmt.Sprintf("http code %d, no error message", r.StatusCode)) + return ret + } + + if err := json.Unmarshal(data, ret); err != nil { + return fmt.Errorf("http code %d, invalid body: %w", r.StatusCode, err) + } + + return ret +} diff --git a/pkg/apiserver/apic.go b/pkg/apiserver/apic.go index 2fdb01144a0..f57ae685e45 100644 --- a/pkg/apiserver/apic.go +++ b/pkg/apiserver/apic.go @@ -539,7 +539,6 @@ func createAlertForDecision(decision *models.Decision) *models.Alert { scenario = *decision.Scenario scope = types.ListOrigin default: - // XXX: this or nil? scenario = "" scope = ""