Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refact pkg/apiclient #2846

Merged
merged 4 commits into from
Feb 22, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 28 additions & 11 deletions pkg/apiclient/auth_jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,27 +130,43 @@ func (t *JWTTransport) refreshJwtToken() error {
return nil
}

// RoundTrip implements the RoundTripper interface.
func (t *JWTTransport) RoundTrip(req *http.Request) (*http.Response, error) {
// In a few occasions several goroutines will execute refreshJwtToken concurrently which is useless and will cause overload on CAPI
// we use a mutex to avoid this
// We also bypass the refresh if we are requesting the login endpoint, as it does not require a token, and it leads to do 2 requests instead of one (refresh + actual login request)
func (t *JWTTransport) needsTokenRefresh() bool {
return t.Token == "" || t.Expiration.Add(-time.Minute).Before(time.Now().UTC())
}

// prepareRequest returns a copy of the request with the necessary authentication headers.
func (t *JWTTransport) prepareRequest(req *http.Request) (*http.Request, error) {
req = cloneRequest(req)
LaurenceJJones marked this conversation as resolved.
Show resolved Hide resolved

// In a few occasions several goroutines will execute refreshJwtToken concurrently which is useless
// and will cause overload on CAPI. We use a mutex to avoid this.
t.refreshTokenMutex.Lock()
if req.URL.Path != "/"+t.VersionPrefix+"/watchers/login" && (t.Token == "" || t.Expiration.Add(-time.Minute).Before(time.Now().UTC())) {
if err := t.refreshJwtToken(); err != nil {
t.refreshTokenMutex.Unlock()
defer t.refreshTokenMutex.Unlock()

// We bypass the refresh if we are requesting the login endpoint, as it does not require a token,
// and it leads to do 2 requests instead of one (refresh + actual login request).
if req.URL.Path != "/"+t.VersionPrefix+"/watchers/login" && t.needsTokenRefresh() {
if err := t.refreshJwtToken(); err != nil {
return nil, err
}
}
t.refreshTokenMutex.Unlock()

if t.UserAgent != "" {
req.Header.Add("User-Agent", t.UserAgent)
}

req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", t.Token))

return req, nil
}

// RoundTrip implements the RoundTripper interface.
func (t *JWTTransport) RoundTrip(req *http.Request) (*http.Response, error) {
req, err := t.prepareRequest(req)
if err != nil {
return nil, err
}

if log.GetLevel() >= log.TraceLevel {
//requestToDump := cloneRequest(req)
dump, _ := httputil.DumpRequest(req, true)
Expand All @@ -166,7 +182,7 @@ func (t *JWTTransport) RoundTrip(req *http.Request) (*http.Response, error) {

if err != nil {
// we had an error (network error for example, or 401 because token is refused), reset the token?
t.Token = ""
t.ResetToken()

return resp, fmt.Errorf("performing jwt auth: %w", err)
}
Expand All @@ -189,7 +205,8 @@ func (t *JWTTransport) ResetToken() {
t.refreshTokenMutex.Unlock()
}

// transport() returns a round tripper that retries once when the status is unauthorized, and 5 times when the infrastructure is overloaded.
// transport() returns a round tripper that retries once when the status is unauthorized,
// and 5 times when the infrastructure is overloaded.
func (t *JWTTransport) transport() http.RoundTripper {
transport := t.Transport
if transport == nil {
Expand Down
36 changes: 0 additions & 36 deletions pkg/apiclient/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@ import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"

Expand Down Expand Up @@ -167,44 +165,10 @@ type Response struct {
//...
}

type ErrorResponse struct {
models.ErrorResponse
}

func (e *ErrorResponse) Error() string {
err := fmt.Sprintf("API error: %s", *e.Message)
if len(e.Errors) > 0 {
err += fmt.Sprintf(" (%s)", e.Errors)
}

return err
}

func newResponse(r *http.Response) *Response {
return &Response{Response: r}
}

func CheckResponse(r *http.Response) error {
if c := r.StatusCode; 200 <= c && c <= 299 || c == 304 {
return nil
}

errorResponse := &ErrorResponse{}

data, err := io.ReadAll(r.Body)
if err == nil && len(data)>0 {
err := json.Unmarshal(data, errorResponse)
if err != nil {
return fmt.Errorf("http code %d, invalid body: %w", r.StatusCode, err)
}
} else {
errorResponse.Message = new(string)
*errorResponse.Message = fmt.Sprintf("http code %d, no error message", r.StatusCode)
}

return errorResponse
}

type ListOpts struct {
//Page int
//PerPage int
Expand Down
46 changes: 46 additions & 0 deletions pkg/apiclient/resperr.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package apiclient

import (
"encoding/json"
"fmt"
"io"
"net/http"

"github.com/crowdsecurity/go-cs-lib/ptr"

"github.com/crowdsecurity/crowdsec/pkg/models"
)

type ErrorResponse struct {
models.ErrorResponse
}

func (e *ErrorResponse) Error() string {
err := fmt.Sprintf("API error: %s", *e.Message)
if len(e.Errors) > 0 {
err += fmt.Sprintf(" (%s)", e.Errors)
}

Check warning on line 22 in pkg/apiclient/resperr.go

View check run for this annotation

Codecov / codecov/patch

pkg/apiclient/resperr.go#L21-L22

Added lines #L21 - L22 were not covered by tests

return err
}

// CheckResponse verifies the API response and builds an appropriate Go error if necessary.
func CheckResponse(r *http.Response) error {
if c := r.StatusCode; 200 <= c && c <= 299 || c == 304 {
LaurenceJJones marked this conversation as resolved.
Show resolved Hide resolved
return nil
}

ret := &ErrorResponse{}

data, err := io.ReadAll(r.Body)
if err != nil || len(data) == 0 {
ret.Message = ptr.Of(fmt.Sprintf("http code %d, no error message", r.StatusCode))
return ret
}

Check warning on line 39 in pkg/apiclient/resperr.go

View check run for this annotation

Codecov / codecov/patch

pkg/apiclient/resperr.go#L37-L39

Added lines #L37 - L39 were not covered by tests

if err := json.Unmarshal(data, ret); err != nil {
return fmt.Errorf("http code %d, invalid body: %w", r.StatusCode, err)
}

return ret
}
1 change: 0 additions & 1 deletion pkg/apiserver/apic.go
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,6 @@ func createAlertForDecision(decision *models.Decision) *models.Alert {
scenario = *decision.Scenario
scope = types.ListOrigin
default:
// XXX: this or nil?
scenario = ""
scope = ""

Expand Down
Loading