From 7e83e9189ca8d9a38d32f04219237dae0820ddd7 Mon Sep 17 00:00:00 2001 From: marco Date: Mon, 15 Jan 2024 11:47:29 +0100 Subject: [PATCH] lint --- pkg/apiserver/apic.go | 22 +++++++---- pkg/apiserver/controllers/v1/alerts.go | 3 +- pkg/apiserver/controllers/v1/decisions.go | 39 +++++++++++-------- pkg/apiserver/middlewares/v1/api_key.go | 3 +- pkg/apiserver/middlewares/v1/jwt.go | 17 ++++---- pkg/apiserver/papi.go | 47 ++++++++++++++--------- pkg/apiserver/papi_cmd.go | 9 +++-- 7 files changed, 81 insertions(+), 59 deletions(-) diff --git a/pkg/apiserver/apic.go b/pkg/apiserver/apic.go index f57ae685e45..e89b1d3d8ec 100644 --- a/pkg/apiserver/apic.go +++ b/pkg/apiserver/apic.go @@ -908,12 +908,19 @@ func (a *apic) UpdateBlocklists(links *modelscapi.GetDecisionsStreamResponseLink } func setAlertScenario(alert *models.Alert, addCounters map[string]map[string]int, deleteCounters map[string]map[string]int) { - if *alert.Source.Scope == types.CAPIOrigin { + switch *alert.Source.Scope { + case types.CAPIOrigin: *alert.Source.Scope = types.CommunityBlocklistPullSourceScope - alert.Scenario = ptr.Of(fmt.Sprintf("update : +%d/-%d IPs", addCounters[types.CAPIOrigin]["all"], deleteCounters[types.CAPIOrigin]["all"])) - } else if *alert.Source.Scope == types.ListOrigin { + alert.Scenario = ptr.Of(fmt.Sprintf("update : +%d/-%d IPs", + addCounters[types.CAPIOrigin]["all"], + deleteCounters[types.CAPIOrigin]["all"]), + ) + case types.ListOrigin: *alert.Source.Scope = fmt.Sprintf("%s:%s", types.ListOrigin, *alert.Scenario) - alert.Scenario = ptr.Of(fmt.Sprintf("update : +%d/-%d IPs", addCounters[types.ListOrigin][*alert.Scenario], deleteCounters[types.ListOrigin][*alert.Scenario])) + alert.Scenario = ptr.Of(fmt.Sprintf("update : +%d/-%d IPs", + addCounters[types.ListOrigin][*alert.Scenario], + deleteCounters[types.ListOrigin][*alert.Scenario]), + ) } } @@ -985,11 +992,12 @@ func makeAddAndDeleteCounters() (map[string]map[string]int, map[string]map[strin } func updateCounterForDecision(counter map[string]map[string]int, origin *string, scenario *string, totalDecisions int) { - if *origin == types.CAPIOrigin { + switch *origin { + case types.CAPIOrigin: counter[*origin]["all"] += totalDecisions - } else if *origin == types.ListOrigin { + case types.ListOrigin: counter[*origin][*scenario] += totalDecisions - } else { + default: log.Warningf("Unknown origin %s", *origin) } } diff --git a/pkg/apiserver/controllers/v1/alerts.go b/pkg/apiserver/controllers/v1/alerts.go index ad183e4ba80..5ac4af40b09 100644 --- a/pkg/apiserver/controllers/v1/alerts.go +++ b/pkg/apiserver/controllers/v1/alerts.go @@ -328,8 +328,7 @@ func (c *Controller) DeleteAlertByID(gctx *gin.Context) { return } - decisionIDStr := gctx.Param("alert_id") - decisionID, err := strconv.Atoi(decisionIDStr) + decisionID, err := strconv.Atoi(gctx.Param("alert_id")) if err != nil { gctx.JSON(http.StatusBadRequest, gin.H{"message": "alert_id must be valid integer"}) return diff --git a/pkg/apiserver/controllers/v1/decisions.go b/pkg/apiserver/controllers/v1/decisions.go index f3c6a7bba26..155aab77752 100644 --- a/pkg/apiserver/controllers/v1/decisions.go +++ b/pkg/apiserver/controllers/v1/decisions.go @@ -2,7 +2,6 @@ package v1 import ( "encoding/json" - "fmt" "net/http" "strconv" "time" @@ -138,22 +137,25 @@ func writeStartupDecisions(gctx *gin.Context, filters map[string][]string, dbFun // respBuffer := bytes.NewBuffer([]byte{}) limit := 30000 //FIXME : make it configurable needComma := false - lastId := 0 + lastID := 0 - limitStr := fmt.Sprintf("%d", limit) + limitStr := strconv.Itoa(limit) filters["limit"] = []string{limitStr} + for { - if lastId > 0 { - lastIdStr := fmt.Sprintf("%d", lastId) - filters["id_gt"] = []string{lastIdStr} + if lastID > 0 { + lastIDStr := strconv.Itoa(lastID) + filters["id_gt"] = []string{lastIDStr} } data, err := dbFunc(filters) if err != nil { return err } + if len(data) > 0 { - lastId = data[len(data)-1].ID + lastID = data[len(data)-1].ID + results := FormatDecisions(data) for _, decision := range results { decisionJSON, _ := json.Marshal(decision) @@ -175,7 +177,9 @@ func writeStartupDecisions(gctx *gin.Context, filters map[string][]string, dbFun //respBuffer.Reset() } } - log.Debugf("startup: %d decisions returned (limit: %d, lastid: %d)", len(data), limit, lastId) + + log.Debugf("startup: %d decisions returned (limit: %d, lastid: %d)", len(data), limit, lastID) + if len(data) < limit { gctx.Writer.Flush() @@ -190,22 +194,23 @@ func writeDeltaDecisions(gctx *gin.Context, filters map[string][]string, lastPul //respBuffer := bytes.NewBuffer([]byte{}) limit := 30000 //FIXME : make it configurable needComma := false - lastId := 0 + lastID := 0 + + filters["limit"] = []string{strconv.Itoa(limit)} - limitStr := fmt.Sprintf("%d", limit) - filters["limit"] = []string{limitStr} for { - if lastId > 0 { - lastIdStr := fmt.Sprintf("%d", lastId) - filters["id_gt"] = []string{lastIdStr} + if lastID > 0 { + filters["id_gt"] = []string{strconv.Itoa(lastID)} } data, err := dbFunc(lastPull, filters) if err != nil { return err } + if len(data) > 0 { - lastId = data[len(data)-1].ID + lastID = data[len(data)-1].ID + results := FormatDecisions(data) for _, decision := range results { decisionJSON, _ := json.Marshal(decision) @@ -227,7 +232,9 @@ func writeDeltaDecisions(gctx *gin.Context, filters map[string][]string, lastPul //respBuffer.Reset() } } - log.Debugf("startup: %d decisions returned (limit: %d, lastid: %d)", len(data), limit, lastId) + + log.Debugf("startup: %d decisions returned (limit: %d, lastid: %d)", len(data), limit, lastID) + if len(data) < limit { gctx.Writer.Flush() diff --git a/pkg/apiserver/middlewares/v1/api_key.go b/pkg/apiserver/middlewares/v1/api_key.go index 4e273371bfe..51f771efd05 100644 --- a/pkg/apiserver/middlewares/v1/api_key.go +++ b/pkg/apiserver/middlewares/v1/api_key.go @@ -4,6 +4,7 @@ import ( "crypto/rand" "crypto/sha512" "encoding/base64" + "encoding/hex" "fmt" "net/http" "strings" @@ -53,7 +54,7 @@ func HashSHA512(str string) string { hashedKey := sha512.New() hashedKey.Write([]byte(str)) - hashStr := fmt.Sprintf("%x", hashedKey.Sum(nil)) + hashStr := hex.EncodeToString(hashedKey.Sum(nil)) return hashStr } diff --git a/pkg/apiserver/middlewares/v1/jwt.go b/pkg/apiserver/middlewares/v1/jwt.go index 6fe053713bc..7181cd21663 100644 --- a/pkg/apiserver/middlewares/v1/jwt.go +++ b/pkg/apiserver/middlewares/v1/jwt.go @@ -276,25 +276,21 @@ func randomSecret() ([]byte, error) { } func NewJWT(dbClient *database.Client) (*JWT, error) { - // Get secret from environment variable "SECRET" - var ( - secret []byte - err error - ) + var err error // Please be aware that brute force HS256 is possible. // PLEASE choose a STRONG secret secretString := os.Getenv("CS_LAPI_SECRET") - secret = []byte(secretString) + secret := []byte(secretString) switch l := len(secret); { case l == 0: secret, err = randomSecret() if err != nil { - return &JWT{}, err + return nil, err } case l < 64: - return &JWT{}, errors.New("CS_LAPI_SECRET not strong enough") + return nil, errors.New("CS_LAPI_SECRET not strong enough") } jwtMiddleware := &JWT{ @@ -318,13 +314,14 @@ func NewJWT(dbClient *database.Client) (*JWT, error) { TimeFunc: time.Now, }) if err != nil { - return &JWT{}, err + return nil, err } errInit := ret.MiddlewareInit() if errInit != nil { - return &JWT{}, fmt.Errorf("authMiddleware.MiddlewareInit() Error:" + errInit.Error()) + return nil, fmt.Errorf("authMiddleware.MiddlewareInit() Error:" + errInit.Error()) } + jwtMiddleware.Middleware = ret return jwtMiddleware, nil diff --git a/pkg/apiserver/papi.go b/pkg/apiserver/papi.go index a3996850a2b..7a0f58b6134 100644 --- a/pkg/apiserver/papi.go +++ b/pkg/apiserver/papi.go @@ -87,17 +87,17 @@ type PapiPermCheckSuccess struct { } func NewPAPI(apic *apic, dbClient *database.Client, consoleConfig *csconfig.ConsoleConfig, logLevel log.Level) (*Papi, error) { - logger := log.New() if err := types.ConfigureLogger(logger); err != nil { - return &Papi{}, fmt.Errorf("creating papi logger: %s", err) + return &Papi{}, fmt.Errorf("creating papi logger: %w", err) } + logger.SetLevel(logLevel) - papiUrl := *apic.apiClient.PapiURL - papiUrl.Path = fmt.Sprintf("%s%s", types.PAPIVersion, types.PAPIPollUrl) + papiURL := *apic.apiClient.PapiURL + papiURL.Path = fmt.Sprintf("%s%s", types.PAPIVersion, types.PAPIPollUrl) longPollClient, err := longpollclient.NewLongPollClient(longpollclient.LongPollClientConfig{ - Url: papiUrl, + Url: papiURL, Logger: logger, HttpClient: apic.apiClient.GetClient(), }) @@ -132,55 +132,64 @@ func NewPAPI(apic *apic, dbClient *database.Client, consoleConfig *csconfig.Cons func (p *Papi) handleEvent(event longpollclient.Event, sync bool) error { logger := p.Logger.WithField("request-id", event.RequestId) logger.Debugf("message received: %+v", event.Data) + message := &Message{} if err := json.Unmarshal([]byte(event.Data), message); err != nil { - return fmt.Errorf("polling papi message format is not compatible: %+v: %s", event.Data, err) + return fmt.Errorf("polling papi message format is not compatible: %+v: %w", event.Data, err) } + if message.Header == nil { return fmt.Errorf("no header in message, skipping") } + if message.Header.Source == nil { return fmt.Errorf("no source user in header message, skipping") } if operationFunc, ok := operationMap[message.Header.OperationType]; ok { logger.Debugf("Calling operation '%s'", message.Header.OperationType) - err := operationFunc(message, p, sync) - if err != nil { - return fmt.Errorf("'%s %s failed: %s", message.Header.OperationType, message.Header.OperationCmd, err) + + if err := operationFunc(message, p, sync); err != nil { + return fmt.Errorf("'%s %s failed: %w", message.Header.OperationType, message.Header.OperationCmd, err) } } else { return fmt.Errorf("operation '%s' unknown, continue", message.Header.OperationType) } + return nil } func (p *Papi) GetPermissions() (PapiPermCheckSuccess, error) { httpClient := p.apiClient.GetClient() - papiCheckUrl := fmt.Sprintf("%s%s%s", p.URL, types.PAPIVersion, types.PAPIPermissionsUrl) - req, err := http.NewRequest(http.MethodGet, papiCheckUrl, nil) + papiCheckURL := fmt.Sprintf("%s%s%s", p.URL, types.PAPIVersion, types.PAPIPermissionsUrl) + + req, err := http.NewRequest(http.MethodGet, papiCheckURL, nil) if err != nil { - return PapiPermCheckSuccess{}, fmt.Errorf("failed to create request : %s", err) + return PapiPermCheckSuccess{}, fmt.Errorf("failed to create request: %w", err) } + resp, err := httpClient.Do(req) if err != nil { - log.Fatalf("failed to get response : %s", err) + // XXX: fatal? + log.Fatalf("failed to get response: %s", err) } defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { errResp := PapiPermCheckError{} - err = json.NewDecoder(resp.Body).Decode(&errResp) - if err != nil { - return PapiPermCheckSuccess{}, fmt.Errorf("failed to decode response : %s", err) + if err = json.NewDecoder(resp.Body).Decode(&errResp); err != nil { + return PapiPermCheckSuccess{}, fmt.Errorf("failed to decode response: %w", err) } + return PapiPermCheckSuccess{}, fmt.Errorf("unable to query PAPI : %s (%d)", errResp.Error, resp.StatusCode) } + respBody := PapiPermCheckSuccess{} - err = json.NewDecoder(resp.Body).Decode(&respBody) - if err != nil { - return PapiPermCheckSuccess{}, fmt.Errorf("failed to decode response : %s", err) + if err = json.NewDecoder(resp.Body).Decode(&respBody); err != nil { + return PapiPermCheckSuccess{}, fmt.Errorf("failed to decode response: %w", err) } + return respBody, nil } diff --git a/pkg/apiserver/papi_cmd.go b/pkg/apiserver/papi_cmd.go index ba02034882c..c2cdb1539d1 100644 --- a/pkg/apiserver/papi_cmd.go +++ b/pkg/apiserver/papi_cmd.go @@ -176,16 +176,17 @@ func ManagementCmd(message *Message, p *Papi, sync bool) error { if err != nil { return err } + forcePullMsg := forcePull{} if err := json.Unmarshal(data, &forcePullMsg); err != nil { - return fmt.Errorf("message for '%s' contains bad data format: %s", message.Header.OperationType, err) + return fmt.Errorf("message for '%s' contains bad data format: %w", message.Header.OperationType, err) } if forcePullMsg.Blocklist == nil { log.Infof("Received force_pull command from PAPI, pulling community and 3rd-party blocklists") - err = p.apic.PullTop(true) - if err != nil { - return fmt.Errorf("failed to force pull operation: %s", err) + + if err = p.apic.PullTop(true); err != nil { + return fmt.Errorf("failed to force pull operation: %w", err) } } else { log.Infof("Received force_pull command from PAPI, pulling blocklist %s", forcePullMsg.Blocklist.Name)