diff --git a/pkg/apiclient/alerts_service.go b/pkg/apiclient/alerts_service.go index 1d0a4ebd12c7..ad75dd39342b 100644 --- a/pkg/apiclient/alerts_service.go +++ b/pkg/apiclient/alerts_service.go @@ -49,15 +49,15 @@ type AlertsDeleteOpts struct { } func (s *AlertsService) Add(ctx context.Context, alerts models.AddAlertsRequest) (*models.AddAlertsResponse, *Response, error) { - var addedIds models.AddAlertsResponse - u := fmt.Sprintf("%s/alerts", s.client.URLPrefix) - req, err := s.client.NewRequest(http.MethodPost, u, &alerts) + req, err := s.client.NewRequest(http.MethodPost, u, &alerts) if err != nil { return nil, nil, err } + addedIds := models.AddAlertsResponse{} + resp, err := s.client.Do(ctx, req, &addedIds) if err != nil { return nil, resp, err @@ -68,22 +68,16 @@ func (s *AlertsService) Add(ctx context.Context, alerts models.AddAlertsRequest) // to demo query arguments func (s *AlertsService) List(ctx context.Context, opts AlertsListOpts) (*models.GetAlertsResponse, *Response, error) { - var ( - alerts models.GetAlertsResponse - URI string - ) - u := fmt.Sprintf("%s/alerts", s.client.URLPrefix) - params, err := qs.Values(opts) + params, err := qs.Values(opts) if err != nil { return nil, nil, fmt.Errorf("building query: %w", err) } + URI := u if len(params) > 0 { - URI = fmt.Sprintf("%s?%s", u, params.Encode()) - } else { - URI = u + URI = fmt.Sprintf("%s?%s", URI, params.Encode()) } req, err := s.client.NewRequest(http.MethodGet, URI, nil) @@ -91,6 +85,8 @@ func (s *AlertsService) List(ctx context.Context, opts AlertsListOpts) (*models. return nil, nil, fmt.Errorf("building request: %w", err) } + alerts := models.GetAlertsResponse{} + resp, err := s.client.Do(ctx, req, &alerts) if err != nil { return nil, resp, fmt.Errorf("performing request: %w", err) @@ -101,8 +97,6 @@ func (s *AlertsService) List(ctx context.Context, opts AlertsListOpts) (*models. // to demo query arguments func (s *AlertsService) Delete(ctx context.Context, opts AlertsDeleteOpts) (*models.DeleteAlertsResponse, *Response, error) { - var alerts models.DeleteAlertsResponse - params, err := qs.Values(opts) if err != nil { return nil, nil, err @@ -115,6 +109,8 @@ func (s *AlertsService) Delete(ctx context.Context, opts AlertsDeleteOpts) (*mod return nil, nil, err } + alerts := models.DeleteAlertsResponse{} + resp, err := s.client.Do(ctx, req, &alerts) if err != nil { return nil, resp, err @@ -124,8 +120,6 @@ func (s *AlertsService) Delete(ctx context.Context, opts AlertsDeleteOpts) (*mod } func (s *AlertsService) DeleteOne(ctx context.Context, alertID string) (*models.DeleteAlertsResponse, *Response, error) { - var alerts models.DeleteAlertsResponse - u := fmt.Sprintf("%s/alerts/%s", s.client.URLPrefix, alertID) req, err := s.client.NewRequest(http.MethodDelete, u, nil) @@ -133,6 +127,8 @@ func (s *AlertsService) DeleteOne(ctx context.Context, alertID string) (*models. return nil, nil, err } + alerts := models.DeleteAlertsResponse{} + resp, err := s.client.Do(ctx, req, &alerts) if err != nil { return nil, resp, err @@ -142,8 +138,6 @@ func (s *AlertsService) DeleteOne(ctx context.Context, alertID string) (*models. } func (s *AlertsService) GetByID(ctx context.Context, alertID int) (*models.Alert, *Response, error) { - var alert models.Alert - u := fmt.Sprintf("%s/alerts/%d", s.client.URLPrefix, alertID) req, err := s.client.NewRequest(http.MethodGet, u, nil) @@ -151,6 +145,8 @@ func (s *AlertsService) GetByID(ctx context.Context, alertID int) (*models.Alert return nil, nil, err } + alert := models.Alert{} + resp, err := s.client.Do(ctx, req, &alert) if err != nil { return nil, nil, err diff --git a/pkg/apiclient/auth.go b/pkg/apiclient/auth.go index 163e96718b04..9ea2565e71eb 100644 --- a/pkg/apiclient/auth.go +++ b/pkg/apiclient/auth.go @@ -125,7 +125,7 @@ func (r retryRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) select { case <-req.Context().Done(): - return resp, req.Context().Err() + return nil, req.Context().Err() case <-time.After(time.Duration(backoff) * time.Second): } } @@ -135,8 +135,8 @@ func (r retryRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) } clonedReq := cloneRequest(req) - resp, err = r.next.RoundTrip(clonedReq) + resp, err = r.next.RoundTrip(clonedReq) if err != nil { if left := maxAttempts - i - 1; left > 0 { log.Errorf("error while performing request: %s; %d retries left", err, left) @@ -171,10 +171,11 @@ type JWTTransport struct { func (t *JWTTransport) refreshJwtToken() error { var err error + if t.UpdateScenario != nil { t.Scenarios, err = t.UpdateScenario() if err != nil { - return fmt.Errorf("can't update scenario list: %s", err) + return fmt.Errorf("can't update scenario list: %w", err) } log.Debugf("scenarios list updated for '%s'", *t.MachineID) @@ -186,8 +187,6 @@ func (t *JWTTransport) refreshJwtToken() error { Scenarios: t.Scenarios, } - var response models.WatcherAuthResponse - /* we don't use the main client, so let's build the body */ @@ -250,6 +249,8 @@ func (t *JWTTransport) refreshJwtToken() error { } } + var response models.WatcherAuthResponse + if err := json.NewDecoder(resp.Body).Decode(&response); err != nil { return fmt.Errorf("unable to decode response: %w", err) } @@ -300,7 +301,7 @@ func (t *JWTTransport) RoundTrip(req *http.Request) (*http.Response, error) { } if err != nil { - // we had an error (network error for example, or 401 because token is refused), reset the token ? + // we had an error (network error for example, or 401 because token is refused), reset the token? t.Token = "" return resp, fmt.Errorf("performing jwt auth: %w", err) @@ -324,14 +325,13 @@ func (t *JWTTransport) ResetToken() { t.refreshTokenMutex.Unlock() } +// transport() returns a round tripper that retries once when the status is unauthorized, and 5 times when infrastructure is overloaded. func (t *JWTTransport) transport() http.RoundTripper { - var transport http.RoundTripper - if t.Transport != nil { - transport = t.Transport - } else { + transport := t.Transport + if transport == nil { transport = http.DefaultTransport } - // a round tripper that retries once when the status is unauthorized and 5 times when infrastructure is overloaded + return &retryRoundTripper{ next: &retryRoundTripper{ next: transport, diff --git a/pkg/apiclient/client_http.go b/pkg/apiclient/client_http.go index 5222ad7707be..0240618f5356 100644 --- a/pkg/apiclient/client_http.go +++ b/pkg/apiclient/client_http.go @@ -94,7 +94,7 @@ func (c *ApiClient) Do(ctx context.Context, req *http.Request, v interface{}) (* if log.GetLevel() >= log.DebugLevel { for k, v := range resp.Header { - log.Debugf("[headers] %s : %s", k, v) + log.Debugf("[headers] %s: %s", k, v) } dump, err := httputil.DumpResponse(resp, true) diff --git a/pkg/apiclient/decisions_service.go b/pkg/apiclient/decisions_service.go index a3f02c0ef27e..388a870f999c 100644 --- a/pkg/apiclient/decisions_service.go +++ b/pkg/apiclient/decisions_service.go @@ -3,11 +3,11 @@ package apiclient import ( "bufio" "context" + "errors" "fmt" "net/http" qs "github.com/google/go-querystring/query" - "github.com/pkg/errors" log "github.com/sirupsen/logrus" "github.com/crowdsecurity/go-cs-lib/ptr" @@ -61,8 +61,6 @@ type DecisionsDeleteOpts struct { // to demo query arguments func (s *DecisionsService) List(ctx context.Context, opts DecisionsListOpts) (*models.GetDecisionsResponse, *Response, error) { - var decisions models.GetDecisionsResponse - params, err := qs.Values(opts) if err != nil { return nil, nil, err @@ -75,6 +73,8 @@ func (s *DecisionsService) List(ctx context.Context, opts DecisionsListOpts) (*m return nil, nil, err } + var decisions models.GetDecisionsResponse + resp, err := s.client.Do(ctx, req, &decisions) if err != nil { return nil, resp, err @@ -84,13 +84,13 @@ func (s *DecisionsService) List(ctx context.Context, opts DecisionsListOpts) (*m } func (s *DecisionsService) FetchV2Decisions(ctx context.Context, url string) (*models.DecisionsStreamResponse, *Response, error) { - var decisions models.DecisionsStreamResponse - req, err := s.client.NewRequest(http.MethodGet, url, nil) if err != nil { return nil, nil, err } + var decisions models.DecisionsStreamResponse + resp, err := s.client.Do(ctx, req, &decisions) if err != nil { return nil, resp, err @@ -100,7 +100,7 @@ func (s *DecisionsService) FetchV2Decisions(ctx context.Context, url string) (*m } func (s *DecisionsService) GetDecisionsFromGroups(decisionsGroups []*modelscapi.GetDecisionsStreamResponseNewItem) []*models.Decision { - var decisions []*models.Decision + decisions := make([]*models.Decision, 0) for _, decisionsGroup := range decisionsGroups { partialDecisions := make([]*models.Decision, len(decisionsGroup.Decisions)) @@ -122,11 +122,6 @@ func (s *DecisionsService) GetDecisionsFromGroups(decisionsGroups []*modelscapi. } func (s *DecisionsService) FetchV3Decisions(ctx context.Context, url string) (*models.DecisionsStreamResponse, *Response, error) { - var ( - decisions modelscapi.GetDecisionsStreamResponse - v2Decisions models.DecisionsStreamResponse - ) - scenarioDeleted := "deleted" durationDeleted := "1h" @@ -135,11 +130,14 @@ func (s *DecisionsService) FetchV3Decisions(ctx context.Context, url string) (*m return nil, nil, err } + decisions := modelscapi.GetDecisionsStreamResponse{} + resp, err := s.client.Do(ctx, req, &decisions) if err != nil { return nil, resp, err } + v2Decisions := models.DecisionsStreamResponse{} v2Decisions.New = s.GetDecisionsFromGroups(decisions.New) for _, decisionsGroup := range decisions.Deleted { @@ -183,6 +181,7 @@ func (s *DecisionsService) GetDecisionsFromBlocklist(ctx context.Context, blockl req = req.WithContext(ctx) log.Debugf("[URL] %s %s", req.Method, req.URL) + // we don't use client_http Do method because we need the reader and is not provided. // We would be forced to use Pipe and goroutine, etc resp, err := client.Do(req) @@ -247,11 +246,11 @@ func (s *DecisionsService) GetStream(ctx context.Context, opts DecisionsStreamOp return nil, nil, err } - if s.client.URLPrefix == "v3" { - return s.FetchV3Decisions(ctx, u) - } else { + if s.client.URLPrefix != "v3" { return s.FetchV2Decisions(ctx, u) } + + return s.FetchV3Decisions(ctx, u) } func (s *DecisionsService) GetStreamV3(ctx context.Context, opts DecisionsStreamOpts) (*modelscapi.GetDecisionsStreamResponse, *Response, error) { @@ -260,13 +259,13 @@ func (s *DecisionsService) GetStreamV3(ctx context.Context, opts DecisionsStream return nil, nil, err } - var decisions modelscapi.GetDecisionsStreamResponse - req, err := s.client.NewRequest(http.MethodGet, u, nil) if err != nil { return nil, nil, err } + decisions := modelscapi.GetDecisionsStreamResponse{} + resp, err := s.client.Do(ctx, req, &decisions) if err != nil { return nil, resp, err @@ -292,8 +291,6 @@ func (s *DecisionsService) StopStream(ctx context.Context) (*Response, error) { } func (s *DecisionsService) Delete(ctx context.Context, opts DecisionsDeleteOpts) (*models.DeleteDecisionResponse, *Response, error) { - var deleteDecisionResponse models.DeleteDecisionResponse - params, err := qs.Values(opts) if err != nil { return nil, nil, err @@ -306,6 +303,8 @@ func (s *DecisionsService) Delete(ctx context.Context, opts DecisionsDeleteOpts) return nil, nil, err } + deleteDecisionResponse := models.DeleteDecisionResponse{} + resp, err := s.client.Do(ctx, req, &deleteDecisionResponse) if err != nil { return nil, resp, err @@ -315,8 +314,6 @@ func (s *DecisionsService) Delete(ctx context.Context, opts DecisionsDeleteOpts) } func (s *DecisionsService) DeleteOne(ctx context.Context, decisionID string) (*models.DeleteDecisionResponse, *Response, error) { - var deleteDecisionResponse models.DeleteDecisionResponse - u := fmt.Sprintf("%s/decisions/%s", s.client.URLPrefix, decisionID) req, err := s.client.NewRequest(http.MethodDelete, u, nil) @@ -324,6 +321,8 @@ func (s *DecisionsService) DeleteOne(ctx context.Context, decisionID string) (*m return nil, nil, err } + deleteDecisionResponse := models.DeleteDecisionResponse{} + resp, err := s.client.Do(ctx, req, &deleteDecisionResponse) if err != nil { return nil, resp, err diff --git a/pkg/apiclient/decisions_sync_service.go b/pkg/apiclient/decisions_sync_service.go index 1aee9b6ca2a7..25e33a8e29d6 100644 --- a/pkg/apiclient/decisions_sync_service.go +++ b/pkg/apiclient/decisions_sync_service.go @@ -14,8 +14,6 @@ type DecisionDeleteService service // DecisionDeleteService purposely reuses AddSignalsRequestItemDecisions model func (d *DecisionDeleteService) Add(ctx context.Context, deletedDecisions *models.DecisionsDeleteRequest) (interface{}, *Response, error) { - var response interface{} - u := fmt.Sprintf("%s/decisions/delete", d.client.URLPrefix) req, err := d.client.NewRequest(http.MethodPost, u, &deletedDecisions) @@ -23,15 +21,17 @@ func (d *DecisionDeleteService) Add(ctx context.Context, deletedDecisions *model return nil, nil, fmt.Errorf("while building request: %w", err) } + var response interface{} + resp, err := d.client.Do(ctx, req, &response) if err != nil { return nil, resp, fmt.Errorf("while performing request: %w", err) } if resp.Response.StatusCode != http.StatusOK { - log.Warnf("Decisions delete response : http %s", resp.Response.Status) + log.Warnf("Decisions delete response: http %s", resp.Response.Status) } else { - log.Debugf("Decisions delete response : http %s", resp.Response.Status) + log.Debugf("Decisions delete response: http %s", resp.Response.Status) } return &response, resp, nil diff --git a/pkg/apiclient/heartbeat.go b/pkg/apiclient/heartbeat.go index 77e0ecc2eae9..c6b3d0832baf 100644 --- a/pkg/apiclient/heartbeat.go +++ b/pkg/apiclient/heartbeat.go @@ -41,15 +41,16 @@ func (h *HeartBeatService) StartHeartBeat(ctx context.Context, t *tomb.Tomb) { ok, resp, err := h.Ping(ctx) if err != nil { - log.Errorf("heartbeat error : %s", err) + log.Errorf("heartbeat error: %s", err) continue } resp.Response.Body.Close() if resp.Response.StatusCode != http.StatusOK { - log.Errorf("heartbeat unexpected return code : %d", resp.Response.StatusCode) + log.Errorf("heartbeat unexpected return code: %d", resp.Response.StatusCode) continue } + if !ok { log.Errorf("heartbeat returned false") continue diff --git a/pkg/apiclient/metrics.go b/pkg/apiclient/metrics.go index a822730070cb..7f8d095a2df7 100644 --- a/pkg/apiclient/metrics.go +++ b/pkg/apiclient/metrics.go @@ -11,8 +11,6 @@ import ( type MetricsService service func (s *MetricsService) Add(ctx context.Context, metrics *models.Metrics) (interface{}, *Response, error) { - var response interface{} - u := fmt.Sprintf("%s/metrics/", s.client.URLPrefix) req, err := s.client.NewRequest(http.MethodPost, u, &metrics) @@ -20,6 +18,8 @@ func (s *MetricsService) Add(ctx context.Context, metrics *models.Metrics) (inte return nil, nil, err } + var response interface{} + resp, err := s.client.Do(ctx, req, &response) if err != nil { return nil, resp, err diff --git a/pkg/apiclient/signal.go b/pkg/apiclient/signal.go index 94c02f080f09..613ce70bbfb5 100644 --- a/pkg/apiclient/signal.go +++ b/pkg/apiclient/signal.go @@ -13,8 +13,6 @@ import ( type SignalService service func (s *SignalService) Add(ctx context.Context, signals *models.AddSignalsRequest) (interface{}, *Response, error) { - var response interface{} - u := fmt.Sprintf("%s/signals", s.client.URLPrefix) req, err := s.client.NewRequest(http.MethodPost, u, &signals) @@ -22,6 +20,8 @@ func (s *SignalService) Add(ctx context.Context, signals *models.AddSignalsReque return nil, nil, fmt.Errorf("while building request: %w", err) } + var response interface{} + resp, err := s.client.Do(ctx, req, &response) if err != nil { return nil, resp, fmt.Errorf("while performing request: %w", err) diff --git a/pkg/apiserver/apic.go b/pkg/apiserver/apic.go index dcf12929a946..c20f292ffe73 100644 --- a/pkg/apiserver/apic.go +++ b/pkg/apiserver/apic.go @@ -2,21 +2,21 @@ package apiserver import ( "context" + "errors" "fmt" "math/rand" "net" "net/http" "net/url" + "slices" "strconv" "strings" "sync" "time" "github.com/go-openapi/strfmt" - "github.com/pkg/errors" log "github.com/sirupsen/logrus" "gopkg.in/tomb.v2" - "slices" "github.com/crowdsecurity/go-cs-lib/ptr" "github.com/crowdsecurity/go-cs-lib/trace" @@ -651,13 +651,15 @@ func (a *apic) PullTop(forcePull bool) error { } addCounters, deleteCounters := makeAddAndDeleteCounters() + // process deleted decisions - if nbDeleted, err := a.HandleDeletedDecisionsV3(data.Deleted, deleteCounters); err != nil { + nbDeleted, err := a.HandleDeletedDecisionsV3(data.Deleted, deleteCounters) + if err != nil { return err - } else { - log.Printf("capi/community-blocklist : %d explicit deletions", nbDeleted) } + log.Printf("capi/community-blocklist : %d explicit deletions", nbDeleted) + if len(data.New) == 0 { log.Infof("capi/community-blocklist : received 0 new entries (expected if you just installed crowdsec)") return nil @@ -895,12 +897,19 @@ func (a *apic) UpdateBlocklists(links *modelscapi.GetDecisionsStreamResponseLink } func setAlertScenario(alert *models.Alert, addCounters map[string]map[string]int, deleteCounters map[string]map[string]int) { - if *alert.Source.Scope == types.CAPIOrigin { + switch *alert.Source.Scope { + case types.CAPIOrigin: *alert.Source.Scope = types.CommunityBlocklistPullSourceScope - alert.Scenario = ptr.Of(fmt.Sprintf("update : +%d/-%d IPs", addCounters[types.CAPIOrigin]["all"], deleteCounters[types.CAPIOrigin]["all"])) - } else if *alert.Source.Scope == types.ListOrigin { + alert.Scenario = ptr.Of(fmt.Sprintf("update : +%d/-%d IPs", + addCounters[types.CAPIOrigin]["all"], + deleteCounters[types.CAPIOrigin]["all"]), + ) + case types.ListOrigin: *alert.Source.Scope = fmt.Sprintf("%s:%s", types.ListOrigin, *alert.Scenario) - alert.Scenario = ptr.Of(fmt.Sprintf("update : +%d/-%d IPs", addCounters[types.ListOrigin][*alert.Scenario], deleteCounters[types.ListOrigin][*alert.Scenario])) + alert.Scenario = ptr.Of(fmt.Sprintf("update : +%d/-%d IPs", + addCounters[types.ListOrigin][*alert.Scenario], + deleteCounters[types.ListOrigin][*alert.Scenario]), + ) } } @@ -972,11 +981,12 @@ func makeAddAndDeleteCounters() (map[string]map[string]int, map[string]map[strin } func updateCounterForDecision(counter map[string]map[string]int, origin *string, scenario *string, totalDecisions int) { - if *origin == types.CAPIOrigin { + switch *origin { + case types.CAPIOrigin: counter[*origin]["all"] += totalDecisions - } else if *origin == types.ListOrigin { + case types.ListOrigin: counter[*origin][*scenario] += totalDecisions - } else { + default: log.Warningf("Unknown origin %s", *origin) } } diff --git a/pkg/apiserver/apiserver.go b/pkg/apiserver/apiserver.go index 638ac2c65088..58caeb06880b 100644 --- a/pkg/apiserver/apiserver.go +++ b/pkg/apiserver/apiserver.go @@ -2,6 +2,7 @@ package apiserver import ( "context" + "errors" "fmt" "io" "net" @@ -13,7 +14,6 @@ import ( "github.com/gin-gonic/gin" "github.com/go-co-op/gocron" - "github.com/pkg/errors" log "github.com/sirupsen/logrus" "gopkg.in/natefinch/lumberjack.v2" "gopkg.in/tomb.v2" @@ -382,7 +382,9 @@ func (s *APIServer) listenAndServeURL(apiReady chan bool) { if s.TLS.KeyFilePath == "" { serverError <- errors.New("missing TLS key file") return - } else if s.TLS.CertFilePath == "" { + } + + if s.TLS.CertFilePath == "" { serverError <- errors.New("missing TLS cert file") return } diff --git a/pkg/apiserver/controllers/v1/alerts.go b/pkg/apiserver/controllers/v1/alerts.go index 424c20af6ce1..5fcd9fa432da 100644 --- a/pkg/apiserver/controllers/v1/alerts.go +++ b/pkg/apiserver/controllers/v1/alerts.go @@ -59,8 +59,10 @@ func FormatOneAlert(alert *ent.Alert) *models.Alert { } for _, eventItem := range alert.Edges.Events { - var Metas models.Meta timestamp := eventItem.Time.String() + + var Metas models.Meta + if err := json.Unmarshal([]byte(eventItem.Serialized), &Metas); err != nil { log.Errorf("unable to unmarshall events meta '%s' : %s", eventItem.Serialized, err) } @@ -162,6 +164,7 @@ func (c *Controller) CreateAlert(gctx *gin.Context) { if alert.Source.Scope != nil { *alert.Source.Scope = normalizeScope(*alert.Source.Scope) } + for _, decision := range alert.Decisions { if decision.Scope != nil { *decision.Scope = normalizeScope(*decision.Scope) @@ -183,30 +186,38 @@ func (c *Controller) CreateAlert(gctx *gin.Context) { _, matched, err := profile.EvaluateProfile(alert) if err != nil { profile.Logger.Warningf("error while evaluating profile %s : %v", profile.Cfg.Name, err) + continue } + if !matched { continue } + c.sendAlertToPluginChannel(alert, uint(pIdx)) + if profile.Cfg.OnSuccess == "break" { break } } + decision := alert.Decisions[0] if decision.Origin != nil && *decision.Origin == types.CscliImportOrigin { stopFlush = true } + continue } for pIdx, profile := range c.Profiles { profileDecisions, matched, err := profile.EvaluateProfile(alert) forceBreak := false + if err != nil { switch profile.Cfg.OnError { case "apply": profile.Logger.Warningf("applying profile %s despite error: %s", profile.Cfg.Name, err) + matched = true case "continue": profile.Logger.Warningf("skipping %s profile due to error: %s", profile.Cfg.Name, err) @@ -219,18 +230,23 @@ func (c *Controller) CreateAlert(gctx *gin.Context) { return } } + if !matched { continue } + for _, decision := range profileDecisions { decision.UUID = uuid.NewString() } - //generate uuid here for alert + + // generate uuid here for alert if len(alert.Decisions) == 0 { // non manual decision alert.Decisions = append(alert.Decisions, profileDecisions...) } + profileAlert := *alert c.sendAlertToPluginChannel(&profileAlert, uint(pIdx)) + if profile.Cfg.OnSuccess == "break" || forceBreak { break } @@ -275,6 +291,7 @@ func (c *Controller) FindAlerts(gctx *gin.Context) { gctx.String(http.StatusOK, "") return } + gctx.JSON(http.StatusOK, data) } @@ -282,21 +299,25 @@ func (c *Controller) FindAlerts(gctx *gin.Context) { func (c *Controller) FindAlertByID(gctx *gin.Context) { alertIDStr := gctx.Param("alert_id") alertID, err := strconv.Atoi(alertIDStr) + if err != nil { gctx.JSON(http.StatusBadRequest, gin.H{"message": "alert_id must be valid integer"}) return } + result, err := c.DBClient.GetAlertByID(alertID) if err != nil { c.HandleDBErrors(gctx, err) return } + data := FormatOneAlert(result) if gctx.Request.Method == http.MethodHead { gctx.String(http.StatusOK, "") return } + gctx.JSON(http.StatusOK, data) } @@ -310,21 +331,19 @@ func (c *Controller) DeleteAlertByID(gctx *gin.Context) { return } - decisionIDStr := gctx.Param("alert_id") - decisionID, err := strconv.Atoi(decisionIDStr) + decisionID, err := strconv.Atoi(gctx.Param("alert_id")) if err != nil { gctx.JSON(http.StatusBadRequest, gin.H{"message": "alert_id must be valid integer"}) return } + err = c.DBClient.DeleteAlertByID(decisionID) if err != nil { c.HandleDBErrors(gctx, err) return } - deleteAlertResp := models.DeleteAlertsResponse{ - NbDeleted: "1", - } + deleteAlertResp := models.DeleteAlertsResponse{NbDeleted: "1"} gctx.JSON(http.StatusOK, deleteAlertResp) } @@ -336,15 +355,17 @@ func (c *Controller) DeleteAlerts(gctx *gin.Context) { gctx.JSON(http.StatusForbidden, gin.H{"message": fmt.Sprintf("access forbidden from this IP (%s)", incomingIP)}) return } - var err error + nbDeleted, err := c.DBClient.DeleteAlertWithFilter(gctx.Request.URL.Query()) if err != nil { c.HandleDBErrors(gctx, err) return } + deleteAlertsResp := models.DeleteAlertsResponse{ NbDeleted: strconv.Itoa(nbDeleted), } + gctx.JSON(http.StatusOK, deleteAlertsResp) } @@ -355,5 +376,6 @@ func networksContainIP(networks []net.IPNet, ip string) bool { return true } } + return false } diff --git a/pkg/apiserver/controllers/v1/controller.go b/pkg/apiserver/controllers/v1/controller.go index 60da83d7dcbd..ad76ad766169 100644 --- a/pkg/apiserver/controllers/v1/controller.go +++ b/pkg/apiserver/controllers/v1/controller.go @@ -61,8 +61,10 @@ func New(cfg *ControllerV1Config) (*Controller, error) { TrustedIPs: cfg.TrustedIPs, } v1.Middlewares, err = middlewares.NewMiddlewares(cfg.DbClient) + if err != nil { return v1, err } + return v1, nil } diff --git a/pkg/apiserver/controllers/v1/decisions.go b/pkg/apiserver/controllers/v1/decisions.go index 534870484d41..155aab77752f 100644 --- a/pkg/apiserver/controllers/v1/decisions.go +++ b/pkg/apiserver/controllers/v1/decisions.go @@ -2,7 +2,6 @@ package v1 import ( "encoding/json" - "fmt" "net/http" "strconv" "time" @@ -33,23 +32,27 @@ func FormatDecisions(decisions []*ent.Decision) []*models.Decision { } results = append(results, &decision) } + return results } func (c *Controller) GetDecision(gctx *gin.Context) { - var err error - var results []*models.Decision - var data []*ent.Decision + var ( + results []*models.Decision + data []*ent.Decision + ) bouncerInfo, err := getBouncerFromContext(gctx) if err != nil { gctx.JSON(http.StatusUnauthorized, gin.H{"message": "not allowed"}) + return } data, err = c.DBClient.QueryDecisionWithFilter(gctx.Request.URL.Query()) if err != nil { c.HandleDBErrors(gctx, err) + return } @@ -64,6 +67,7 @@ func (c *Controller) GetDecision(gctx *gin.Context) { if gctx.Request.Method == http.MethodHead { gctx.String(http.StatusOK, "") + return } @@ -77,20 +81,23 @@ func (c *Controller) GetDecision(gctx *gin.Context) { } func (c *Controller) DeleteDecisionById(gctx *gin.Context) { - var err error - decisionIDStr := gctx.Param("decision_id") + decisionID, err := strconv.Atoi(decisionIDStr) if err != nil { gctx.JSON(http.StatusBadRequest, gin.H{"message": "decision_id must be valid integer"}) + return } + nbDeleted, deletedFromDB, err := c.DBClient.SoftDeleteDecisionByID(decisionID) if err != nil { c.HandleDBErrors(gctx, err) + return } - //transform deleted decisions to be sendable to capi + + // transform deleted decisions to be sendable to capi deletedDecisions := FormatDecisions(deletedFromDB) if c.DecisionDeleteChan != nil { @@ -105,13 +112,14 @@ func (c *Controller) DeleteDecisionById(gctx *gin.Context) { } func (c *Controller) DeleteDecisions(gctx *gin.Context) { - var err error nbDeleted, deletedFromDB, err := c.DBClient.SoftDeleteDecisionsWithFilter(gctx.Request.URL.Query()) if err != nil { c.HandleDBErrors(gctx, err) + return } - //transform deleted decisions to be sendable to capi + + // transform deleted decisions to be sendable to capi deletedDecisions := FormatDecisions(deletedFromDB) if c.DecisionDeleteChan != nil { @@ -121,6 +129,7 @@ func (c *Controller) DeleteDecisions(gctx *gin.Context) { deleteDecisionResp := models.DeleteDecisionResponse{ NbDeleted: nbDeleted, } + gctx.JSON(http.StatusOK, deleteDecisionResp) } @@ -128,25 +137,29 @@ func writeStartupDecisions(gctx *gin.Context, filters map[string][]string, dbFun // respBuffer := bytes.NewBuffer([]byte{}) limit := 30000 //FIXME : make it configurable needComma := false - lastId := 0 + lastID := 0 - limitStr := fmt.Sprintf("%d", limit) + limitStr := strconv.Itoa(limit) filters["limit"] = []string{limitStr} + for { - if lastId > 0 { - lastIdStr := fmt.Sprintf("%d", lastId) - filters["id_gt"] = []string{lastIdStr} + if lastID > 0 { + lastIDStr := strconv.Itoa(lastID) + filters["id_gt"] = []string{lastIDStr} } data, err := dbFunc(filters) if err != nil { return err } + if len(data) > 0 { - lastId = data[len(data)-1].ID + lastID = data[len(data)-1].ID + results := FormatDecisions(data) for _, decision := range results { decisionJSON, _ := json.Marshal(decision) + if needComma { //respBuffer.Write([]byte(",")) gctx.Writer.Write([]byte(",")) @@ -158,17 +171,22 @@ func writeStartupDecisions(gctx *gin.Context, filters map[string][]string, dbFun _, err := gctx.Writer.Write(decisionJSON) if err != nil { gctx.Writer.Flush() + return err } //respBuffer.Reset() } } - log.Debugf("startup: %d decisions returned (limit: %d, lastid: %d)", len(data), limit, lastId) + + log.Debugf("startup: %d decisions returned (limit: %d, lastid: %d)", len(data), limit, lastID) + if len(data) < limit { gctx.Writer.Flush() + break } } + return nil } @@ -176,25 +194,27 @@ func writeDeltaDecisions(gctx *gin.Context, filters map[string][]string, lastPul //respBuffer := bytes.NewBuffer([]byte{}) limit := 30000 //FIXME : make it configurable needComma := false - lastId := 0 + lastID := 0 + + filters["limit"] = []string{strconv.Itoa(limit)} - limitStr := fmt.Sprintf("%d", limit) - filters["limit"] = []string{limitStr} for { - if lastId > 0 { - lastIdStr := fmt.Sprintf("%d", lastId) - filters["id_gt"] = []string{lastIdStr} + if lastID > 0 { + filters["id_gt"] = []string{strconv.Itoa(lastID)} } data, err := dbFunc(lastPull, filters) if err != nil { return err } + if len(data) > 0 { - lastId = data[len(data)-1].ID + lastID = data[len(data)-1].ID + results := FormatDecisions(data) for _, decision := range results { decisionJSON, _ := json.Marshal(decision) + if needComma { //respBuffer.Write([]byte(",")) gctx.Writer.Write([]byte(",")) @@ -206,17 +226,22 @@ func writeDeltaDecisions(gctx *gin.Context, filters map[string][]string, lastPul _, err := gctx.Writer.Write(decisionJSON) if err != nil { gctx.Writer.Flush() + return err } //respBuffer.Reset() } } - log.Debugf("startup: %d decisions returned (limit: %d, lastid: %d)", len(data), limit, lastId) + + log.Debugf("startup: %d decisions returned (limit: %d, lastid: %d)", len(data), limit, lastID) + if len(data) < limit { gctx.Writer.Flush() + break } } + return nil } @@ -230,14 +255,13 @@ func (c *Controller) StreamDecisionChunked(gctx *gin.Context, bouncerInfo *ent.B // if the blocker just started, return all decisions if val, ok := gctx.Request.URL.Query()["startup"]; ok && val[0] == "true" { - //Active decisions - + // Active decisions err := writeStartupDecisions(gctx, filters, c.DBClient.QueryAllDecisionsWithFilters) - if err != nil { log.Errorf("failed sending new decisions for startup: %v", err) gctx.Writer.Write([]byte(`], "deleted": []}`)) gctx.Writer.Flush() + return err } @@ -248,6 +272,7 @@ func (c *Controller) StreamDecisionChunked(gctx *gin.Context, bouncerInfo *ent.B log.Errorf("failed sending expired decisions for startup: %v", err) gctx.Writer.Write([]byte(`]}`)) gctx.Writer.Flush() + return err } @@ -259,6 +284,7 @@ func (c *Controller) StreamDecisionChunked(gctx *gin.Context, bouncerInfo *ent.B log.Errorf("failed sending new decisions for delta: %v", err) gctx.Writer.Write([]byte(`], "deleted": []}`)) gctx.Writer.Flush() + return err } @@ -270,18 +296,21 @@ func (c *Controller) StreamDecisionChunked(gctx *gin.Context, bouncerInfo *ent.B log.Errorf("failed sending expired decisions for delta: %v", err) gctx.Writer.Write([]byte(`]}`)) gctx.Writer.Flush() + return err } gctx.Writer.Write([]byte(`]}`)) gctx.Writer.Flush() } + return nil } func (c *Controller) StreamDecisionNonChunked(gctx *gin.Context, bouncerInfo *ent.Bouncer, streamStartTime time.Time, filters map[string][]string) error { var data []*ent.Decision var err error + ret := make(map[string][]*models.Decision, 0) ret["new"] = []*models.Decision{} ret["deleted"] = []*models.Decision{} @@ -292,6 +321,7 @@ func (c *Controller) StreamDecisionNonChunked(gctx *gin.Context, bouncerInfo *en if err != nil { log.Errorf("failed querying decisions: %v", err) gctx.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()}) + return err } //data = KeepLongestDecision(data) @@ -302,11 +332,14 @@ func (c *Controller) StreamDecisionNonChunked(gctx *gin.Context, bouncerInfo *en if err != nil { log.Errorf("unable to query expired decision for '%s' : %v", bouncerInfo.Name, err) gctx.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()}) + return err } + ret["deleted"] = FormatDecisions(data) gctx.JSON(http.StatusOK, ret) + return nil } } @@ -316,6 +349,7 @@ func (c *Controller) StreamDecisionNonChunked(gctx *gin.Context, bouncerInfo *en if err != nil { log.Errorf("unable to query new decision for '%s' : %v", bouncerInfo.Name, err) gctx.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()}) + return err } //data = KeepLongestDecision(data) @@ -326,10 +360,13 @@ func (c *Controller) StreamDecisionNonChunked(gctx *gin.Context, bouncerInfo *en if err != nil { log.Errorf("unable to query expired decision for '%s' : %v", bouncerInfo.Name, err) gctx.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()}) + return err } + ret["deleted"] = FormatDecisions(data) gctx.JSON(http.StatusOK, ret) + return nil } @@ -337,9 +374,11 @@ func (c *Controller) StreamDecision(gctx *gin.Context) { var err error streamStartTime := time.Now().UTC() + bouncerInfo, err := getBouncerFromContext(gctx) if err != nil { gctx.JSON(http.StatusUnauthorized, gin.H{"message": "not allowed"}) + return } @@ -347,6 +386,7 @@ func (c *Controller) StreamDecision(gctx *gin.Context) { //For HEAD, just return as the bouncer won't get a body anyway, so no need to query the db //We also don't update the last pull time, as it would mess with the delta sent on the next request (if done without startup=true) gctx.String(http.StatusOK, "") + return } diff --git a/pkg/apiserver/controllers/v1/heartbeat.go b/pkg/apiserver/controllers/v1/heartbeat.go index bf6fd5781956..b19b450f0d52 100644 --- a/pkg/apiserver/controllers/v1/heartbeat.go +++ b/pkg/apiserver/controllers/v1/heartbeat.go @@ -8,7 +8,6 @@ import ( ) func (c *Controller) HeartBeat(gctx *gin.Context) { - claims := jwt.ExtractClaims(gctx) // TBD: use defined rather than hardcoded key to find back owner machineID := claims["id"].(string) @@ -17,5 +16,6 @@ func (c *Controller) HeartBeat(gctx *gin.Context) { c.HandleDBErrors(gctx, err) return } + gctx.Status(http.StatusOK) } diff --git a/pkg/apiserver/controllers/v1/machines.go b/pkg/apiserver/controllers/v1/machines.go index 55f79d0c93ff..84a6ef2583c9 100644 --- a/pkg/apiserver/controllers/v1/machines.go +++ b/pkg/apiserver/controllers/v1/machines.go @@ -11,19 +11,19 @@ import ( ) func (c *Controller) CreateMachine(gctx *gin.Context) { - var err error var input models.WatcherRegistrationRequest - if err = gctx.ShouldBindJSON(&input); err != nil { + + if err := gctx.ShouldBindJSON(&input); err != nil { gctx.JSON(http.StatusBadRequest, gin.H{"message": err.Error()}) return } - if err = input.Validate(strfmt.Default); err != nil { + + if err := input.Validate(strfmt.Default); err != nil { c.HandleDBErrors(gctx, err) return } - _, err = c.DBClient.CreateMachine(input.MachineID, input.Password, gctx.ClientIP(), false, false, types.PasswordAuthType) - if err != nil { + if _, err := c.DBClient.CreateMachine(input.MachineID, input.Password, gctx.ClientIP(), false, false, types.PasswordAuthType); err != nil { c.HandleDBErrors(gctx, err) return } diff --git a/pkg/apiserver/controllers/v1/metrics.go b/pkg/apiserver/controllers/v1/metrics.go index 676cc31ea46f..13ccf9ac94f1 100644 --- a/pkg/apiserver/controllers/v1/metrics.go +++ b/pkg/apiserver/controllers/v1/metrics.go @@ -93,6 +93,7 @@ func PrometheusMachinesMiddleware() gin.HandlerFunc { "method": c.Request.Method}).Inc() } } + c.Next() } } @@ -106,6 +107,7 @@ func PrometheusBouncersMiddleware() gin.HandlerFunc { "route": c.Request.URL.Path, "method": c.Request.Method}).Inc() } + c.Next() } } @@ -113,10 +115,12 @@ func PrometheusBouncersMiddleware() gin.HandlerFunc { func PrometheusMiddleware() gin.HandlerFunc { return func(c *gin.Context) { startTime := time.Now() + LapiRouteHits.With(prometheus.Labels{ "route": c.Request.URL.Path, "method": c.Request.Method}).Inc() c.Next() + elapsed := time.Since(startTime) LapiResponseTime.With(prometheus.Labels{"method": c.Request.Method, "endpoint": c.Request.URL.Path}).Observe(elapsed.Seconds()) } diff --git a/pkg/apiserver/controllers/v1/utils.go b/pkg/apiserver/controllers/v1/utils.go index aaa17ca51b1a..6afd005132ac 100644 --- a/pkg/apiserver/controllers/v1/utils.go +++ b/pkg/apiserver/controllers/v1/utils.go @@ -9,9 +9,7 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/database/ent" ) -var ( - bouncerContextKey = "bouncer_info" -) +const bouncerContextKey = "bouncer_info" func getBouncerFromContext(ctx *gin.Context) (*ent.Bouncer, error) { bouncerInterface, exist := ctx.Get(bouncerContextKey) diff --git a/pkg/apiserver/middlewares/v1/api_key.go b/pkg/apiserver/middlewares/v1/api_key.go index 7e4df875c116..4a79bacc1443 100644 --- a/pkg/apiserver/middlewares/v1/api_key.go +++ b/pkg/apiserver/middlewares/v1/api_key.go @@ -4,6 +4,7 @@ import ( "crypto/rand" "crypto/sha512" "encoding/base64" + "encoding/hex" "fmt" "net/http" "strings" @@ -34,7 +35,9 @@ func GenerateAPIKey(n int) (string, error) { if _, err := rand.Read(bytes); err != nil { return "", err } + encoded := base64.StdEncoding.EncodeToString(bytes) + // the '=' can cause issues on some bouncers return strings.TrimRight(encoded, "="), nil } @@ -51,7 +54,7 @@ func HashSHA512(str string) string { hashedKey := sha512.New() hashedKey.Write([]byte(str)) - hashStr := fmt.Sprintf("%x", hashedKey.Sum(nil)) + hashStr := hex.EncodeToString(hashedKey.Sum(nil)) return hashStr } @@ -67,6 +70,7 @@ func (a *APIKey) authTLS(c *gin.Context, logger *log.Entry) *ent.Bouncer { logger.Errorf("invalid client certificate: %s", err) return nil } + if err != nil { logger.Error(err) return nil @@ -88,7 +92,9 @@ func (a *APIKey) authTLS(c *gin.Context, logger *log.Entry) *ent.Bouncer { logger.Errorf("error generating mock api key: %s", err) return nil } + logger.Infof("Creating bouncer %s", bouncerName) + bouncer, err = a.DbClient.CreateBouncer(bouncerName, c.ClientIP(), HashSHA512(apiKey), types.TlsAuthType) if err != nil { logger.Errorf("while creating bouncer db entry: %s", err) @@ -103,6 +109,7 @@ func (a *APIKey) authTLS(c *gin.Context, logger *log.Entry) *ent.Bouncer { logger.Errorf("bouncer isn't allowed to auth by TLS") return nil } + return bouncer } @@ -112,6 +119,7 @@ func (a *APIKey) authPlain(c *gin.Context, logger *log.Entry) *ent.Bouncer { logger.Errorf("API key not found") return nil } + hashStr := HashSHA512(val[0]) bouncer, err := a.DbClient.SelectBouncer(hashStr) @@ -162,16 +170,19 @@ func (a *APIKey) MiddlewareFunc() gin.HandlerFunc { logger.Errorf("Failed to update ip address for '%s': %s\n", bouncer.Name, err) c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) c.Abort() + return } } if bouncer.IPAddress != c.ClientIP() && bouncer.IPAddress != "" { log.Warningf("new IP address detected for bouncer '%s': %s (old: %s)", bouncer.Name, c.ClientIP(), bouncer.IPAddress) + if err := a.DbClient.UpdateBouncerIP(c.ClientIP(), bouncer.ID); err != nil { logger.Errorf("Failed to update ip address for '%s': %s\n", bouncer.Name, err) c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) c.Abort() + return } } @@ -187,12 +198,12 @@ func (a *APIKey) MiddlewareFunc() gin.HandlerFunc { logger.Errorf("failed to update bouncer version and type: %s", err) c.JSON(http.StatusForbidden, gin.H{"message": "bad user agent"}) c.Abort() + return } } c.Set(bouncerContextKey, bouncer) - c.Next() } } diff --git a/pkg/apiserver/middlewares/v1/jwt.go b/pkg/apiserver/middlewares/v1/jwt.go index 8797761a4144..3b777a8671d5 100644 --- a/pkg/apiserver/middlewares/v1/jwt.go +++ b/pkg/apiserver/middlewares/v1/jwt.go @@ -36,14 +36,16 @@ func PayloadFunc(data interface{}) jwt.MapClaims { identityKey: &value.MachineID, } } + return jwt.MapClaims{} } func IdentityHandler(c *gin.Context) interface{} { claims := jwt.ExtractClaims(c) - machineId := claims[identityKey].(string) + machineID := claims[identityKey].(string) + return &models.WatcherAuthRequest{ - MachineID: &machineId, + MachineID: &machineID, } } @@ -67,6 +69,7 @@ func (j *JWT) authTLS(c *gin.Context) (*authInput, error) { log.Error(err) c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) c.Abort() + return nil, fmt.Errorf("while trying to validate client cert: %w", err) } @@ -77,6 +80,7 @@ func (j *JWT) authTLS(c *gin.Context) (*authInput, error) { } ret.machineID = fmt.Sprintf("%s@%s", extractedCN, c.ClientIP()) + ret.clientMachine, err = j.DbClient.Ent.Machine.Query(). Where(machine.MachineId(ret.machineID)). First(j.DbClient.CTX) @@ -90,9 +94,12 @@ func (j *JWT) authTLS(c *gin.Context) (*authInput, error) { "ip": c.ClientIP(), "cn": extractedCN, }).Errorf("error generating password: %s", err) + return nil, fmt.Errorf("error generating password") } + password := strfmt.Password(pwd) + ret.clientMachine, err = j.DbClient.CreateMachine(&ret.machineID, &password, "", true, true, types.TlsAuthType) if err != nil { return nil, fmt.Errorf("while creating machine entry for %s: %w", ret.machineID, err) @@ -111,27 +118,33 @@ func (j *JWT) authTLS(c *gin.Context) (*authInput, error) { }{ Scenarios: []string{}, } + err = c.ShouldBindJSON(&loginInput) if err != nil { return nil, fmt.Errorf("missing scenarios list in login request for TLS auth: %w", err) } + ret.scenariosInput = loginInput.Scenarios return &ret, nil } func (j *JWT) authPlain(c *gin.Context) (*authInput, error) { - var loginInput models.WatcherAuthRequest - var err error + var ( + loginInput models.WatcherAuthRequest + err error + ) ret := authInput{} if err = c.ShouldBindJSON(&loginInput); err != nil { return nil, fmt.Errorf("missing: %w", err) } + if err = loginInput.Validate(strfmt.Default); err != nil { return nil, err } + ret.machineID = *loginInput.MachineID password := *loginInput.Password ret.scenariosInput = loginInput.Scenarios @@ -165,8 +178,10 @@ func (j *JWT) authPlain(c *gin.Context) (*authInput, error) { } func (j *JWT) Authenticator(c *gin.Context) (interface{}, error) { - var err error - var auth *authInput + var ( + err error + auth *authInput + ) if c.Request.TLS != nil && len(c.Request.TLS.PeerCertificates) > 0 { auth, err = j.authTLS(c) @@ -190,6 +205,7 @@ func (j *JWT) Authenticator(c *gin.Context) (interface{}, error) { scenarios += "," + scenario } } + err = j.DbClient.UpdateMachineScenarios(scenarios, auth.clientMachine.ID) if err != nil { log.Errorf("Failed to update scenarios list for '%s': %s\n", auth.machineID, err) @@ -207,6 +223,7 @@ func (j *JWT) Authenticator(c *gin.Context) (interface{}, error) { if auth.clientMachine.IpAddress != c.ClientIP() && auth.clientMachine.IpAddress != "" { log.Warningf("new IP address detected for machine '%s': %s (old: %s)", auth.clientMachine.MachineId, c.ClientIP(), auth.clientMachine.IpAddress) + err = j.DbClient.UpdateMachineIP(c.ClientIP(), auth.clientMachine.ID) if err != nil { log.Errorf("Failed to update ip address for '%s': %s\n", auth.clientMachine.MachineId, err) @@ -225,10 +242,10 @@ func (j *JWT) Authenticator(c *gin.Context) (interface{}, error) { log.Errorf("bad user agent from : %s", c.ClientIP()) return nil, jwt.ErrFailedAuthentication } + return &models.WatcherAuthRequest{ MachineID: &auth.machineID, }, nil - } func Authorizator(data interface{}, c *gin.Context) bool { @@ -259,25 +276,21 @@ func randomSecret() ([]byte, error) { } func NewJWT(dbClient *database.Client) (*JWT, error) { - // Get secret from environment variable "SECRET" - var ( - secret []byte - err error - ) + var err error // Please be aware that brute force HS256 is possible. // PLEASE choose a STRONG secret secretString := os.Getenv("CS_LAPI_SECRET") - secret = []byte(secretString) + secret := []byte(secretString) switch l := len(secret); { case l == 0: secret, err = randomSecret() if err != nil { - return &JWT{}, err + return nil, err } case l < 64: - return &JWT{}, errors.New("CS_LAPI_SECRET not strong enough") + return nil, errors.New("CS_LAPI_SECRET not strong enough") } jwtMiddleware := &JWT{ @@ -301,13 +314,14 @@ func NewJWT(dbClient *database.Client) (*JWT, error) { TimeFunc: time.Now, }) if err != nil { - return &JWT{}, err + return nil, err } errInit := ret.MiddlewareInit() if errInit != nil { - return &JWT{}, fmt.Errorf("authMiddleware.MiddlewareInit() Error:" + errInit.Error()) + return nil, fmt.Errorf("authMiddleware.MiddlewareInit() Error:" + errInit.Error()) } + jwtMiddleware.Middleware = ret return jwtMiddleware, nil diff --git a/pkg/apiserver/middlewares/v1/middlewares.go b/pkg/apiserver/middlewares/v1/middlewares.go index ef2d93b9212c..a5409ea5c9e6 100644 --- a/pkg/apiserver/middlewares/v1/middlewares.go +++ b/pkg/apiserver/middlewares/v1/middlewares.go @@ -18,5 +18,6 @@ func NewMiddlewares(dbClient *database.Client) (*Middlewares, error) { } ret.APIKey = NewAPIKey(dbClient) + return ret, nil } diff --git a/pkg/apiserver/middlewares/v1/tls_auth.go b/pkg/apiserver/middlewares/v1/tls_auth.go index 87ca896a8f47..904f6cd445ab 100644 --- a/pkg/apiserver/middlewares/v1/tls_auth.go +++ b/pkg/apiserver/middlewares/v1/tls_auth.go @@ -36,32 +36,40 @@ func (ta *TLSAuth) ocspQuery(server string, cert *x509.Certificate, issuer *x509 ta.logger.Errorf("TLSAuth: error creating OCSP request: %s", err) return nil, err } + httpRequest, err := http.NewRequest(http.MethodPost, server, bytes.NewBuffer(req)) if err != nil { ta.logger.Error("TLSAuth: cannot create HTTP request for OCSP") return nil, err } + ocspURL, err := url.Parse(server) if err != nil { ta.logger.Error("TLSAuth: cannot parse OCSP URL") return nil, err } + httpRequest.Header.Add("Content-Type", "application/ocsp-request") httpRequest.Header.Add("Accept", "application/ocsp-response") httpRequest.Header.Add("host", ocspURL.Host) + httpClient := &http.Client{} + httpResponse, err := httpClient.Do(httpRequest) if err != nil { ta.logger.Error("TLSAuth: cannot send HTTP request to OCSP") return nil, err } defer httpResponse.Body.Close() + output, err := io.ReadAll(httpResponse.Body) if err != nil { ta.logger.Error("TLSAuth: cannot read HTTP response from OCSP") return nil, err } + ocspResponse, err := ocsp.ParseResponseForCert(output, cert, issuer) + return ocspResponse, err } @@ -72,10 +80,12 @@ func (ta *TLSAuth) isExpired(cert *x509.Certificate) bool { ta.logger.Errorf("TLSAuth: client certificate is expired (NotAfter: %s)", cert.NotAfter.UTC()) return true } + if cert.NotBefore.UTC().After(now) { ta.logger.Errorf("TLSAuth: client certificate is not yet valid (NotBefore: %s)", cert.NotBefore.UTC()) return true } + return false } @@ -84,12 +94,14 @@ func (ta *TLSAuth) isOCSPRevoked(cert *x509.Certificate, issuer *x509.Certificat ta.logger.Infof("TLSAuth: no OCSP Server present in client certificate, skipping OCSP verification") return false, nil } + for _, server := range cert.OCSPServer { ocspResponse, err := ta.ocspQuery(server, cert, issuer) if err != nil { ta.logger.Errorf("TLSAuth: error querying OCSP server %s: %s", server, err) continue } + switch ocspResponse.Status { case ocsp.Good: return false, nil @@ -100,7 +112,9 @@ func (ta *TLSAuth) isOCSPRevoked(cert *x509.Certificate, issuer *x509.Certificat continue } } + log.Infof("Could not get any valid OCSP response, assuming the cert is revoked") + return true, nil } @@ -109,24 +123,29 @@ func (ta *TLSAuth) isCRLRevoked(cert *x509.Certificate) (bool, error) { ta.logger.Warn("no crl_path, skipping CRL check") return false, nil } + crlContent, err := os.ReadFile(ta.CrlPath) if err != nil { ta.logger.Warnf("could not read CRL file, skipping check: %s", err) return false, nil } + crl, err := x509.ParseCRL(crlContent) if err != nil { ta.logger.Warnf("could not parse CRL file, skipping check: %s", err) return false, nil } + if crl.HasExpired(time.Now().UTC()) { ta.logger.Warn("CRL has expired, will still validate the cert against it.") } + for _, revoked := range crl.TBSCertList.RevokedCertificates { if revoked.SerialNumber.Cmp(cert.SerialNumber) == 0 { return true, fmt.Errorf("client certificate is revoked by CRL") } } + return false, nil } @@ -143,6 +162,7 @@ func (ta *TLSAuth) isRevoked(cert *x509.Certificate, issuer *x509.Certificate) ( } else { ta.logger.Tracef("TLSAuth: no cached value for cert %s", sn) } + revoked, err := ta.isOCSPRevoked(cert, issuer) if err != nil { ta.revokationCache[sn] = cacheEntry{ @@ -150,22 +170,27 @@ func (ta *TLSAuth) isRevoked(cert *x509.Certificate, issuer *x509.Certificate) ( err: err, timestamp: time.Now().UTC(), } + return true, err } + if revoked { ta.revokationCache[sn] = cacheEntry{ revoked: revoked, err: err, timestamp: time.Now().UTC(), } + return true, nil } + revoked, err = ta.isCRLRevoked(cert) ta.revokationCache[sn] = cacheEntry{ revoked: revoked, err: err, timestamp: time.Now().UTC(), } + return revoked, err } @@ -173,6 +198,7 @@ func (ta *TLSAuth) isInvalid(cert *x509.Certificate, issuer *x509.Certificate) ( if ta.isExpired(cert) { return true, nil } + revoked, err := ta.isRevoked(cert, issuer) if err != nil { //Fail securely, if we can't check the revocation status, let's consider the cert invalid @@ -189,24 +215,30 @@ func (ta *TLSAuth) SetAllowedOu(allowedOus []string) error { if ou == "" { return fmt.Errorf("empty ou isn't allowed") } + //drop & warn on duplicate ou ok := true + for _, validOu := range ta.AllowedOUs { if validOu == ou { ta.logger.Warningf("dropping duplicate ou %s", ou) + ok = false } } + if ok { ta.AllowedOUs = append(ta.AllowedOUs, ou) } } + return nil } func (ta *TLSAuth) ValidateCert(c *gin.Context) (bool, string, error) { //Checks cert validity, Returns true + CN if client cert matches requested OU var clientCert *x509.Certificate + if c.Request.TLS == nil || len(c.Request.TLS.PeerCertificates) == 0 { //do not error if it's not TLS or there are no peer certs return false, "", nil @@ -215,6 +247,7 @@ func (ta *TLSAuth) ValidateCert(c *gin.Context) (bool, string, error) { if len(c.Request.TLS.VerifiedChains) > 0 { validOU := false clientCert = c.Request.TLS.VerifiedChains[0][0] + for _, ou := range clientCert.Subject.OrganizationalUnit { for _, allowedOu := range ta.AllowedOUs { if allowedOu == ou { @@ -223,21 +256,27 @@ func (ta *TLSAuth) ValidateCert(c *gin.Context) (bool, string, error) { } } } + if !validOU { return false, "", fmt.Errorf("client certificate OU (%v) doesn't match expected OU (%v)", clientCert.Subject.OrganizationalUnit, ta.AllowedOUs) } + revoked, err := ta.isInvalid(clientCert, c.Request.TLS.VerifiedChains[0][1]) if err != nil { ta.logger.Errorf("TLSAuth: error checking if client certificate is revoked: %s", err) return false, "", fmt.Errorf("could not check for client certification revokation status: %w", err) } + if revoked { return false, "", fmt.Errorf("client certificate is revoked") } + ta.logger.Debugf("client OU %v is allowed vs required OU %v", clientCert.Subject.OrganizationalUnit, ta.AllowedOUs) + return true, clientCert.Subject.CommonName, nil } + return false, "", fmt.Errorf("no verified cert in request") } @@ -248,9 +287,11 @@ func NewTLSAuth(allowedOus []string, crlPath string, cacheExpiration time.Durati CrlPath: crlPath, logger: logger, } + err := ta.SetAllowedOu(allowedOus) if err != nil { return nil, err } + return ta, nil } diff --git a/pkg/apiserver/papi.go b/pkg/apiserver/papi.go index 2cf032d26b72..7a0f58b61347 100644 --- a/pkg/apiserver/papi.go +++ b/pkg/apiserver/papi.go @@ -87,17 +87,17 @@ type PapiPermCheckSuccess struct { } func NewPAPI(apic *apic, dbClient *database.Client, consoleConfig *csconfig.ConsoleConfig, logLevel log.Level) (*Papi, error) { - logger := log.New() if err := types.ConfigureLogger(logger); err != nil { - return &Papi{}, fmt.Errorf("creating papi logger: %s", err) + return &Papi{}, fmt.Errorf("creating papi logger: %w", err) } + logger.SetLevel(logLevel) - papiUrl := *apic.apiClient.PapiURL - papiUrl.Path = fmt.Sprintf("%s%s", types.PAPIVersion, types.PAPIPollUrl) + papiURL := *apic.apiClient.PapiURL + papiURL.Path = fmt.Sprintf("%s%s", types.PAPIVersion, types.PAPIPollUrl) longPollClient, err := longpollclient.NewLongPollClient(longpollclient.LongPollClientConfig{ - Url: papiUrl, + Url: papiURL, Logger: logger, HttpClient: apic.apiClient.GetClient(), }) @@ -132,55 +132,64 @@ func NewPAPI(apic *apic, dbClient *database.Client, consoleConfig *csconfig.Cons func (p *Papi) handleEvent(event longpollclient.Event, sync bool) error { logger := p.Logger.WithField("request-id", event.RequestId) logger.Debugf("message received: %+v", event.Data) + message := &Message{} if err := json.Unmarshal([]byte(event.Data), message); err != nil { - return fmt.Errorf("polling papi message format is not compatible: %+v: %s", event.Data, err) + return fmt.Errorf("polling papi message format is not compatible: %+v: %w", event.Data, err) } + if message.Header == nil { return fmt.Errorf("no header in message, skipping") } + if message.Header.Source == nil { return fmt.Errorf("no source user in header message, skipping") } if operationFunc, ok := operationMap[message.Header.OperationType]; ok { logger.Debugf("Calling operation '%s'", message.Header.OperationType) - err := operationFunc(message, p, sync) - if err != nil { - return fmt.Errorf("'%s %s failed: %s", message.Header.OperationType, message.Header.OperationCmd, err) + + if err := operationFunc(message, p, sync); err != nil { + return fmt.Errorf("'%s %s failed: %w", message.Header.OperationType, message.Header.OperationCmd, err) } } else { return fmt.Errorf("operation '%s' unknown, continue", message.Header.OperationType) } + return nil } func (p *Papi) GetPermissions() (PapiPermCheckSuccess, error) { httpClient := p.apiClient.GetClient() - papiCheckUrl := fmt.Sprintf("%s%s%s", p.URL, types.PAPIVersion, types.PAPIPermissionsUrl) - req, err := http.NewRequest(http.MethodGet, papiCheckUrl, nil) + papiCheckURL := fmt.Sprintf("%s%s%s", p.URL, types.PAPIVersion, types.PAPIPermissionsUrl) + + req, err := http.NewRequest(http.MethodGet, papiCheckURL, nil) if err != nil { - return PapiPermCheckSuccess{}, fmt.Errorf("failed to create request : %s", err) + return PapiPermCheckSuccess{}, fmt.Errorf("failed to create request: %w", err) } + resp, err := httpClient.Do(req) if err != nil { - log.Fatalf("failed to get response : %s", err) + // XXX: fatal? + log.Fatalf("failed to get response: %s", err) } defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { errResp := PapiPermCheckError{} - err = json.NewDecoder(resp.Body).Decode(&errResp) - if err != nil { - return PapiPermCheckSuccess{}, fmt.Errorf("failed to decode response : %s", err) + if err = json.NewDecoder(resp.Body).Decode(&errResp); err != nil { + return PapiPermCheckSuccess{}, fmt.Errorf("failed to decode response: %w", err) } + return PapiPermCheckSuccess{}, fmt.Errorf("unable to query PAPI : %s (%d)", errResp.Error, resp.StatusCode) } + respBody := PapiPermCheckSuccess{} - err = json.NewDecoder(resp.Body).Decode(&respBody) - if err != nil { - return PapiPermCheckSuccess{}, fmt.Errorf("failed to decode response : %s", err) + if err = json.NewDecoder(resp.Body).Decode(&respBody); err != nil { + return PapiPermCheckSuccess{}, fmt.Errorf("failed to decode response: %w", err) } + return respBody, nil } @@ -205,12 +214,15 @@ func (p *Papi) PullOnce(since time.Time, sync bool) error { reversedEvents := reverse(events) //PAPI sends events in the reverse order, which is not an issue when pulling them in real time, but here we need the correct order eventsCount := len(events) p.Logger.Infof("received %d events", eventsCount) + for i, event := range reversedEvents { if err := p.handleEvent(event, sync); err != nil { p.Logger.WithField("request-id", event.RequestId).Errorf("failed to handle event: %s", err) } + p.Logger.Debugf("handled event %d/%d", i, eventsCount) } + p.Logger.Debugf("finished handling events") //Don't update the timestamp in DB, as a "real" LAPI might be running //Worst case, crowdsec will receive a few duplicated events and will discard them @@ -223,16 +235,19 @@ func (p *Papi) Pull() error { p.Logger.Infof("Starting Polling API Pull") lastTimestamp := time.Time{} + lastTimestampStr, err := p.DBClient.GetConfigItem(PapiPullKey) if err != nil { p.Logger.Warningf("failed to get last timestamp for papi pull: %s", err) } + //value doesn't exist, it's first time we're pulling if lastTimestampStr == nil { binTime, err := lastTimestamp.MarshalText() if err != nil { return fmt.Errorf("failed to marshal last timestamp: %w", err) } + if err := p.DBClient.SetConfigItem(PapiPullKey, string(binTime)); err != nil { p.Logger.Errorf("error setting papi pull last key: %s", err) } else { @@ -245,10 +260,12 @@ func (p *Papi) Pull() error { } p.Logger.Infof("Starting PAPI pull (since:%s)", lastTimestamp) + for event := range p.Client.Start(lastTimestamp) { logger := p.Logger.WithField("request-id", event.RequestId) //update last timestamp in database newTime := time.Now().UTC() + binTime, err := newTime.MarshalText() if err != nil { return fmt.Errorf("failed to marshal last timestamp: %w", err) @@ -262,11 +279,11 @@ func (p *Papi) Pull() error { if err := p.DBClient.SetConfigItem(PapiPullKey, string(binTime)); err != nil { return fmt.Errorf("failed to update last timestamp: %w", err) - } else { - logger.Debugf("set last timestamp to %s", newTime) } + logger.Debugf("set last timestamp to %s", newTime) } + return nil } @@ -274,6 +291,7 @@ func (p *Papi) SyncDecisions() error { defer trace.CatchPanic("lapi/syncDecisionsToCAPI") var cache models.DecisionsDeleteRequest + ticker := time.NewTicker(p.SyncInterval) p.Logger.Infof("Start decisions sync to CrowdSec Central API (interval: %s)", p.SyncInterval) @@ -281,10 +299,13 @@ func (p *Papi) SyncDecisions() error { select { case <-p.syncTomb.Dying(): // if one apic routine is dying, do we kill the others? p.Logger.Infof("sync decisions tomb is dying, sending cache (%d elements) before exiting", len(cache)) + if len(cache) == 0 { return nil } + go p.SendDeletedDecisions(&cache) + return nil case <-ticker.C: if len(cache) > 0 { @@ -293,15 +314,19 @@ func (p *Papi) SyncDecisions() error { cache = make([]models.DecisionsDeleteRequestItem, 0) p.mu.Unlock() p.Logger.Infof("sync decisions: %d deleted decisions to push", len(cacheCopy)) + go p.SendDeletedDecisions(&cacheCopy) } case deletedDecisions := <-p.Channels.DeleteDecisionChannel: if (p.consoleConfig.ShareManualDecisions != nil && *p.consoleConfig.ShareManualDecisions) || (p.consoleConfig.ConsoleManagement != nil && *p.consoleConfig.ConsoleManagement) { var tmpDecisions []models.DecisionsDeleteRequestItem + p.Logger.Debugf("%d decisions deletion to add in cache", len(deletedDecisions)) + for _, decision := range deletedDecisions { tmpDecisions = append(tmpDecisions, models.DecisionsDeleteRequestItem(decision.UUID)) } + p.mu.Lock() cache = append(cache, tmpDecisions...) p.mu.Unlock() @@ -311,33 +336,42 @@ func (p *Papi) SyncDecisions() error { } func (p *Papi) SendDeletedDecisions(cacheOrig *models.DecisionsDeleteRequest) { - - var cache []models.DecisionsDeleteRequestItem = *cacheOrig - var send models.DecisionsDeleteRequest + var ( + cache []models.DecisionsDeleteRequestItem = *cacheOrig + send models.DecisionsDeleteRequest + ) bulkSize := 50 pageStart := 0 pageEnd := bulkSize + for { if pageEnd >= len(cache) { send = cache[pageStart:] ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + _, _, err := p.apiClient.DecisionDelete.Add(ctx, &send) if err != nil { p.Logger.Errorf("sending deleted decisions to central API: %s", err) return } + break } + send = cache[pageStart:pageEnd] ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + _, _, err := p.apiClient.DecisionDelete.Add(ctx, &send) if err != nil { //we log it here as well, because the return value of func might be discarded p.Logger.Errorf("sending deleted decisions to central API: %s", err) } + pageStart += bulkSize pageEnd += bulkSize } diff --git a/pkg/apiserver/papi_cmd.go b/pkg/apiserver/papi_cmd.go index 6ab8f37349d4..c2cdb1539d15 100644 --- a/pkg/apiserver/papi_cmd.go +++ b/pkg/apiserver/papi_cmd.go @@ -40,17 +40,18 @@ type forcePull struct { func DecisionCmd(message *Message, p *Papi, sync bool) error { switch message.Header.OperationCmd { case "delete": - data, err := json.Marshal(message.Data) if err != nil { return err } + UUIDs := make([]string, 0) deleteDecisionMsg := deleteDecisions{ Decisions: make([]string, 0), } + if err := json.Unmarshal(data, &deleteDecisionMsg); err != nil { - return fmt.Errorf("message for '%s' contains bad data format: %s", message.Header.OperationType, err) + return fmt.Errorf("message for '%s' contains bad data format: %w", message.Header.OperationType, err) } UUIDs = append(UUIDs, deleteDecisionMsg.Decisions...) @@ -59,10 +60,13 @@ func DecisionCmd(message *Message, p *Papi, sync bool) error { filter := make(map[string][]string) filter["uuid"] = UUIDs _, deletedDecisions, err := p.DBClient.SoftDeleteDecisionsWithFilter(filter) + if err != nil { - return fmt.Errorf("unable to delete decisions %+v : %s", UUIDs, err) + return fmt.Errorf("unable to delete decisions %+v: %w", UUIDs, err) } + decisions := make([]*models.Decision, 0) + for _, deletedDecision := range deletedDecisions { log.Infof("Decision from '%s' for '%s' (%s) has been deleted", deletedDecision.Origin, deletedDecision.Value, deletedDecision.Type) dec := &models.Decision{ @@ -92,6 +96,7 @@ func AlertCmd(message *Message, p *Papi, sync bool) error { if err != nil { return err } + alert := &models.Alert{} if err := json.Unmarshal(data, alert); err != nil { @@ -105,10 +110,12 @@ func AlertCmd(message *Message, p *Papi, sync bool) error { log.Warnf("Alert %d has no StartAt, setting it to now", alert.ID) alert.StartAt = ptr.Of(time.Now().UTC().Format(time.RFC3339)) } + if alert.StopAt == nil || *alert.StopAt == "" { log.Warnf("Alert %d has no StopAt, setting it to now", alert.ID) alert.StopAt = ptr.Of(time.Now().UTC().Format(time.RFC3339)) } + alert.EventsCount = ptr.Of(int32(0)) alert.Capacity = ptr.Of(int32(0)) alert.Leakspeed = ptr.Of("") @@ -128,12 +135,14 @@ func AlertCmd(message *Message, p *Papi, sync bool) error { alert.Source.Scope = ptr.Of(types.ConsoleOrigin) alert.Source.Value = &message.Header.Source.User } + alert.Scenario = &message.Header.Message for _, decision := range alert.Decisions { if *decision.Scenario == "" { decision.Scenario = &message.Header.Message } + log.Infof("Adding decision for '%s' with UUID: %s", *decision.Value, decision.UUID) } @@ -157,6 +166,7 @@ func ManagementCmd(message *Message, p *Papi, sync bool) error { log.Infof("Ignoring management command from PAPI in sync mode") return nil } + switch message.Header.OperationCmd { case "reauth": log.Infof("Received reauth command from PAPI, resetting token") @@ -166,16 +176,17 @@ func ManagementCmd(message *Message, p *Papi, sync bool) error { if err != nil { return err } + forcePullMsg := forcePull{} if err := json.Unmarshal(data, &forcePullMsg); err != nil { - return fmt.Errorf("message for '%s' contains bad data format: %s", message.Header.OperationType, err) + return fmt.Errorf("message for '%s' contains bad data format: %w", message.Header.OperationType, err) } if forcePullMsg.Blocklist == nil { log.Infof("Received force_pull command from PAPI, pulling community and 3rd-party blocklists") - err = p.apic.PullTop(true) - if err != nil { - return fmt.Errorf("failed to force pull operation: %s", err) + + if err = p.apic.PullTop(true); err != nil { + return fmt.Errorf("failed to force pull operation: %w", err) } } else { log.Infof("Received force_pull command from PAPI, pulling blocklist %s", forcePullMsg.Blocklist.Name) @@ -187,12 +198,12 @@ func ManagementCmd(message *Message, p *Papi, sync bool) error { Duration: &forcePullMsg.Blocklist.Duration, }, true) if err != nil { - return fmt.Errorf("failed to force pull operation: %s", err) + return fmt.Errorf("failed to force pull operation: %w", err) } } - default: return fmt.Errorf("unknown command '%s' for operation type '%s'", message.Header.OperationCmd, message.Header.OperationType) } + return nil }