diff --git a/pkg/alertcontext/alertcontext.go b/pkg/alertcontext/alertcontext.go index 7586e7cb4af..8b0648ca0eb 100644 --- a/pkg/alertcontext/alertcontext.go +++ b/pkg/alertcontext/alertcontext.go @@ -19,9 +19,7 @@ const ( maxContextValueLen = 4000 ) -var ( - alertContext = Context{} -) +var alertContext = Context{} type Context struct { ContextToSend map[string][]string @@ -37,19 +35,21 @@ func ValidateContextExpr(key string, expressions []string) error { return fmt.Errorf("compilation of '%s' failed: %v", expression, err) } } + return nil } func NewAlertContext(contextToSend map[string][]string, valueLength int) error { - var clog = log.New() + clog := log.New() if err := types.ConfigureLogger(clog); err != nil { - return fmt.Errorf("couldn't create logger for alert context: %s", err) + return fmt.Errorf("couldn't create logger for alert context: %w", err) } if valueLength == 0 { clog.Debugf("No console context value length provided, using default: %d", maxContextValueLen) valueLength = maxContextValueLen } + if valueLength > maxContextValueLen { clog.Debugf("Provided console context value length (%d) is higher than the maximum, using default: %d", valueLength, maxContextValueLen) valueLength = maxContextValueLen @@ -76,6 +76,7 @@ func NewAlertContext(contextToSend map[string][]string, valueLength int) error { if err != nil { return fmt.Errorf("compilation of '%s' context value failed: %v", value, err) } + alertContext.ContextToSendCompiled[key] = append(alertContext.ContextToSendCompiled[key], valueCompiled) alertContext.ContextToSend[key] = append(alertContext.ContextToSend[key], value) } @@ -85,16 +86,13 @@ func NewAlertContext(contextToSend map[string][]string, valueLength int) error { } func truncate(values []string, contextValueLen int) (string, error) { - var ret string valueByte, err := json.Marshal(values) if err != nil { - return "", fmt.Errorf("unable to dump metas: %s", err) + return "", fmt.Errorf("unable to dump metas: %w", err) } - ret = string(valueByte) - for { - if len(ret) <= contextValueLen { - break - } + + ret := string(valueByte) + for len(ret) > contextValueLen { // if there is only 1 value left and that the size is too big, truncate it if len(values) == 1 { valueToTruncate := values[0] @@ -106,12 +104,15 @@ func truncate(values []string, contextValueLen int) (string, error) { // if there is multiple value inside, just remove the last one values = values[:len(values)-1] } + valueByte, err = json.Marshal(values) if err != nil { - return "", fmt.Errorf("unable to dump metas: %s", err) + return "", fmt.Errorf("unable to dump metas: %w", err) } + ret = string(valueByte) } + return ret, nil } @@ -120,18 +121,22 @@ func EventToContext(events []types.Event) (models.Meta, []error) { metas := make([]*models.MetaItems0, 0) tmpContext := make(map[string][]string) + for _, evt := range events { for key, values := range alertContext.ContextToSendCompiled { if _, ok := tmpContext[key]; !ok { tmpContext[key] = make([]string, 0) } + for _, value := range values { var val string + output, err := expr.Run(value, map[string]interface{}{"evt": evt}) if err != nil { errors = append(errors, fmt.Errorf("failed to get value for %s : %v", key, err)) continue } + switch out := output.(type) { case string: val = out @@ -141,20 +146,24 @@ func EventToContext(events []types.Event) (models.Meta, []error) { errors = append(errors, fmt.Errorf("unexpected return type for %s : %T", key, output)) continue } + if val != "" && !slices.Contains(tmpContext[key], val) { tmpContext[key] = append(tmpContext[key], val) } } } } + for key, values := range tmpContext { if len(values) == 0 { continue } + valueStr, err := truncate(values, alertContext.ContextValueLen) if err != nil { log.Warningf(err.Error()) } + meta := models.MetaItems0{ Key: key, Value: valueStr, @@ -163,5 +172,6 @@ func EventToContext(events []types.Event) (models.Meta, []error) { } ret := models.Meta(metas) + return ret, errors } diff --git a/pkg/apiserver/apic.go b/pkg/apiserver/apic.go index 2136edc8b8e..3f646071b0e 100644 --- a/pkg/apiserver/apic.go +++ b/pkg/apiserver/apic.go @@ -81,12 +81,12 @@ func randomDuration(d time.Duration, delta time.Duration) time.Duration { func (a *apic) FetchScenariosListFromDB() ([]string, error) { scenarios := make([]string, 0) - machines, err := a.dbClient.ListMachines() + machines, err := a.dbClient.ListMachines() if err != nil { return nil, fmt.Errorf("while listing machines: %w", err) } - //merge all scenarios together + // merge all scenarios together for _, v := range machines { machineScenarios := strings.Split(v.Scenarios, ",") log.Debugf("%d scenarios for machine %d", len(machineScenarios), v.ID) @@ -113,7 +113,7 @@ func decisionsToApiDecisions(decisions []*models.Decision) models.AddSignalsRequ Origin: ptr.Of(*decision.Origin), Scenario: ptr.Of(*decision.Scenario), Scope: ptr.Of(*decision.Scope), - //Simulated: *decision.Simulated, + // Simulated: *decision.Simulated, Type: ptr.Of(*decision.Type), Until: decision.Until, Value: ptr.Of(*decision.Value), @@ -196,8 +196,8 @@ func NewAPIC(config *csconfig.OnlineApiClientCfg, dbClient *database.Client, con } password := strfmt.Password(config.Credentials.Password) - apiURL, err := url.Parse(config.Credentials.URL) + apiURL, err := url.Parse(config.Credentials.URL) if err != nil { return nil, fmt.Errorf("while parsing '%s': %w", config.Credentials.URL, err) } @@ -376,7 +376,6 @@ func (a *apic) Send(cacheOrig *models.AddSignalsRequest) { defer cancel() _, _, err := a.apiClient.Signal.Add(ctx, &send) - if err != nil { log.Errorf("sending signal to central API: %s", err) return @@ -391,9 +390,8 @@ func (a *apic) Send(cacheOrig *models.AddSignalsRequest) { defer cancel() _, _, err := a.apiClient.Signal.Add(ctx, &send) - if err != nil { - //we log it here as well, because the return value of func might be discarded + // we log it here as well, because the return value of func might be discarded log.Errorf("sending signal to central API: %s", err) } @@ -407,8 +405,8 @@ func (a *apic) CAPIPullIsOld() (bool, error) { alerts := a.dbClient.Ent.Alert.Query() alerts = alerts.Where(alert.HasDecisionsWith(decision.OriginEQ(database.CapiMachineID))) alerts = alerts.Where(alert.CreatedAtGTE(time.Now().UTC().Add(-time.Duration(1*time.Hour + 30*time.Minute)))) //nolint:unconvert - count, err := alerts.Count(a.dbClient.CTX) + count, err := alerts.Count(a.dbClient.CTX) if err != nil { return false, fmt.Errorf("while looking for CAPI alert: %w", err) } @@ -506,6 +504,7 @@ func createAlertsForDecisions(decisions []*models.Decision) []*models.Alert { if sub.Scenario == nil { log.Warningf("nil scenario in %+v", sub) } + if *sub.Scenario == *decision.Scenario { found = true break @@ -567,7 +566,7 @@ func createAlertForDecision(decision *models.Decision) *models.Alert { // This function takes in list of parent alerts and decisions and then pairs them up. func fillAlertsWithDecisions(alerts []*models.Alert, decisions []*models.Decision, addCounters map[string]map[string]int) []*models.Alert { for _, decision := range decisions { - //count and create separate alerts for each list + // count and create separate alerts for each list updateCounterForDecision(addCounters, decision.Origin, decision.Scenario, 1) /*CAPI might send lower case scopes, unify it.*/ @@ -579,7 +578,7 @@ func fillAlertsWithDecisions(alerts []*models.Alert, decisions []*models.Decisio } found := false - //add the individual decisions to the right list + // add the individual decisions to the right list for idx, alert := range alerts { if *decision.Origin == types.CAPIOrigin { if *alert.Source.Scope == types.CAPIOrigin { @@ -592,6 +591,7 @@ func fillAlertsWithDecisions(alerts []*models.Alert, decisions []*models.Decisio if *alert.Source.Scope == types.ListOrigin && *alert.Scenario == *decision.Scenario { alerts[idx].Decisions = append(alerts[idx].Decisions, decision) found = true + break } } else { @@ -613,8 +613,8 @@ func fillAlertsWithDecisions(alerts []*models.Alert, decisions []*models.Decisio func (a *apic) PullTop(forcePull bool) error { var err error - //A mutex with TryLock would be a bit simpler - //But go does not guarantee that TryLock will be able to acquire the lock even if it is available + // A mutex with TryLock would be a bit simpler + // But go does not guarantee that TryLock will be able to acquire the lock even if it is available select { case a.isPulling <- true: defer func() { @@ -633,6 +633,7 @@ func (a *apic) PullTop(forcePull bool) error { } log.Debug("Acquiring lock for pullCAPI") + err = a.dbClient.AcquirePullCAPILock() if a.dbClient.IsLocked(err) { log.Info("PullCAPI is already running, skipping") @@ -642,6 +643,7 @@ func (a *apic) PullTop(forcePull bool) error { /*defer lock release*/ defer func() { log.Debug("Releasing lock for pullCAPI") + if err := a.dbClient.ReleasePullCAPILock(); err != nil { log.Errorf("while releasing lock: %v", err) } @@ -681,7 +683,7 @@ func (a *apic) PullTop(forcePull bool) error { // create one alert for community blocklist using the first decision decisions := a.apiClient.Decisions.GetDecisionsFromGroups(data.New) - //apply APIC specific whitelists + // apply APIC specific whitelists decisions = a.ApplyApicWhitelists(decisions) alert := createAlertForDecision(decisions[0]) @@ -740,7 +742,7 @@ func (a *apic) ApplyApicWhitelists(decisions []*models.Decision) []*models.Decis if a.whitelists == nil || len(a.whitelists.Cidrs) == 0 && len(a.whitelists.Ips) == 0 { return decisions } - //deal with CAPI whitelists for fire. We want to avoid having a second list, so we shrink in place + // deal with CAPI whitelists for fire. We want to avoid having a second list, so we shrink in place outIdx := 0 for _, decision := range decisions { @@ -753,7 +755,7 @@ func (a *apic) ApplyApicWhitelists(decisions []*models.Decision) []*models.Decis decisions[outIdx] = decision outIdx++ } - //shrink the list, those are deleted items + // shrink the list, those are deleted items return decisions[:outIdx] } @@ -782,8 +784,8 @@ func (a *apic) ShouldForcePullBlocklist(blocklist *modelscapi.BlocklistLink) (bo alertQuery := a.dbClient.Ent.Alert.Query() alertQuery.Where(alert.SourceScopeEQ(fmt.Sprintf("%s:%s", types.ListOrigin, *blocklist.Name))) alertQuery.Order(ent.Desc(alert.FieldCreatedAt)) - alertInstance, err := alertQuery.First(context.Background()) + alertInstance, err := alertQuery.First(context.Background()) if err != nil { if ent.IsNotFound(err) { log.Debugf("no alert found for %s, force refresh", *blocklist.Name) @@ -795,8 +797,8 @@ func (a *apic) ShouldForcePullBlocklist(blocklist *modelscapi.BlocklistLink) (bo decisionQuery := a.dbClient.Ent.Decision.Query() decisionQuery.Where(decision.HasOwnerWith(alert.IDEQ(alertInstance.ID))) - firstDecision, err := decisionQuery.First(context.Background()) + firstDecision, err := decisionQuery.First(context.Background()) if err != nil { if ent.IsNotFound(err) { log.Debugf("no decision found for %s, force refresh", *blocklist.Name) @@ -872,7 +874,7 @@ func (a *apic) updateBlocklist(client *apiclient.ApiClient, blocklist *modelscap log.Infof("blocklist %s has no decisions", *blocklist.Name) return nil } - //apply APIC specific whitelists + // apply APIC specific whitelists decisions = a.ApplyApicWhitelists(decisions) alert := createAlertForDecision(decisions[0]) alertsFromCapi := []*models.Alert{alert} @@ -911,12 +913,17 @@ func (a *apic) UpdateBlocklists(links *modelscapi.GetDecisionsStreamResponseLink } func setAlertScenario(alert *models.Alert, addCounters map[string]map[string]int, deleteCounters map[string]map[string]int) { - if *alert.Source.Scope == types.CAPIOrigin { + switch *alert.Source.Scope { + case types.CAPIOrigin: *alert.Source.Scope = types.CommunityBlocklistPullSourceScope - alert.Scenario = ptr.Of(fmt.Sprintf("update : +%d/-%d IPs", addCounters[types.CAPIOrigin]["all"], deleteCounters[types.CAPIOrigin]["all"])) - } else if *alert.Source.Scope == types.ListOrigin { + alert.Scenario = ptr.Of(fmt.Sprintf("update : +%d/-%d IPs", + addCounters[types.CAPIOrigin]["all"], + deleteCounters[types.CAPIOrigin]["all"])) + case types.ListOrigin: *alert.Source.Scope = fmt.Sprintf("%s:%s", types.ListOrigin, *alert.Scenario) - alert.Scenario = ptr.Of(fmt.Sprintf("update : +%d/-%d IPs", addCounters[types.ListOrigin][*alert.Scenario], deleteCounters[types.ListOrigin][*alert.Scenario])) + alert.Scenario = ptr.Of(fmt.Sprintf("update : +%d/-%d IPs", + addCounters[types.ListOrigin][*alert.Scenario], + deleteCounters[types.ListOrigin][*alert.Scenario])) } } @@ -988,11 +995,12 @@ func makeAddAndDeleteCounters() (map[string]map[string]int, map[string]map[strin } func updateCounterForDecision(counter map[string]map[string]int, origin *string, scenario *string, totalDecisions int) { - if *origin == types.CAPIOrigin { + switch *origin { + case types.CAPIOrigin: counter[*origin]["all"] += totalDecisions - } else if *origin == types.ListOrigin { + case types.ListOrigin: counter[*origin][*scenario] += totalDecisions - } else { + default: log.Warningf("Unknown origin %s", *origin) } } diff --git a/pkg/database/utils.go b/pkg/database/utils.go index 2414e702786..f1c06565635 100644 --- a/pkg/database/utils.go +++ b/pkg/database/utils.go @@ -13,12 +13,14 @@ func IP2Int(ip net.IP) uint32 { if len(ip) == 16 { return binary.BigEndian.Uint32(ip[12:16]) } + return binary.BigEndian.Uint32(ip) } func Int2ip(nn uint32) net.IP { ip := make(net.IP, 4) binary.BigEndian.PutUint32(ip, nn) + return ip } @@ -26,13 +28,14 @@ func IsIpv4(host string) bool { return net.ParseIP(host) != nil } -//Stolen from : https://github.com/llimllib/ipaddress/ +// Stolen from : https://github.com/llimllib/ipaddress/ // Return the final address of a net range. Convert to IPv4 if possible, // otherwise return an ipv6 func LastAddress(n *net.IPNet) net.IP { ip := n.IP.To4() if ip == nil { ip = n.IP + return net.IP{ ip[0] | ^n.Mask[0], ip[1] | ^n.Mask[1], ip[2] | ^n.Mask[2], ip[3] | ^n.Mask[3], ip[4] | ^n.Mask[4], ip[5] | ^n.Mask[5], @@ -49,40 +52,44 @@ func LastAddress(n *net.IPNet) net.IP { ip[3]|^n.Mask[3]) } +// GetIpsFromIpRange takes a CIDR range and returns the start and end IP func GetIpsFromIpRange(host string) (int64, int64, error) { - var ipStart int64 - var ipEnd int64 - var err error - var parsedRange *net.IPNet - - if _, parsedRange, err = net.ParseCIDR(host); err != nil { - return ipStart, ipEnd, fmt.Errorf("'%s' is not a valid CIDR", host) + _, parsedRange, err := net.ParseCIDR(host) + if err != nil { + return 0, 0, fmt.Errorf("'%s' is not a valid CIDR", host) } + if parsedRange == nil { - return ipStart, ipEnd, fmt.Errorf("unable to parse network : %s", err) + return 0, 0, fmt.Errorf("unable to parse network: %w", err) } - ipStart = int64(IP2Int(parsedRange.IP)) - ipEnd = int64(IP2Int(LastAddress(parsedRange))) + + ipStart := int64(IP2Int(parsedRange.IP)) + ipEnd := int64(IP2Int(LastAddress(parsedRange))) return ipStart, ipEnd, nil } func ParseDuration(d string) (time.Duration, error) { durationStr := d + if strings.HasSuffix(d, "d") { days := strings.Split(d, "d")[0] if len(days) == 0 { return 0, fmt.Errorf("'%s' can't be parsed as duration", d) } + daysInt, err := strconv.Atoi(days) if err != nil { return 0, err } + durationStr = strconv.Itoa(daysInt*24) + "h" } + duration, err := time.ParseDuration(durationStr) if err != nil { return 0, err } + return duration, nil } diff --git a/pkg/leakybucket/manager_load.go b/pkg/leakybucket/manager_load.go index 85eee89d933..bc259c18319 100644 --- a/pkg/leakybucket/manager_load.go +++ b/pkg/leakybucket/manager_load.go @@ -34,42 +34,42 @@ type BucketFactory struct { Author string `yaml:"author"` Description string `yaml:"description"` References []string `yaml:"references"` - Type string `yaml:"type"` //Type can be : leaky, counter, trigger. It determines the main bucket characteristics - Name string `yaml:"name"` //Name of the bucket, used later in log and user-messages. Should be unique - Capacity int `yaml:"capacity"` //Capacity is applicable to leaky buckets and determines the "burst" capacity - LeakSpeed string `yaml:"leakspeed"` //Leakspeed is a float representing how many events per second leak out of the bucket - Duration string `yaml:"duration"` //Duration allows 'counter' buckets to have a fixed life-time - Filter string `yaml:"filter"` //Filter is an expr that determines if an event is elligible for said bucket. Filter is evaluated against the Event struct - GroupBy string `yaml:"groupby,omitempty"` //groupy is an expr that allows to determine the partitions of the bucket. A common example is the source_ip - Distinct string `yaml:"distinct"` //Distinct, when present, adds a `Pour()` processor that will only pour uniq items (based on distinct expr result) - Debug bool `yaml:"debug"` //Debug, when set to true, will enable debugging for _this_ scenario specifically - Labels map[string]interface{} `yaml:"labels"` //Labels is K:V list aiming at providing context the overflow - Blackhole string `yaml:"blackhole,omitempty"` //Blackhole is a duration that, if present, will prevent same bucket partition to overflow more often than $duration - logger *log.Entry `yaml:"-"` //logger is bucket-specific logger (used by Debug as well) - Reprocess bool `yaml:"reprocess"` //Reprocess, if true, will for the bucket to be re-injected into processing chain - CacheSize int `yaml:"cache_size"` //CacheSize, if > 0, limits the size of in-memory cache of the bucket - Profiling bool `yaml:"profiling"` //Profiling, if true, will make the bucket record pours/overflows/etc. - OverflowFilter string `yaml:"overflow_filter"` //OverflowFilter if present, is a filter that must return true for the overflow to go through - ConditionalOverflow string `yaml:"condition"` //condition if present, is an expression that must return true for the bucket to overflow + Type string `yaml:"type"` // Type can be : leaky, counter, trigger. It determines the main bucket characteristics + Name string `yaml:"name"` // Name of the bucket, used later in log and user-messages. Should be unique + Capacity int `yaml:"capacity"` // Capacity is applicable to leaky buckets and determines the "burst" capacity + LeakSpeed string `yaml:"leakspeed"` // Leakspeed is a float representing how many events per second leak out of the bucket + Duration string `yaml:"duration"` // Duration allows 'counter' buckets to have a fixed life-time + Filter string `yaml:"filter"` // Filter is an expr that determines if an event is elligible for said bucket. Filter is evaluated against the Event struct + GroupBy string `yaml:"groupby,omitempty"` // groupy is an expr that allows to determine the partitions of the bucket. A common example is the source_ip + Distinct string `yaml:"distinct"` // Distinct, when present, adds a `Pour()` processor that will only pour uniq items (based on distinct expr result) + Debug bool `yaml:"debug"` // Debug, when set to true, will enable debugging for _this_ scenario specifically + Labels map[string]interface{} `yaml:"labels"` // Labels is K:V list aiming at providing context the overflow + Blackhole string `yaml:"blackhole,omitempty"` // Blackhole is a duration that, if present, will prevent same bucket partition to overflow more often than $duration + logger *log.Entry `yaml:"-"` // logger is bucket-specific logger (used by Debug as well) + Reprocess bool `yaml:"reprocess"` // Reprocess, if true, will for the bucket to be re-injected into processing chain + CacheSize int `yaml:"cache_size"` // CacheSize, if > 0, limits the size of in-memory cache of the bucket + Profiling bool `yaml:"profiling"` // Profiling, if true, will make the bucket record pours/overflows/etc. + OverflowFilter string `yaml:"overflow_filter"` // OverflowFilter if present, is a filter that must return true for the overflow to go through + ConditionalOverflow string `yaml:"condition"` // condition if present, is an expression that must return true for the bucket to overflow BayesianPrior float32 `yaml:"bayesian_prior"` BayesianThreshold float32 `yaml:"bayesian_threshold"` - BayesianConditions []RawBayesianCondition `yaml:"bayesian_conditions"` //conditions for the bayesian bucket - ScopeType types.ScopeType `yaml:"scope,omitempty"` //to enforce a different remediation than blocking an IP. Will default this to IP + BayesianConditions []RawBayesianCondition `yaml:"bayesian_conditions"` // conditions for the bayesian bucket + ScopeType types.ScopeType `yaml:"scope,omitempty"` // to enforce a different remediation than blocking an IP. Will default this to IP BucketName string `yaml:"-"` Filename string `yaml:"-"` RunTimeFilter *vm.Program `json:"-"` RunTimeGroupBy *vm.Program `json:"-"` Data []*types.DataSource `yaml:"data,omitempty"` DataDir string `yaml:"-"` - CancelOnFilter string `yaml:"cancel_on,omitempty"` //a filter that, if matched, kills the bucket - leakspeed time.Duration //internal representation of `Leakspeed` - duration time.Duration //internal representation of `Duration` - ret chan types.Event //the bucket-specific output chan for overflows - processors []Processor //processors is the list of hooks for pour/overflow/create (cf. uniq, blackhole etc.) - output bool //?? + CancelOnFilter string `yaml:"cancel_on,omitempty"` // a filter that, if matched, kills the bucket + leakspeed time.Duration // internal representation of `Leakspeed` + duration time.Duration // internal representation of `Duration` + ret chan types.Event // the bucket-specific output chan for overflows + processors []Processor // processors is the list of hooks for pour/overflow/create (cf. uniq, blackhole etc.) + output bool // ?? ScenarioVersion string `yaml:"version,omitempty"` hash string `yaml:"-"` - Simulated bool `yaml:"simulated"` //Set to true if the scenario instantiating the bucket was in the exclusion list + Simulated bool `yaml:"simulated"` // Set to true if the scenario instantiating the bucket was in the exclusion list tomb *tomb.Tomb `yaml:"-"` wgPour *sync.WaitGroup `yaml:"-"` wgDumpState *sync.WaitGroup `yaml:"-"` @@ -81,66 +81,80 @@ var seed namegenerator.Generator = namegenerator.NewNameGenerator(time.Now().UTC func ValidateFactory(bucketFactory *BucketFactory) error { if bucketFactory.Name == "" { - return fmt.Errorf("bucket must have name") + return errors.New("bucket must have name") } + if bucketFactory.Description == "" { - return fmt.Errorf("description is mandatory") + return errors.New("description is mandatory") } + if bucketFactory.Type == "leaky" { - if bucketFactory.Capacity <= 0 { //capacity must be a positive int + if bucketFactory.Capacity <= 0 { // capacity must be a positive int return fmt.Errorf("bad capacity for leaky '%d'", bucketFactory.Capacity) } + if bucketFactory.LeakSpeed == "" { - return fmt.Errorf("leakspeed can't be empty for leaky") + return errors.New("leakspeed can't be empty for leaky") } + if bucketFactory.leakspeed == 0 { return fmt.Errorf("bad leakspeed for leaky '%s'", bucketFactory.LeakSpeed) } } else if bucketFactory.Type == "counter" { if bucketFactory.Duration == "" { - return fmt.Errorf("duration can't be empty for counter") + return errors.New("duration can't be empty for counter") } + if bucketFactory.duration == 0 { return fmt.Errorf("bad duration for counter bucket '%d'", bucketFactory.duration) } + if bucketFactory.Capacity != -1 { - return fmt.Errorf("counter bucket must have -1 capacity") + return errors.New("counter bucket must have -1 capacity") } } else if bucketFactory.Type == "trigger" { if bucketFactory.Capacity != 0 { - return fmt.Errorf("trigger bucket must have 0 capacity") + return errors.New("trigger bucket must have 0 capacity") } } else if bucketFactory.Type == "conditional" { if bucketFactory.ConditionalOverflow == "" { - return fmt.Errorf("conditional bucket must have a condition") + return errors.New("conditional bucket must have a condition") } + if bucketFactory.Capacity != -1 { bucketFactory.logger.Warnf("Using a value different than -1 as capacity for conditional bucket, this may lead to unexpected overflows") } + if bucketFactory.LeakSpeed == "" { - return fmt.Errorf("leakspeed can't be empty for conditional bucket") + return errors.New("leakspeed can't be empty for conditional bucket") } + if bucketFactory.leakspeed == 0 { return fmt.Errorf("bad leakspeed for conditional bucket '%s'", bucketFactory.LeakSpeed) } } else if bucketFactory.Type == "bayesian" { if bucketFactory.BayesianConditions == nil { - return fmt.Errorf("bayesian bucket must have bayesian conditions") + return errors.New("bayesian bucket must have bayesian conditions") } + if bucketFactory.BayesianPrior == 0 { - return fmt.Errorf("bayesian bucket must have a valid, non-zero prior") + return errors.New("bayesian bucket must have a valid, non-zero prior") } + if bucketFactory.BayesianThreshold == 0 { - return fmt.Errorf("bayesian bucket must have a valid, non-zero threshold") + return errors.New("bayesian bucket must have a valid, non-zero threshold") } + if bucketFactory.BayesianPrior > 1 { - return fmt.Errorf("bayesian bucket must have a valid, non-zero prior") + return errors.New("bayesian bucket must have a valid, non-zero prior") } + if bucketFactory.BayesianThreshold > 1 { - return fmt.Errorf("bayesian bucket must have a valid, non-zero threshold") + return errors.New("bayesian bucket must have a valid, non-zero threshold") } + if bucketFactory.Capacity != -1 { - return fmt.Errorf("bayesian bucket must have capacity -1") + return errors.New("bayesian bucket must have capacity -1") } } else { return fmt.Errorf("unknown bucket type '%s'", bucketFactory.Type) @@ -155,26 +169,31 @@ func ValidateFactory(bucketFactory *BucketFactory) error { runTimeFilter *vm.Program err error ) + if bucketFactory.ScopeType.Filter != "" { if runTimeFilter, err = expr.Compile(bucketFactory.ScopeType.Filter, exprhelpers.GetExprOptions(map[string]interface{}{"evt": &types.Event{}})...); err != nil { - return fmt.Errorf("Error compiling the scope filter: %s", err) + return fmt.Errorf("error compiling the scope filter: %w", err) } + bucketFactory.ScopeType.RunTimeFilter = runTimeFilter } default: - //Compile the scope filter + // Compile the scope filter var ( runTimeFilter *vm.Program err error ) + if bucketFactory.ScopeType.Filter != "" { if runTimeFilter, err = expr.Compile(bucketFactory.ScopeType.Filter, exprhelpers.GetExprOptions(map[string]interface{}{"evt": &types.Event{}})...); err != nil { - return fmt.Errorf("Error compiling the scope filter: %s", err) + return fmt.Errorf("error compiling the scope filter: %w", err) } + bucketFactory.ScopeType.RunTimeFilter = runTimeFilter } } + return nil } @@ -185,48 +204,58 @@ func LoadBuckets(cscfg *csconfig.CrowdsecServiceCfg, hub *cwhub.Hub, files []str ) response = make(chan types.Event, 1) + for _, f := range files { log.Debugf("Loading '%s'", f) + if !strings.HasSuffix(f, ".yaml") && !strings.HasSuffix(f, ".yml") { log.Debugf("Skipping %s : not a yaml file", f) continue } - //process the yaml + // process the yaml bucketConfigurationFile, err := os.Open(f) if err != nil { log.Errorf("Can't access leaky configuration file %s", f) return nil, nil, err } + defer bucketConfigurationFile.Close() dec := yaml.NewDecoder(bucketConfigurationFile) dec.SetStrict(true) + for { bucketFactory := BucketFactory{} + err = dec.Decode(&bucketFactory) if err != nil { if !errors.Is(err, io.EOF) { log.Errorf("Bad yaml in %s : %v", f, err) return nil, nil, fmt.Errorf("bad yaml in %s : %v", f, err) } + log.Tracef("End of yaml file") + break } + bucketFactory.DataDir = hub.GetDataDir() - //check empty + // check empty if bucketFactory.Name == "" { log.Errorf("Won't load nameless bucket") - return nil, nil, fmt.Errorf("nameless bucket") + return nil, nil, errors.New("nameless bucket") } - //check compat + // check compat if bucketFactory.FormatVersion == "" { log.Tracef("no version in %s : %s, assuming '1.0'", bucketFactory.Name, f) bucketFactory.FormatVersion = "1.0" } + ok, err := cwversion.Satisfies(bucketFactory.FormatVersion, cwversion.Constraint_scenario) if err != nil { return nil, nil, fmt.Errorf("failed to check version : %s", err) } + if !ok { log.Errorf("can't load %s : %s doesn't satisfy scenario format %s, skip", bucketFactory.Name, bucketFactory.FormatVersion, cwversion.Constraint_scenario) continue @@ -235,6 +264,7 @@ func LoadBuckets(cscfg *csconfig.CrowdsecServiceCfg, hub *cwhub.Hub, files []str bucketFactory.Filename = filepath.Clean(f) bucketFactory.BucketName = seed.Generate() bucketFactory.ret = response + hubItem, err := hub.GetItemByPath(cwhub.SCENARIOS, bucketFactory.Filename) if err != nil { log.Errorf("scenario %s (%s) couldn't be find in hub (ignore if in unit tests)", bucketFactory.Name, bucketFactory.Filename) @@ -242,6 +272,7 @@ func LoadBuckets(cscfg *csconfig.CrowdsecServiceCfg, hub *cwhub.Hub, files []str if cscfg.SimulationConfig != nil { bucketFactory.Simulated = cscfg.SimulationConfig.IsSimulated(hubItem.Name) } + if hubItem != nil { bucketFactory.ScenarioVersion = hubItem.State.LocalVersion bucketFactory.hash = hubItem.State.LocalHash @@ -252,6 +283,7 @@ func LoadBuckets(cscfg *csconfig.CrowdsecServiceCfg, hub *cwhub.Hub, files []str bucketFactory.wgDumpState = buckets.wgDumpState bucketFactory.wgPour = buckets.wgPour + err = LoadBucket(&bucketFactory, tomb) if err != nil { log.Errorf("Failed to load bucket %s : %v", bucketFactory.Name, err) @@ -265,21 +297,24 @@ func LoadBuckets(cscfg *csconfig.CrowdsecServiceCfg, hub *cwhub.Hub, files []str } if err := alertcontext.NewAlertContext(cscfg.ContextToSend, cscfg.ConsoleContextValueLength); err != nil { - return nil, nil, fmt.Errorf("unable to load alert context: %s", err) + return nil, nil, fmt.Errorf("unable to load alert context: %w", err) } log.Infof("Loaded %d scenarios", len(ret)) + return ret, response, nil } /* Init recursively process yaml files from a directory and loads them as BucketFactory */ func LoadBucket(bucketFactory *BucketFactory, tomb *tomb.Tomb) error { var err error + if bucketFactory.Debug { - var clog = log.New() + clog := log.New() if err := types.ConfigureLogger(clog); err != nil { log.Fatalf("While creating bucket-specific logger : %s", err) } + clog.SetLevel(log.DebugLevel) bucketFactory.logger = clog.WithFields(log.Fields{ "cfg": bucketFactory.BucketName, @@ -300,6 +335,7 @@ func LoadBucket(bucketFactory *BucketFactory, tomb *tomb.Tomb) error { } else { bucketFactory.leakspeed = time.Duration(0) } + if bucketFactory.Duration != "" { if bucketFactory.duration, err = time.ParseDuration(bucketFactory.Duration); err != nil { return fmt.Errorf("invalid Duration '%s' in %s : %v", bucketFactory.Duration, bucketFactory.Filename, err) @@ -308,8 +344,9 @@ func LoadBucket(bucketFactory *BucketFactory, tomb *tomb.Tomb) error { if bucketFactory.Filter == "" { bucketFactory.logger.Warning("Bucket without filter, abort.") - return fmt.Errorf("bucket without filter directive") + return errors.New("bucket without filter directive") } + bucketFactory.RunTimeFilter, err = expr.Compile(bucketFactory.Filter, exprhelpers.GetExprOptions(map[string]interface{}{"evt": &types.Event{}})...) if err != nil { return fmt.Errorf("invalid filter '%s' in %s : %v", bucketFactory.Filter, bucketFactory.Filename, err) @@ -323,7 +360,7 @@ func LoadBucket(bucketFactory *BucketFactory, tomb *tomb.Tomb) error { } bucketFactory.logger.Infof("Adding %s bucket", bucketFactory.Type) - //return the Holder corresponding to the type of bucket + // return the Holder corresponding to the type of bucket bucketFactory.processors = []Processor{} switch bucketFactory.Type { case "leaky": @@ -352,21 +389,25 @@ func LoadBucket(bucketFactory *BucketFactory, tomb *tomb.Tomb) error { if bucketFactory.OverflowFilter != "" { bucketFactory.logger.Tracef("Adding an overflow filter") + filovflw, err := NewOverflowFilter(bucketFactory) if err != nil { bucketFactory.logger.Errorf("Error creating overflow_filter : %s", err) - return fmt.Errorf("error creating overflow_filter : %s", err) + return fmt.Errorf("error creating overflow_filter: %w", err) } + bucketFactory.processors = append(bucketFactory.processors, filovflw) } if bucketFactory.Blackhole != "" { bucketFactory.logger.Tracef("Adding blackhole.") + blackhole, err := NewBlackhole(bucketFactory) if err != nil { bucketFactory.logger.Errorf("Error creating blackhole : %s", err) - return fmt.Errorf("error creating blackhole : %s", err) + return fmt.Errorf("error creating blackhole : %w", err) } + bucketFactory.processors = append(bucketFactory.processors, blackhole) } @@ -380,19 +421,19 @@ func LoadBucket(bucketFactory *BucketFactory, tomb *tomb.Tomb) error { bucketFactory.processors = append(bucketFactory.processors, &BayesianBucket{}) } - if len(bucketFactory.Data) > 0 { - for _, data := range bucketFactory.Data { - if data.DestPath == "" { - bucketFactory.logger.Errorf("no dest_file provided for '%s'", bucketFactory.Name) - continue - } - err = exprhelpers.FileInit(bucketFactory.DataDir, data.DestPath, data.Type) - if err != nil { - bucketFactory.logger.Errorf("unable to init data for file '%s': %s", data.DestPath, err) - } - if data.Type == "regexp" { //cache only makes sense for regexp - exprhelpers.RegexpCacheInit(data.DestPath, *data) - } + for _, data := range bucketFactory.Data { + if data.DestPath == "" { + bucketFactory.logger.Errorf("no dest_file provided for '%s'", bucketFactory.Name) + continue + } + + err = exprhelpers.FileInit(bucketFactory.DataDir, data.DestPath, data.Type) + if err != nil { + bucketFactory.logger.Errorf("unable to init data for file '%s': %s", data.DestPath, err) + } + + if data.Type == "regexp" { // cache only makes sense for regexp + exprhelpers.RegexpCacheInit(data.DestPath, *data) } } @@ -400,34 +441,40 @@ func LoadBucket(bucketFactory *BucketFactory, tomb *tomb.Tomb) error { if err := ValidateFactory(bucketFactory); err != nil { return fmt.Errorf("invalid bucket from %s : %v", bucketFactory.Filename, err) } + bucketFactory.tomb = tomb return nil - } func LoadBucketsState(file string, buckets *Buckets, bucketFactories []BucketFactory) error { var state map[string]Leaky + body, err := os.ReadFile(file) if err != nil { - return fmt.Errorf("can't state file %s : %s", file, err) + return fmt.Errorf("can't read state file %s: %w", file, err) } + if err := json.Unmarshal(body, &state); err != nil { - return fmt.Errorf("can't unmarshal state file %s : %s", file, err) + return fmt.Errorf("can't unmarshal state file %s: %w", file, err) } + for k, v := range state { var tbucket *Leaky + log.Debugf("Reloading bucket %s", k) + val, ok := buckets.Bucket_map.Load(k) if ok { log.Fatalf("key %s already exists : %+v", k, val) } - //find back our holder + // find back our holder found := false + for _, h := range bucketFactories { if h.Name == v.Name { log.Debugf("found factory %s/%s -> %s", h.Author, h.Name, h.Description) - //check in which mode the bucket was + // check in which mode the bucket was if v.Mode == types.TIMEMACHINE { tbucket = NewTimeMachine(h) } else if v.Mode == types.LIVE { @@ -451,16 +498,19 @@ func LoadBucketsState(file string, buckets *Buckets, bucketFactories []BucketFac return LeakRoutine(tbucket) }) <-tbucket.Signal + found = true + break } } + if !found { log.Fatalf("Unable to find holder for bucket %s : %s", k, spew.Sdump(v)) } } log.Infof("Restored %d buckets from dump", len(state)) - return nil + return nil } diff --git a/pkg/types/ip.go b/pkg/types/ip.go index 5e4d7734f2d..9d08afd8809 100644 --- a/pkg/types/ip.go +++ b/pkg/types/ip.go @@ -2,6 +2,7 @@ package types import ( "encoding/binary" + "errors" "fmt" "math" "net" @@ -15,6 +16,7 @@ func LastAddress(n net.IPNet) net.IP { if ip == nil { // IPv6 ip = n.IP + return net.IP{ ip[0] | ^n.Mask[0], ip[1] | ^n.Mask[1], ip[2] | ^n.Mask[2], ip[3] | ^n.Mask[3], ip[4] | ^n.Mask[4], ip[5] | ^n.Mask[5], @@ -38,12 +40,13 @@ func Addr2Ints(anyIP string) (int, int64, int64, int64, int64, error) { if err != nil { return -1, 0, 0, 0, 0, fmt.Errorf("while parsing range %s: %w", anyIP, err) } + return Range2Ints(*net) } ip := net.ParseIP(anyIP) if ip == nil { - return -1, 0, 0, 0, 0, fmt.Errorf("invalid address") + return -1, 0, 0, 0, 0, errors.New("invalid address") } sz, start, end, err := IP2Ints(ip) @@ -56,19 +59,22 @@ func Addr2Ints(anyIP string) (int, int64, int64, int64, int64, error) { /*size (16|4), nw_start, suffix_start, nw_end, suffix_end, error*/ func Range2Ints(network net.IPNet) (int, int64, int64, int64, int64, error) { - szStart, nwStart, sfxStart, err := IP2Ints(network.IP) if err != nil { return -1, 0, 0, 0, 0, fmt.Errorf("converting first ip in range: %w", err) } + lastAddr := LastAddress(network) + szEnd, nwEnd, sfxEnd, err := IP2Ints(lastAddr) if err != nil { return -1, 0, 0, 0, 0, fmt.Errorf("transforming last address of range: %w", err) } + if szEnd != szStart { return -1, 0, 0, 0, 0, fmt.Errorf("inconsistent size for range first(%d) and last(%d) ip", szStart, szEnd) } + return szStart, nwStart, sfxStart, nwEnd, sfxEnd, nil } @@ -85,6 +91,7 @@ func uint2int(u uint64) int64 { ret = int64(u) ret -= math.MaxInt64 } + return ret } @@ -97,13 +104,15 @@ func IP2Ints(pip net.IP) (int, int64, int64, error) { if pip4 != nil { ip_nw32 := binary.BigEndian.Uint32(pip4) - return 4, uint2int(uint64(ip_nw32)), uint2int(ip_sfx), nil - } else if pip16 != nil { + } + + if pip16 != nil { ip_nw = binary.BigEndian.Uint64(pip16[0:8]) ip_sfx = binary.BigEndian.Uint64(pip16[8:16]) + return 16, uint2int(ip_nw), uint2int(ip_sfx), nil - } else { - return -1, 0, 0, fmt.Errorf("unexpected len %d for %s", len(pip), pip) } + + return -1, 0, 0, fmt.Errorf("unexpected len %d for %s", len(pip), pip) }