diff --git a/.github/workflows/docker-tests.yml b/.github/workflows/docker-tests.yml index 7bc63de0178..d3ae4f90d79 100644 --- a/.github/workflows/docker-tests.yml +++ b/.github/workflows/docker-tests.yml @@ -50,7 +50,7 @@ jobs: cache-to: type=gha,mode=min - name: "Setup Python" - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: "3.x" @@ -61,7 +61,7 @@ jobs: - name: "Cache virtualenvs" id: cache-pipenv - uses: actions/cache@v3 + uses: actions/cache@v4 with: path: ~/.local/share/virtualenvs key: ${{ runner.os }}-pipenv-${{ hashFiles('**/Pipfile.lock') }} diff --git a/.golangci.yml b/.golangci.yml index e1f2fc09a84..3161b2c0aaf 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -11,7 +11,7 @@ run: linters-settings: cyclop: # lower this after refactoring - max-complexity: 66 + max-complexity: 70 gci: sections: @@ -26,7 +26,7 @@ linters-settings: gocyclo: # lower this after refactoring - min-complexity: 64 + min-complexity: 70 funlen: # Checks the number of lines in a function. @@ -53,7 +53,7 @@ linters-settings: nestif: # lower this after refactoring - min-complexity: 27 + min-complexity: 28 nlreturn: block-size: 4 @@ -310,10 +310,6 @@ issues: # Will fix, might be trickier # - - linters: - - staticcheck - text: "x509.ParseCRL has been deprecated since Go 1.19: Use ParseRevocationList instead" - # https://github.com/pkg/errors/issues/245 - linters: - depguard diff --git a/cmd/crowdsec-cli/bouncers.go b/cmd/crowdsec-cli/bouncers.go index d2685901ebb..717e9aef5fe 100644 --- a/cmd/crowdsec-cli/bouncers.go +++ b/cmd/crowdsec-cli/bouncers.go @@ -36,13 +36,13 @@ func askYesNo(message string, defaultAnswer bool) (bool, error) { } type cliBouncers struct { - db *database.Client + db *database.Client cfg configGetter } -func NewCLIBouncers(getconfig configGetter) *cliBouncers { +func NewCLIBouncers(cfg configGetter) *cliBouncers { return &cliBouncers{ - cfg: getconfig, + cfg: cfg, } } @@ -197,13 +197,13 @@ cscli bouncers add MyBouncerName --key `, return cmd } -func (cli *cliBouncers) deleteValid(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { +func (cli *cliBouncers) deleteValid(_ *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { bouncers, err := cli.db.ListBouncers() if err != nil { cobra.CompError("unable to list bouncers " + err.Error()) } - ret :=[]string{} + ret := []string{} for _, bouncer := range bouncers { if strings.Contains(bouncer.Name, toComplete) && !slices.Contains(args, bouncer.Name) { diff --git a/cmd/crowdsec-cli/config_restore.go b/cmd/crowdsec-cli/config_restore.go index e9c2fa9aa23..17d7494c60f 100644 --- a/cmd/crowdsec-cli/config_restore.go +++ b/cmd/crowdsec-cli/config_restore.go @@ -146,7 +146,12 @@ func restoreConfigFromDirectory(dirPath string, oldBackup bool) error { // Now we have config.yaml, we should regenerate config struct to have rights paths etc ConfigFilePath = fmt.Sprintf("%s/config.yaml", csConfig.ConfigPaths.ConfigDir) - initConfig() + log.Debug("Reloading configuration") + + csConfig, _, err = loadConfigFor("config") + if err != nil { + return fmt.Errorf("failed to reload configuration: %s", err) + } backupCAPICreds := fmt.Sprintf("%s/online_api_credentials.yaml", dirPath) if _, err = os.Stat(backupCAPICreds); err == nil { @@ -227,7 +232,7 @@ func restoreConfigFromDirectory(dirPath string, oldBackup bool) error { } } - // if there is files in the acquis backup dir, restore them + // if there are files in the acquis backup dir, restore them acquisBackupDir := filepath.Join(dirPath, "acquis", "*.yaml") if acquisFiles, err := filepath.Glob(acquisBackupDir); err == nil { for _, acquisFile := range acquisFiles { diff --git a/cmd/crowdsec-cli/dashboard.go b/cmd/crowdsec-cli/dashboard.go index 64cb7577e89..59b9e67cd94 100644 --- a/cmd/crowdsec-cli/dashboard.go +++ b/cmd/crowdsec-cli/dashboard.go @@ -19,15 +19,14 @@ import ( log "github.com/sirupsen/logrus" "github.com/spf13/cobra" - "github.com/crowdsecurity/crowdsec/pkg/metabase" - "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require" + "github.com/crowdsecurity/crowdsec/pkg/metabase" ) var ( metabaseUser = "crowdsec@crowdsec.net" metabasePassword string - metabaseDbPath string + metabaseDBPath string metabaseConfigPath string metabaseConfigFolder = "metabase/" metabaseConfigFile = "metabase.yaml" @@ -43,13 +42,13 @@ var ( // information needed to set up a random password on user's behalf ) -type cliDashboard struct{ +type cliDashboard struct { cfg configGetter } -func NewCLIDashboard(getconfig configGetter) *cliDashboard { +func NewCLIDashboard(cfg configGetter) *cliDashboard { return &cliDashboard{ - cfg: getconfig, + cfg: cfg, } } @@ -99,6 +98,7 @@ cscli dashboard remove metabaseContainerID = oldContainerID } } + return nil }, } @@ -127,8 +127,8 @@ cscli dashboard setup --listen 0.0.0.0 cscli dashboard setup -l 0.0.0.0 -p 443 --password `, RunE: func(_ *cobra.Command, _ []string) error { - if metabaseDbPath == "" { - metabaseDbPath = cli.cfg().ConfigPaths.DataDir + if metabaseDBPath == "" { + metabaseDBPath = cli.cfg().ConfigPaths.DataDir } if metabasePassword == "" { @@ -152,7 +152,7 @@ cscli dashboard setup -l 0.0.0.0 -p 443 --password if err = cli.chownDatabase(dockerGroup.Gid); err != nil { return err } - mb, err := metabase.SetupMetabase(cli.cfg().API.Server.DbConfig, metabaseListenAddress, metabaseListenPort, metabaseUser, metabasePassword, metabaseDbPath, dockerGroup.Gid, metabaseContainerID, metabaseImage) + mb, err := metabase.SetupMetabase(cli.cfg().API.Server.DbConfig, metabaseListenAddress, metabaseListenPort, metabaseUser, metabasePassword, metabaseDBPath, dockerGroup.Gid, metabaseContainerID, metabaseImage) if err != nil { return err } @@ -165,13 +165,14 @@ cscli dashboard setup -l 0.0.0.0 -p 443 --password fmt.Printf("\tURL : '%s'\n", mb.Config.ListenURL) fmt.Printf("\tusername : '%s'\n", mb.Config.Username) fmt.Printf("\tpassword : '%s'\n", mb.Config.Password) + return nil }, } flags := cmd.Flags() flags.BoolVarP(&force, "force", "f", false, "Force setup : override existing files") - flags.StringVarP(&metabaseDbPath, "dir", "d", "", "Shared directory with metabase container") + flags.StringVarP(&metabaseDBPath, "dir", "d", "", "Shared directory with metabase container") flags.StringVarP(&metabaseListenAddress, "listen", "l", metabaseListenAddress, "Listen address of container") flags.StringVar(&metabaseImage, "metabase-image", metabaseImage, "Metabase image to use") flags.StringVarP(&metabaseListenPort, "port", "p", metabaseListenPort, "Listen port of container") @@ -203,6 +204,7 @@ func (cli *cliDashboard) newStartCmd() *cobra.Command { } log.Infof("Started metabase") log.Infof("url : http://%s:%s", mb.Config.ListenAddr, mb.Config.ListenPort) + return nil }, } @@ -241,6 +243,7 @@ func (cli *cliDashboard) newShowPasswordCmd() *cobra.Command { return err } log.Printf("'%s'", m.Config.Password) + return nil }, } @@ -313,6 +316,7 @@ cscli dashboard remove --force } } } + return nil }, } diff --git a/cmd/crowdsec-cli/dashboard_unsupported.go b/cmd/crowdsec-cli/dashboard_unsupported.go index 4cf8e18b503..cc80abd2528 100644 --- a/cmd/crowdsec-cli/dashboard_unsupported.go +++ b/cmd/crowdsec-cli/dashboard_unsupported.go @@ -13,9 +13,9 @@ type cliDashboard struct{ cfg configGetter } -func NewCLIDashboard(getconfig configGetter) *cliDashboard { +func NewCLIDashboard(cfg configGetter) *cliDashboard { return &cliDashboard{ - cfg: getconfig, + cfg: cfg, } } diff --git a/cmd/crowdsec-cli/decisions.go b/cmd/crowdsec-cli/decisions.go index c5839ae0079..d7165367898 100644 --- a/cmd/crowdsec-cli/decisions.go +++ b/cmd/crowdsec-cli/decisions.go @@ -116,14 +116,13 @@ func (cli *cliDecisions) decisionsToTable(alerts *models.GetAlertsResponse, prin return nil } - type cliDecisions struct { cfg configGetter } -func NewCLIDecisions(getconfig configGetter) *cliDecisions { +func NewCLIDecisions(cfg configGetter) *cliDecisions { return &cliDecisions{ - cfg: getconfig, + cfg: cfg, } } @@ -157,6 +156,7 @@ func (cli *cliDecisions) NewCommand() *cobra.Command { if err != nil { return fmt.Errorf("creating api client: %w", err) } + return nil }, } @@ -393,6 +393,7 @@ cscli decisions add --scope username --value foobar } log.Info("Decision successfully added") + return nil }, } @@ -499,6 +500,7 @@ cscli decisions delete --type captcha } } log.Infof("%s decision(s) deleted", decisions.NbDeleted) + return nil }, } diff --git a/cmd/crowdsec-cli/hub.go b/cmd/crowdsec-cli/hub.go index d3ce380bb6f..600e56889f7 100644 --- a/cmd/crowdsec-cli/hub.go +++ b/cmd/crowdsec-cli/hub.go @@ -13,13 +13,13 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/cwhub" ) -type cliHub struct{ +type cliHub struct { cfg configGetter } -func NewCLIHub(getconfig configGetter) *cliHub { +func NewCLIHub(cfg configGetter) *cliHub { return &cliHub{ - cfg: getconfig, + cfg: cfg, } } diff --git a/cmd/crowdsec-cli/machines.go b/cmd/crowdsec-cli/machines.go index 1819bdcf5fb..7c9b9708c92 100644 --- a/cmd/crowdsec-cli/machines.go +++ b/cmd/crowdsec-cli/machines.go @@ -7,6 +7,7 @@ import ( "fmt" "math/big" "os" + "slices" "strings" "time" @@ -17,7 +18,6 @@ import ( log "github.com/sirupsen/logrus" "github.com/spf13/cobra" "gopkg.in/yaml.v3" - "slices" "github.com/crowdsecurity/machineid" @@ -106,14 +106,14 @@ func getLastHeartbeat(m *ent.Machine) (string, bool) { return hb, true } -type cliMachines struct{ - db *database.Client +type cliMachines struct { + db *database.Client cfg configGetter } -func NewCLIMachines(getconfig configGetter) *cliMachines { +func NewCLIMachines(cfg configGetter) *cliMachines { return &cliMachines{ - cfg: getconfig, + cfg: cfg, } } @@ -136,6 +136,7 @@ Note: This command requires database direct access, so is intended to be run on if err != nil { return fmt.Errorf("unable to create new database client: %s", err) } + return nil }, } @@ -249,7 +250,7 @@ cscli machines add -f- --auto > /tmp/mycreds.yaml`, func (cli *cliMachines) add(args []string, machinePassword string, dumpFile string, apiURL string, interactive bool, autoAdd bool, force bool) error { var ( - err error + err error machineID string ) @@ -347,7 +348,7 @@ func (cli *cliMachines) add(args []string, machinePassword string, dumpFile stri return nil } -func (cli *cliMachines) deleteValid(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { +func (cli *cliMachines) deleteValid(_ *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { machines, err := cli.db.ListMachines() if err != nil { cobra.CompError("unable to list machines " + err.Error()) @@ -447,9 +448,9 @@ func (cli *cliMachines) prune(duration time.Duration, notValidOnly bool, force b func (cli *cliMachines) newPruneCmd() *cobra.Command { var ( - duration time.Duration - notValidOnly bool - force bool + duration time.Duration + notValidOnly bool + force bool ) const defaultDuration = 10 * time.Minute diff --git a/cmd/crowdsec-cli/main.go b/cmd/crowdsec-cli/main.go index db3a164af90..62b85e63047 100644 --- a/cmd/crowdsec-cli/main.go +++ b/cmd/crowdsec-cli/main.go @@ -15,45 +15,88 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/fflag" ) -var trace_lvl, dbg_lvl, nfo_lvl, wrn_lvl, err_lvl bool - var ConfigFilePath string var csConfig *csconfig.Config var dbClient *database.Client -var outputFormat string -var OutputColor string +type configGetter func() *csconfig.Config var mergedConfig string -// flagBranch overrides the value in csConfig.Cscli.HubBranch -var flagBranch = "" +type cliRoot struct { + logTrace bool + logDebug bool + logInfo bool + logWarn bool + logErr bool + outputColor string + outputFormat string + // flagBranch overrides the value in csConfig.Cscli.HubBranch + flagBranch string +} -type configGetter func() *csconfig.Config +func newCliRoot() *cliRoot { + return &cliRoot{} +} -func initConfig() { - var err error +// cfg() is a helper function to get the configuration loaded from config.yaml, +// we pass it to subcommands because the file is not read until the Execute() call +func (cli *cliRoot) cfg() *csconfig.Config { + return csConfig +} - if trace_lvl { - log.SetLevel(log.TraceLevel) - } else if dbg_lvl { - log.SetLevel(log.DebugLevel) - } else if nfo_lvl { - log.SetLevel(log.InfoLevel) - } else if wrn_lvl { - log.SetLevel(log.WarnLevel) - } else if err_lvl { - log.SetLevel(log.ErrorLevel) +// wantedLogLevel returns the log level requested in the command line flags. +func (cli *cliRoot) wantedLogLevel() log.Level { + switch { + case cli.logTrace: + return log.TraceLevel + case cli.logDebug: + return log.DebugLevel + case cli.logInfo: + return log.InfoLevel + case cli.logWarn: + return log.WarnLevel + case cli.logErr: + return log.ErrorLevel + default: + return log.InfoLevel + } +} + +// loadConfigFor loads the configuration file for the given sub-command. +// If the sub-command does not need it, it returns a default configuration. +func loadConfigFor(command string) (*csconfig.Config, string, error) { + noNeedConfig := []string{ + "doc", + "help", + "completion", + "version", + "hubtest", } - if !slices.Contains(NoNeedConfig, os.Args[1]) { + if !slices.Contains(noNeedConfig, command) { log.Debugf("Using %s as configuration file", ConfigFilePath) - csConfig, mergedConfig, err = csconfig.NewConfig(ConfigFilePath, false, false, true) + + config, merged, err := csconfig.NewConfig(ConfigFilePath, false, false, true) if err != nil { - log.Fatal(err) + return nil, "", err } - } else { - csConfig = csconfig.NewDefaultConfig() + + return config, merged, nil + } + + return csconfig.NewDefaultConfig(), "", nil +} + +// initialize is called before the subcommand is executed. +func (cli *cliRoot) initialize() { + var err error + + log.SetLevel(cli.wantedLogLevel()) + + csConfig, mergedConfig, err = loadConfigFor(os.Args[1]) + if err != nil { + log.Fatal(err) } // recap of the enabled feature flags, because logging @@ -62,12 +105,12 @@ func initConfig() { log.Debugf("Enabled feature flags: %s", fflist) } - if flagBranch != "" { - csConfig.Cscli.HubBranch = flagBranch + if cli.flagBranch != "" { + csConfig.Cscli.HubBranch = cli.flagBranch } - if outputFormat != "" { - csConfig.Cscli.Output = outputFormat + if cli.outputFormat != "" { + csConfig.Cscli.Output = cli.outputFormat } if csConfig.Cscli.Output == "" { @@ -85,11 +128,11 @@ func initConfig() { log.SetLevel(log.ErrorLevel) } - if OutputColor != "" { - csConfig.Cscli.Color = OutputColor + if cli.outputColor != "" { + csConfig.Cscli.Color = cli.outputColor - if OutputColor != "yes" && OutputColor != "no" && OutputColor != "auto" { - log.Fatalf("output color %s unknown", OutputColor) + if cli.outputColor != "yes" && cli.outputColor != "no" && cli.outputColor != "auto" { + log.Fatalf("output color %s unknown", cli.outputColor) } } } @@ -102,15 +145,25 @@ var validArgs = []string{ "postoverflows", "scenarios", "simulation", "support", "version", } -var NoNeedConfig = []string{ - "doc", - "help", - "completion", - "version", - "hubtest", +func (cli *cliRoot) colorize(cmd *cobra.Command) { + cc.Init(&cc.Config{ + RootCmd: cmd, + Headings: cc.Yellow, + Commands: cc.Green + cc.Bold, + CmdShortDescr: cc.Cyan, + Example: cc.Italic, + ExecName: cc.Bold, + Aliases: cc.Bold + cc.Italic, + FlagsDataType: cc.White, + Flags: cc.Green, + FlagsDescr: cc.Cyan, + NoExtraNewlines: true, + NoBottomNewline: true, + }) + cmd.SetOut(color.Output) } -func main() { +func (cli *cliRoot) NewCommand() *cobra.Command { // set the formatter asap and worry about level later logFormatter := &log.TextFormatter{TimestampFormat: time.RFC3339, FullTimestamp: true} log.SetFormatter(logFormatter) @@ -135,31 +188,25 @@ It is meant to allow you to manage bans, parsers/scenarios/etc, api and generall /*TBD examples*/ } - cc.Init(&cc.Config{ - RootCmd: cmd, - Headings: cc.Yellow, - Commands: cc.Green + cc.Bold, - CmdShortDescr: cc.Cyan, - Example: cc.Italic, - ExecName: cc.Bold, - Aliases: cc.Bold + cc.Italic, - FlagsDataType: cc.White, - Flags: cc.Green, - FlagsDescr: cc.Cyan, - }) - cmd.SetOut(color.Output) + cli.colorize(cmd) + + /*don't sort flags so we can enforce order*/ + cmd.Flags().SortFlags = false - cmd.PersistentFlags().StringVarP(&ConfigFilePath, "config", "c", csconfig.DefaultConfigPath("config.yaml"), "path to crowdsec config file") - cmd.PersistentFlags().StringVarP(&outputFormat, "output", "o", "", "Output format: human, json, raw") - cmd.PersistentFlags().StringVarP(&OutputColor, "color", "", "auto", "Output color: yes, no, auto") - cmd.PersistentFlags().BoolVar(&dbg_lvl, "debug", false, "Set logging to debug") - cmd.PersistentFlags().BoolVar(&nfo_lvl, "info", false, "Set logging to info") - cmd.PersistentFlags().BoolVar(&wrn_lvl, "warning", false, "Set logging to warning") - cmd.PersistentFlags().BoolVar(&err_lvl, "error", false, "Set logging to error") - cmd.PersistentFlags().BoolVar(&trace_lvl, "trace", false, "Set logging to trace") - cmd.PersistentFlags().StringVar(&flagBranch, "branch", "", "Override hub branch on github") - - if err := cmd.PersistentFlags().MarkHidden("branch"); err != nil { + pflags := cmd.PersistentFlags() + pflags.SortFlags = false + + pflags.StringVarP(&ConfigFilePath, "config", "c", csconfig.DefaultConfigPath("config.yaml"), "path to crowdsec config file") + pflags.StringVarP(&cli.outputFormat, "output", "o", "", "Output format: human, json, raw") + pflags.StringVarP(&cli.outputColor, "color", "", "auto", "Output color: yes, no, auto") + pflags.BoolVar(&cli.logDebug, "debug", false, "Set logging to debug") + pflags.BoolVar(&cli.logInfo, "info", false, "Set logging to info") + pflags.BoolVar(&cli.logWarn, "warning", false, "Set logging to warning") + pflags.BoolVar(&cli.logErr, "error", false, "Set logging to error") + pflags.BoolVar(&cli.logTrace, "trace", false, "Set logging to trace") + pflags.StringVar(&cli.flagBranch, "branch", "", "Override hub branch on github") + + if err := pflags.MarkHidden("branch"); err != nil { log.Fatalf("failed to hide flag: %s", err) } @@ -179,29 +226,20 @@ It is meant to allow you to manage bans, parsers/scenarios/etc, api and generall } if len(os.Args) > 1 { - cobra.OnInitialize(initConfig) - } - - /*don't sort flags so we can enforce order*/ - cmd.Flags().SortFlags = false - cmd.PersistentFlags().SortFlags = false - - // we use a getter because the config is not initialized until the Execute() call - getconfig := func() *csconfig.Config { - return csConfig + cobra.OnInitialize(cli.initialize) } cmd.AddCommand(NewCLIDoc().NewCommand(cmd)) cmd.AddCommand(NewCLIVersion().NewCommand()) cmd.AddCommand(NewConfigCmd()) - cmd.AddCommand(NewCLIHub(getconfig).NewCommand()) - cmd.AddCommand(NewCLIMetrics(getconfig).NewCommand()) - cmd.AddCommand(NewCLIDashboard(getconfig).NewCommand()) - cmd.AddCommand(NewCLIDecisions(getconfig).NewCommand()) + cmd.AddCommand(NewCLIHub(cli.cfg).NewCommand()) + cmd.AddCommand(NewCLIMetrics(cli.cfg).NewCommand()) + cmd.AddCommand(NewCLIDashboard(cli.cfg).NewCommand()) + cmd.AddCommand(NewCLIDecisions(cli.cfg).NewCommand()) cmd.AddCommand(NewCLIAlerts().NewCommand()) - cmd.AddCommand(NewCLISimulation(getconfig).NewCommand()) - cmd.AddCommand(NewCLIBouncers(getconfig).NewCommand()) - cmd.AddCommand(NewCLIMachines(getconfig).NewCommand()) + cmd.AddCommand(NewCLISimulation(cli.cfg).NewCommand()) + cmd.AddCommand(NewCLIBouncers(cli.cfg).NewCommand()) + cmd.AddCommand(NewCLIMachines(cli.cfg).NewCommand()) cmd.AddCommand(NewCLICapi().NewCommand()) cmd.AddCommand(NewLapiCmd()) cmd.AddCommand(NewCompletionCmd()) @@ -210,7 +248,7 @@ It is meant to allow you to manage bans, parsers/scenarios/etc, api and generall cmd.AddCommand(NewCLIHubTest().NewCommand()) cmd.AddCommand(NewCLINotifications().NewCommand()) cmd.AddCommand(NewCLISupport().NewCommand()) - cmd.AddCommand(NewCLIPapi(getconfig).NewCommand()) + cmd.AddCommand(NewCLIPapi(cli.cfg).NewCommand()) cmd.AddCommand(NewCLICollection().NewCommand()) cmd.AddCommand(NewCLIParser().NewCommand()) cmd.AddCommand(NewCLIScenario().NewCommand()) @@ -223,6 +261,11 @@ It is meant to allow you to manage bans, parsers/scenarios/etc, api and generall cmd.AddCommand(NewSetupCmd()) } + return cmd +} + +func main() { + cmd := newCliRoot().NewCommand() if err := cmd.Execute(); err != nil { log.Fatal(err) } diff --git a/cmd/crowdsec-cli/metrics.go b/cmd/crowdsec-cli/metrics.go index ad255e847db..6e23bcf12e4 100644 --- a/cmd/crowdsec-cli/metrics.go +++ b/cmd/crowdsec-cli/metrics.go @@ -2,6 +2,7 @@ package main import ( "encoding/json" + "errors" "fmt" "io" "net/http" @@ -16,42 +17,64 @@ import ( "github.com/spf13/cobra" "gopkg.in/yaml.v3" + "github.com/crowdsecurity/go-cs-lib/maptools" "github.com/crowdsecurity/go-cs-lib/trace" ) type ( - statAcquis map[string]map[string]int - statParser map[string]map[string]int - statBucket map[string]map[string]int - statLapi map[string]map[string]int - statLapiMachine map[string]map[string]map[string]int - statLapiBouncer map[string]map[string]map[string]int + statAcquis map[string]map[string]int + statParser map[string]map[string]int + statBucket map[string]map[string]int + statWhitelist map[string]map[string]map[string]int + statLapi map[string]map[string]int + statLapiMachine map[string]map[string]map[string]int + statLapiBouncer map[string]map[string]map[string]int statLapiDecision map[string]struct { NonEmpty int Empty int } - statDecision map[string]map[string]map[string]int + statDecision map[string]map[string]map[string]int statAppsecEngine map[string]map[string]int - statAppsecRule map[string]map[string]map[string]int - statAlert map[string]int - statStash map[string]struct { + statAppsecRule map[string]map[string]map[string]int + statAlert map[string]int + statStash map[string]struct { Type string Count int } ) -type cliMetrics struct { - cfg configGetter +var ( + ErrMissingConfig = errors.New("prometheus section missing, can't show metrics") + ErrMetricsDisabled = errors.New("prometheus is not enabled, can't show metrics") + +) + +type metricSection interface { + Table(out io.Writer, noUnit bool, showEmpty bool) + Description() (string, string) } -func NewCLIMetrics(getconfig configGetter) *cliMetrics { - return &cliMetrics{ - cfg: getconfig, +type metricStore map[string]metricSection + +func NewMetricStore() metricStore { + return metricStore{ + "acquisition": statAcquis{}, + "buckets": statBucket{}, + "parsers": statParser{}, + "lapi": statLapi{}, + "lapi-machine": statLapiMachine{}, + "lapi-bouncer": statLapiBouncer{}, + "lapi-decisions": statLapiDecision{}, + "decisions": statDecision{}, + "alerts": statAlert{}, + "stash": statStash{}, + "appsec-engine": statAppsecEngine{}, + "appsec-rule": statAppsecRule{}, + "whitelists": statWhitelist{}, } } -// FormatPrometheusMetrics is a complete rip from prom2json -func FormatPrometheusMetrics(out io.Writer, url string, formatType string, noUnit bool) error { +func (ms metricStore) Fetch(url string) error { mfChan := make(chan *dto.MetricFamily, 1024) errChan := make(chan error, 1) @@ -64,9 +87,10 @@ func FormatPrometheusMetrics(out io.Writer, url string, formatType string, noUni transport.ResponseHeaderTimeout = time.Minute go func() { defer trace.CatchPanic("crowdsec/ShowPrometheus") + err := prom2json.FetchMetricFamilies(url, mfChan, transport) if err != nil { - errChan <- fmt.Errorf("failed to fetch prometheus metrics: %w", err) + errChan <- fmt.Errorf("failed to fetch metrics: %w", err) return } errChan <- nil @@ -81,37 +105,42 @@ func FormatPrometheusMetrics(out io.Writer, url string, formatType string, noUni return err } - log.Debugf("Finished reading prometheus output, %d entries", len(result)) + log.Debugf("Finished reading metrics output, %d entries", len(result)) /*walk*/ - mAcquis := statAcquis{} - mParser := statParser{} - mBucket := statBucket{} - mLapi := statLapi{} - mLapiMachine := statLapiMachine{} - mLapiBouncer := statLapiBouncer{} - mLapiDecision := statLapiDecision{} - mDecision := statDecision{} - mAppsecEngine := statAppsecEngine{} - mAppsecRule := statAppsecRule{} - mAlert := statAlert{} - mStash := statStash{} + mAcquis := ms["acquisition"].(statAcquis) + mParser := ms["parsers"].(statParser) + mBucket := ms["buckets"].(statBucket) + mLapi := ms["lapi"].(statLapi) + mLapiMachine := ms["lapi-machine"].(statLapiMachine) + mLapiBouncer := ms["lapi-bouncer"].(statLapiBouncer) + mLapiDecision := ms["lapi-decisions"].(statLapiDecision) + mDecision := ms["decisions"].(statDecision) + mAppsecEngine := ms["appsec-engine"].(statAppsecEngine) + mAppsecRule := ms["appsec-rule"].(statAppsecRule) + mAlert := ms["alerts"].(statAlert) + mStash := ms["stash"].(statStash) + mWhitelist := ms["whitelists"].(statWhitelist) for idx, fam := range result { if !strings.HasPrefix(fam.Name, "cs_") { continue } + log.Tracef("round %d", idx) + for _, m := range fam.Metrics { metric, ok := m.(prom2json.Metric) if !ok { log.Debugf("failed to convert metric to prom2json.Metric") continue } + name, ok := metric.Labels["name"] if !ok { log.Debugf("no name in Metric %v", metric.Labels) } + source, ok := metric.Labels["source"] if !ok { log.Debugf("no source in Metric %v for %s", metric.Labels, fam.Name) @@ -132,148 +161,89 @@ func FormatPrometheusMetrics(out io.Writer, url string, formatType string, noUni origin := metric.Labels["origin"] action := metric.Labels["action"] + appsecEngine := metric.Labels["appsec_engine"] + appsecRule := metric.Labels["rule_name"] + mtype := metric.Labels["type"] fval, err := strconv.ParseFloat(value, 32) if err != nil { log.Errorf("Unexpected int value %s : %s", value, err) } + ival := int(fval) + switch fam.Name { - /*buckets*/ + // + // buckets + // case "cs_bucket_created_total": - if _, ok := mBucket[name]; !ok { - mBucket[name] = make(map[string]int) - } - mBucket[name]["instantiation"] += ival + mBucket.Process(name, "instantiation", ival) case "cs_buckets": - if _, ok := mBucket[name]; !ok { - mBucket[name] = make(map[string]int) - } - mBucket[name]["curr_count"] += ival + mBucket.Process(name, "curr_count", ival) case "cs_bucket_overflowed_total": - if _, ok := mBucket[name]; !ok { - mBucket[name] = make(map[string]int) - } - mBucket[name]["overflow"] += ival + mBucket.Process(name, "overflow", ival) case "cs_bucket_poured_total": - if _, ok := mBucket[name]; !ok { - mBucket[name] = make(map[string]int) - } - if _, ok := mAcquis[source]; !ok { - mAcquis[source] = make(map[string]int) - } - mBucket[name]["pour"] += ival - mAcquis[source]["pour"] += ival + mBucket.Process(name, "pour", ival) + mAcquis.Process(source, "pour", ival) case "cs_bucket_underflowed_total": - if _, ok := mBucket[name]; !ok { - mBucket[name] = make(map[string]int) - } - mBucket[name]["underflow"] += ival - /*acquis*/ + mBucket.Process(name, "underflow", ival) + // + // parsers + // case "cs_parser_hits_total": - if _, ok := mAcquis[source]; !ok { - mAcquis[source] = make(map[string]int) - } - mAcquis[source]["reads"] += ival + mAcquis.Process(source, "reads", ival) case "cs_parser_hits_ok_total": - if _, ok := mAcquis[source]; !ok { - mAcquis[source] = make(map[string]int) - } - mAcquis[source]["parsed"] += ival + mAcquis.Process(source, "parsed", ival) case "cs_parser_hits_ko_total": - if _, ok := mAcquis[source]; !ok { - mAcquis[source] = make(map[string]int) - } - mAcquis[source]["unparsed"] += ival + mAcquis.Process(source, "unparsed", ival) case "cs_node_hits_total": - if _, ok := mParser[name]; !ok { - mParser[name] = make(map[string]int) - } - mParser[name]["hits"] += ival + mParser.Process(name, "hits", ival) case "cs_node_hits_ok_total": - if _, ok := mParser[name]; !ok { - mParser[name] = make(map[string]int) - } - mParser[name]["parsed"] += ival + mParser.Process(name, "parsed", ival) case "cs_node_hits_ko_total": - if _, ok := mParser[name]; !ok { - mParser[name] = make(map[string]int) - } - mParser[name]["unparsed"] += ival + mParser.Process(name, "unparsed", ival) + // + // whitelists + // + case "cs_node_wl_hits_total": + mWhitelist.Process(name, reason, "hits", ival) + case "cs_node_wl_hits_ok_total": + mWhitelist.Process(name, reason, "whitelisted", ival) + // track as well whitelisted lines at acquis level + mAcquis.Process(source, "whitelisted", ival) + // + // lapi + // case "cs_lapi_route_requests_total": - if _, ok := mLapi[route]; !ok { - mLapi[route] = make(map[string]int) - } - mLapi[route][method] += ival + mLapi.Process(route, method, ival) case "cs_lapi_machine_requests_total": - if _, ok := mLapiMachine[machine]; !ok { - mLapiMachine[machine] = make(map[string]map[string]int) - } - if _, ok := mLapiMachine[machine][route]; !ok { - mLapiMachine[machine][route] = make(map[string]int) - } - mLapiMachine[machine][route][method] += ival + mLapiMachine.Process(machine, route, method, ival) case "cs_lapi_bouncer_requests_total": - if _, ok := mLapiBouncer[bouncer]; !ok { - mLapiBouncer[bouncer] = make(map[string]map[string]int) - } - if _, ok := mLapiBouncer[bouncer][route]; !ok { - mLapiBouncer[bouncer][route] = make(map[string]int) - } - mLapiBouncer[bouncer][route][method] += ival + mLapiBouncer.Process(bouncer, route, method, ival) case "cs_lapi_decisions_ko_total", "cs_lapi_decisions_ok_total": - if _, ok := mLapiDecision[bouncer]; !ok { - mLapiDecision[bouncer] = struct { - NonEmpty int - Empty int - }{} - } - x := mLapiDecision[bouncer] - if fam.Name == "cs_lapi_decisions_ko_total" { - x.Empty += ival - } else if fam.Name == "cs_lapi_decisions_ok_total" { - x.NonEmpty += ival - } - mLapiDecision[bouncer] = x + mLapiDecision.Process(bouncer, fam.Name, ival) + // + // decisions + // case "cs_active_decisions": - if _, ok := mDecision[reason]; !ok { - mDecision[reason] = make(map[string]map[string]int) - } - if _, ok := mDecision[reason][origin]; !ok { - mDecision[reason][origin] = make(map[string]int) - } - mDecision[reason][origin][action] += ival + mDecision.Process(reason, origin, action, ival) case "cs_alerts": - /*if _, ok := mAlert[scenario]; !ok { - mAlert[scenario] = make(map[string]int) - }*/ - mAlert[reason] += ival + mAlert.Process(reason, ival) + // + // stash + // case "cs_cache_size": - mStash[name] = struct { - Type string - Count int - }{Type: mtype, Count: ival} + mStash.Process(name, mtype, ival) + // + // appsec + // case "cs_appsec_reqs_total": - if _, ok := mAppsecEngine[metric.Labels["appsec_engine"]]; !ok { - mAppsecEngine[metric.Labels["appsec_engine"]] = make(map[string]int, 0) - } - mAppsecEngine[metric.Labels["appsec_engine"]]["processed"] = ival + mAppsecEngine.Process(appsecEngine, "processed", ival) case "cs_appsec_block_total": - if _, ok := mAppsecEngine[metric.Labels["appsec_engine"]]; !ok { - mAppsecEngine[metric.Labels["appsec_engine"]] = make(map[string]int, 0) - } - mAppsecEngine[metric.Labels["appsec_engine"]]["blocked"] = ival + mAppsecEngine.Process(appsecEngine, "blocked", ival) case "cs_appsec_rule_hits": - appsecEngine := metric.Labels["appsec_engine"] - ruleID := metric.Labels["rule_name"] - if _, ok := mAppsecRule[appsecEngine]; !ok { - mAppsecRule[appsecEngine] = make(map[string]map[string]int, 0) - } - if _, ok := mAppsecRule[appsecEngine][ruleID]; !ok { - mAppsecRule[appsecEngine][ruleID] = make(map[string]int, 0) - } - mAppsecRule[appsecEngine][ruleID]["triggered"] = ival + mAppsecRule.Process(appsecEngine, appsecRule, "triggered", ival) default: log.Debugf("unknown: %+v", fam.Name) continue @@ -281,46 +251,52 @@ func FormatPrometheusMetrics(out io.Writer, url string, formatType string, noUni } } - if formatType == "human" { - mAcquis.table(out, noUnit) - mBucket.table(out, noUnit) - mParser.table(out, noUnit) - mLapi.table(out) - mLapiMachine.table(out) - mLapiBouncer.table(out) - mLapiDecision.table(out) - mDecision.table(out) - mAlert.table(out) - mStash.table(out) - mAppsecEngine.table(out, noUnit) - mAppsecRule.table(out, noUnit) - return nil + return nil +} + +type cliMetrics struct { + cfg configGetter +} + +func NewCLIMetrics(cfg configGetter) *cliMetrics { + return &cliMetrics{ + cfg: cfg, } +} - stats := make(map[string]any) +func (ms metricStore) Format(out io.Writer, sections []string, formatType string, noUnit bool) error { + // copy only the sections we want + want := map[string]metricSection{} + + // if explicitly asking for sections, we want to show empty tables + showEmpty := len(sections) > 0 + + // if no sections are specified, we want all of them + if len(sections) == 0 { + for section := range ms { + sections = append(sections, section) + } + } - stats["acquisition"] = mAcquis - stats["buckets"] = mBucket - stats["parsers"] = mParser - stats["lapi"] = mLapi - stats["lapi_machine"] = mLapiMachine - stats["lapi_bouncer"] = mLapiBouncer - stats["lapi_decisions"] = mLapiDecision - stats["decisions"] = mDecision - stats["alerts"] = mAlert - stats["stash"] = mStash + for _, section := range sections { + want[section] = ms[section] + } switch formatType { + case "human": + for section := range want { + want[section].Table(out, noUnit, showEmpty) + } case "json": - x, err := json.MarshalIndent(stats, "", " ") + x, err := json.MarshalIndent(want, "", " ") if err != nil { - return fmt.Errorf("failed to unmarshal metrics : %v", err) + return fmt.Errorf("failed to marshal metrics: %w", err) } out.Write(x) case "raw": - x, err := yaml.Marshal(stats) + x, err := yaml.Marshal(want) if err != nil { - return fmt.Errorf("failed to unmarshal metrics : %v", err) + return fmt.Errorf("failed to marshal metrics: %w", err) } out.Write(x) default: @@ -330,7 +306,7 @@ func FormatPrometheusMetrics(out io.Writer, url string, formatType string, noUni return nil } -func (cli *cliMetrics) run(url string, noUnit bool) error { +func (cli *cliMetrics) show(sections []string, url string, noUnit bool) error { cfg := cli.cfg() if url != "" { @@ -338,33 +314,55 @@ func (cli *cliMetrics) run(url string, noUnit bool) error { } if cfg.Prometheus == nil { - return fmt.Errorf("prometheus section missing, can't show metrics") + return ErrMissingConfig } if !cfg.Prometheus.Enabled { - return fmt.Errorf("prometheus is not enabled, can't show metrics") + return ErrMetricsDisabled + } + + ms := NewMetricStore() + + if err := ms.Fetch(cfg.Cscli.PrometheusUrl); err != nil { + return err } - if err := FormatPrometheusMetrics(color.Output, cfg.Cscli.PrometheusUrl, cfg.Cscli.Output, noUnit); err != nil { + // any section that we don't have in the store is an error + for _, section := range sections { + if _, ok := ms[section]; !ok { + return fmt.Errorf("unknown metrics type: %s", section) + } + } + + if err := ms.Format(color.Output, sections, cfg.Cscli.Output, noUnit); err != nil { return err } + return nil } func (cli *cliMetrics) NewCommand() *cobra.Command { var ( - url string + url string noUnit bool ) cmd := &cobra.Command{ - Use: "metrics", - Short: "Display crowdsec prometheus metrics.", - Long: `Fetch metrics from the prometheus server and display them in a human-friendly way`, + Use: "metrics", + Short: "Display crowdsec prometheus metrics.", + Long: `Fetch metrics from a Local API server and display them`, + Example: `# Show all Metrics, skip empty tables (same as "cecli metrics show") +cscli metrics + +# Show only some metrics, connect to a different url +cscli metrics --url http://lapi.local:6060/metrics show acquisition parsers + +# List available metric types +cscli metrics list`, Args: cobra.ExactArgs(0), DisableAutoGenTag: true, RunE: func(cmd *cobra.Command, args []string) error { - return cli.run(url, noUnit) + return cli.show(nil, url, noUnit) }, } @@ -372,5 +370,126 @@ func (cli *cliMetrics) NewCommand() *cobra.Command { flags.StringVarP(&url, "url", "u", "", "Prometheus url (http://:/metrics)") flags.BoolVar(&noUnit, "no-unit", false, "Show the real number instead of formatted with units") + cmd.AddCommand(cli.newShowCmd()) + cmd.AddCommand(cli.newListCmd()) + + return cmd +} + +// expandAlias returns a list of sections. The input can be a list of sections or alias. +func (cli *cliMetrics) expandSectionGroups(args []string) []string { + ret := []string{} + + for _, section := range args { + switch section { + case "engine": + ret = append(ret, "acquisition", "parsers", "buckets", "stash", "whitelists") + case "lapi": + ret = append(ret, "alerts", "decisions", "lapi", "lapi-bouncer", "lapi-decisions", "lapi-machine") + case "appsec": + ret = append(ret, "appsec-engine", "appsec-rule") + default: + ret = append(ret, section) + } + } + + return ret +} + +func (cli *cliMetrics) newShowCmd() *cobra.Command { + var ( + url string + noUnit bool + ) + + cmd := &cobra.Command{ + Use: "show [type]...", + Short: "Display all or part of the available metrics.", + Long: `Fetch metrics from a Local API server and display them, optionally filtering on specific types.`, + Example: `# Show all Metrics, skip empty tables +cscli metrics show + +# Use an alias: "engine", "lapi" or "appsec" to show a group of metrics +cscli metrics show engine + +# Show some specific metrics, show empty tables, connect to a different url +cscli metrics show acquisition parsers buckets stash --url http://lapi.local:6060/metrics + +# Show metrics in json format +cscli metrics show acquisition parsers buckets stash -o json`, + // Positional args are optional + DisableAutoGenTag: true, + RunE: func(_ *cobra.Command, args []string) error { + args = cli.expandSectionGroups(args) + return cli.show(args, url, noUnit) + }, + } + + flags := cmd.Flags() + flags.StringVarP(&url, "url", "u", "", "Metrics url (http://:/metrics)") + flags.BoolVar(&noUnit, "no-unit", false, "Show the real number instead of formatted with units") + + return cmd +} + +func (cli *cliMetrics) list() error { + type metricType struct { + Type string `json:"type" yaml:"type"` + Title string `json:"title" yaml:"title"` + Description string `json:"description" yaml:"description"` + } + + var allMetrics []metricType + + ms := NewMetricStore() + for _, section := range maptools.SortedKeys(ms) { + title, description := ms[section].Description() + allMetrics = append(allMetrics, metricType{ + Type: section, + Title: title, + Description: description, + }) + } + + switch cli.cfg().Cscli.Output { + case "human": + t := newTable(color.Output) + t.SetRowLines(true) + t.SetHeaders("Type", "Title", "Description") + + for _, metric := range allMetrics { + t.AddRow(metric.Type, metric.Title, metric.Description) + } + + t.Render() + case "json": + x, err := json.MarshalIndent(allMetrics, "", " ") + if err != nil { + return fmt.Errorf("failed to marshal metric types: %w", err) + } + fmt.Println(string(x)) + case "raw": + x, err := yaml.Marshal(allMetrics) + if err != nil { + return fmt.Errorf("failed to marshal metric types: %w", err) + } + fmt.Println(string(x)) + } + + return nil +} + +func (cli *cliMetrics) newListCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "list", + Short: "List available types of metrics.", + Long: `List available types of metrics.`, + Args: cobra.ExactArgs(0), + DisableAutoGenTag: true, + RunE: func(_ *cobra.Command, _ []string) error { + return cli.list() + }, + } + return cmd } diff --git a/cmd/crowdsec-cli/metrics_table.go b/cmd/crowdsec-cli/metrics_table.go index 835277aa4ee..da6ea3d9f1d 100644 --- a/cmd/crowdsec-cli/metrics_table.go +++ b/cmd/crowdsec-cli/metrics_table.go @@ -4,22 +4,29 @@ import ( "fmt" "io" "sort" + "strconv" "github.com/aquasecurity/table" log "github.com/sirupsen/logrus" + + "github.com/crowdsecurity/go-cs-lib/maptools" ) +// ErrNilTable means a nil pointer was passed instead of a table instance. This is a programming error. +var ErrNilTable = fmt.Errorf("nil table") + func lapiMetricsToTable(t *table.Table, stats map[string]map[string]map[string]int) int { // stats: machine -> route -> method -> count - // sort keys to keep consistent order when printing machineKeys := []string{} for k := range stats { machineKeys = append(machineKeys, k) } + sort.Strings(machineKeys) numRows := 0 + for _, machine := range machineKeys { // oneRow: route -> method -> count machineRow := stats[machine] @@ -31,41 +38,77 @@ func lapiMetricsToTable(t *table.Table, stats map[string]map[string]map[string]i methodName, } if count != 0 { - row = append(row, fmt.Sprintf("%d", count)) + row = append(row, strconv.Itoa(count)) } else { row = append(row, "-") } + t.AddRow(row...) numRows++ } } } + return numRows } -func metricsToTable(t *table.Table, stats map[string]map[string]int, keys []string, noUnit bool) (int, error) { +func wlMetricsToTable(t *table.Table, stats map[string]map[string]map[string]int, noUnit bool) (int, error) { if t == nil { - return 0, fmt.Errorf("nil table") + return 0, ErrNilTable } - // sort keys to keep consistent order when printing - sortedKeys := []string{} - for k := range stats { - sortedKeys = append(sortedKeys, k) + + numRows := 0 + + for _, name := range maptools.SortedKeys(stats) { + for _, reason := range maptools.SortedKeys(stats[name]) { + row := []string{ + name, + reason, + "-", + "-", + } + + for _, action := range maptools.SortedKeys(stats[name][reason]) { + value := stats[name][reason][action] + + switch action { + case "whitelisted": + row[3] = strconv.Itoa(value) + case "hits": + row[2] = strconv.Itoa(value) + default: + log.Debugf("unexpected counter '%s' for whitelists = %d", action, value) + } + } + + t.AddRow(row...) + numRows++ + } + } + + return numRows, nil +} + +func metricsToTable(t *table.Table, stats map[string]map[string]int, keys []string, noUnit bool) (int, error) { + if t == nil { + return 0, ErrNilTable } - sort.Strings(sortedKeys) numRows := 0 - for _, alabel := range sortedKeys { + + for _, alabel := range maptools.SortedKeys(stats) { astats, ok := stats[alabel] if !ok { continue } + row := []string{ alabel, } + for _, sl := range keys { if v, ok := astats[sl]; ok && v != 0 { - numberToShow := fmt.Sprintf("%d", v) + numberToShow := strconv.Itoa(v) if !noUnit { numberToShow = formatNumber(v) } @@ -75,13 +118,29 @@ func metricsToTable(t *table.Table, stats map[string]map[string]int, keys []stri row = append(row, "-") } } + t.AddRow(row...) numRows++ } + return numRows, nil } -func (s statBucket) table(out io.Writer, noUnit bool) { +func (s statBucket) Description() (string, string) { + return "Bucket Metrics", + `Measure events in different scenarios. Current count is the number of buckets during metrics collection. ` + + `Overflows are past event-producing buckets, while Expired are the ones that didn’t receive enough events to Overflow.` +} + +func (s statBucket) Process(bucket, metric string, val int) { + if _, ok := s[bucket]; !ok { + s[bucket] = make(map[string]int) + } + + s[bucket][metric] += val +} + +func (s statBucket) Table(out io.Writer, noUnit bool, showEmpty bool) { t := newTable(out) t.SetRowLines(false) t.SetHeaders("Bucket", "Current Count", "Overflows", "Instantiated", "Poured", "Expired") @@ -91,60 +150,159 @@ func (s statBucket) table(out io.Writer, noUnit bool) { if numRows, err := metricsToTable(t, s, keys, noUnit); err != nil { log.Warningf("while collecting bucket stats: %s", err) - } else if numRows > 0 { - renderTableTitle(out, "\nBucket Metrics:") + } else if numRows > 0 || showEmpty { + title, _ := s.Description() + renderTableTitle(out, "\n"+title+":") t.Render() } } -func (s statAcquis) table(out io.Writer, noUnit bool) { +func (s statAcquis) Description() (string, string) { + return "Acquisition Metrics", + `Measures the lines read, parsed, and unparsed per datasource. ` + + `Zero read lines indicate a misconfigured or inactive datasource. ` + + `Zero parsed lines mean the parser(s) failed. ` + + `Non-zero parsed lines are fine as crowdsec selects relevant lines.` +} + +func (s statAcquis) Process(source, metric string, val int) { + if _, ok := s[source]; !ok { + s[source] = make(map[string]int) + } + + s[source][metric] += val +} + +func (s statAcquis) Table(out io.Writer, noUnit bool, showEmpty bool) { t := newTable(out) t.SetRowLines(false) - t.SetHeaders("Source", "Lines read", "Lines parsed", "Lines unparsed", "Lines poured to bucket") + t.SetHeaders("Source", "Lines read", "Lines parsed", "Lines unparsed", "Lines poured to bucket", "Lines whitelisted") t.SetAlignment(table.AlignLeft, table.AlignLeft, table.AlignLeft, table.AlignLeft, table.AlignLeft) - keys := []string{"reads", "parsed", "unparsed", "pour"} + keys := []string{"reads", "parsed", "unparsed", "pour", "whitelisted"} if numRows, err := metricsToTable(t, s, keys, noUnit); err != nil { log.Warningf("while collecting acquis stats: %s", err) - } else if numRows > 0 { - renderTableTitle(out, "\nAcquisition Metrics:") + } else if numRows > 0 || showEmpty { + title, _ := s.Description() + renderTableTitle(out, "\n"+title+":") t.Render() } } -func (s statAppsecEngine) table(out io.Writer, noUnit bool) { +func (s statAppsecEngine) Description() (string, string) { + return "Appsec Metrics", + `Measures the number of parsed and blocked requests by the AppSec Component.` +} + +func (s statAppsecEngine) Process(appsecEngine, metric string, val int) { + if _, ok := s[appsecEngine]; !ok { + s[appsecEngine] = make(map[string]int) + } + + s[appsecEngine][metric] += val +} + +func (s statAppsecEngine) Table(out io.Writer, noUnit bool, showEmpty bool) { t := newTable(out) t.SetRowLines(false) t.SetHeaders("Appsec Engine", "Processed", "Blocked") t.SetAlignment(table.AlignLeft, table.AlignLeft) + keys := []string{"processed", "blocked"} + if numRows, err := metricsToTable(t, s, keys, noUnit); err != nil { log.Warningf("while collecting appsec stats: %s", err) - } else if numRows > 0 { - renderTableTitle(out, "\nAppsec Metrics:") + } else if numRows > 0 || showEmpty { + title, _ := s.Description() + renderTableTitle(out, "\n"+title+":") t.Render() } } -func (s statAppsecRule) table(out io.Writer, noUnit bool) { +func (s statAppsecRule) Description() (string, string) { + return "Appsec Rule Metrics", + `Provides “per AppSec Component” information about the number of matches for loaded AppSec Rules.` +} + +func (s statAppsecRule) Process(appsecEngine, appsecRule string, metric string, val int) { + if _, ok := s[appsecEngine]; !ok { + s[appsecEngine] = make(map[string]map[string]int) + } + + if _, ok := s[appsecEngine][appsecRule]; !ok { + s[appsecEngine][appsecRule] = make(map[string]int) + } + + s[appsecEngine][appsecRule][metric] += val +} + +func (s statAppsecRule) Table(out io.Writer, noUnit bool, showEmpty bool) { for appsecEngine, appsecEngineRulesStats := range s { t := newTable(out) t.SetRowLines(false) t.SetHeaders("Rule ID", "Triggered") t.SetAlignment(table.AlignLeft, table.AlignLeft) + keys := []string{"triggered"} + if numRows, err := metricsToTable(t, appsecEngineRulesStats, keys, noUnit); err != nil { log.Warningf("while collecting appsec rules stats: %s", err) - } else if numRows > 0 { + } else if numRows > 0 || showEmpty { renderTableTitle(out, fmt.Sprintf("\nAppsec '%s' Rules Metrics:", appsecEngine)) t.Render() } } +} + +func (s statWhitelist) Description() (string, string) { + return "Whitelist Metrics", + `Tracks the number of events processed and possibly whitelisted by each parser whitelist.` +} + +func (s statWhitelist) Process(whitelist, reason, metric string, val int) { + if _, ok := s[whitelist]; !ok { + s[whitelist] = make(map[string]map[string]int) + } + + if _, ok := s[whitelist][reason]; !ok { + s[whitelist][reason] = make(map[string]int) + } + + s[whitelist][reason][metric] += val +} + +func (s statWhitelist) Table(out io.Writer, noUnit bool, showEmpty bool) { + t := newTable(out) + t.SetRowLines(false) + t.SetHeaders("Whitelist", "Reason", "Hits", "Whitelisted") + t.SetAlignment(table.AlignLeft, table.AlignLeft, table.AlignLeft, table.AlignLeft) + + if numRows, err := wlMetricsToTable(t, s, noUnit); err != nil { + log.Warningf("while collecting parsers stats: %s", err) + } else if numRows > 0 || showEmpty { + title, _ := s.Description() + renderTableTitle(out, "\n"+title+":") + t.Render() + } +} +func (s statParser) Description() (string, string) { + return "Parser Metrics", + `Tracks the number of events processed by each parser and indicates success of failure. ` + + `Zero parsed lines means the parer(s) failed. ` + + `Non-zero unparsed lines are fine as crowdsec select relevant lines.` } -func (s statParser) table(out io.Writer, noUnit bool) { +func (s statParser) Process(parser, metric string, val int) { + if _, ok := s[parser]; !ok { + s[parser] = make(map[string]int) + } + + s[parser][metric] += val +} + +func (s statParser) Table(out io.Writer, noUnit bool, showEmpty bool) { t := newTable(out) t.SetRowLines(false) t.SetHeaders("Parsers", "Hits", "Parsed", "Unparsed") @@ -154,84 +312,124 @@ func (s statParser) table(out io.Writer, noUnit bool) { if numRows, err := metricsToTable(t, s, keys, noUnit); err != nil { log.Warningf("while collecting parsers stats: %s", err) - } else if numRows > 0 { - renderTableTitle(out, "\nParser Metrics:") + } else if numRows > 0 || showEmpty { + title, _ := s.Description() + renderTableTitle(out, "\n"+title+":") t.Render() } } -func (s statStash) table(out io.Writer) { +func (s statStash) Description() (string, string) { + return "Parser Stash Metrics", + `Tracks the status of stashes that might be created by various parsers and scenarios.` +} + +func (s statStash) Process(name, mtype string, val int) { + s[name] = struct { + Type string + Count int + }{ + Type: mtype, + Count: val, + } +} + +func (s statStash) Table(out io.Writer, noUnit bool, showEmpty bool) { t := newTable(out) t.SetRowLines(false) t.SetHeaders("Name", "Type", "Items") t.SetAlignment(table.AlignLeft, table.AlignLeft, table.AlignLeft) // unfortunately, we can't reuse metricsToTable as the structure is too different :/ - sortedKeys := []string{} - for k := range s { - sortedKeys = append(sortedKeys, k) - } - sort.Strings(sortedKeys) - numRows := 0 - for _, alabel := range sortedKeys { + + for _, alabel := range maptools.SortedKeys(s) { astats := s[alabel] row := []string{ alabel, astats.Type, - fmt.Sprintf("%d", astats.Count), + strconv.Itoa(astats.Count), } t.AddRow(row...) numRows++ } - if numRows > 0 { - renderTableTitle(out, "\nParser Stash Metrics:") + + if numRows > 0 || showEmpty { + title, _ := s.Description() + renderTableTitle(out, "\n"+title+":") t.Render() } } -func (s statLapi) table(out io.Writer) { +func (s statLapi) Description() (string, string) { + return "Local API Metrics", + `Monitors the requests made to local API routes.` +} + +func (s statLapi) Process(route, method string, val int) { + if _, ok := s[route]; !ok { + s[route] = make(map[string]int) + } + + s[route][method] += val +} + +func (s statLapi) Table(out io.Writer, noUnit bool, showEmpty bool) { t := newTable(out) t.SetRowLines(false) t.SetHeaders("Route", "Method", "Hits") t.SetAlignment(table.AlignLeft, table.AlignLeft, table.AlignLeft) // unfortunately, we can't reuse metricsToTable as the structure is too different :/ - sortedKeys := []string{} - for k := range s { - sortedKeys = append(sortedKeys, k) - } - sort.Strings(sortedKeys) - numRows := 0 - for _, alabel := range sortedKeys { + + for _, alabel := range maptools.SortedKeys(s) { astats := s[alabel] subKeys := []string{} for skey := range astats { subKeys = append(subKeys, skey) } + sort.Strings(subKeys) for _, sl := range subKeys { row := []string{ alabel, sl, - fmt.Sprintf("%d", astats[sl]), + strconv.Itoa(astats[sl]), } t.AddRow(row...) numRows++ } } - if numRows > 0 { - renderTableTitle(out, "\nLocal API Metrics:") + if numRows > 0 || showEmpty { + title, _ := s.Description() + renderTableTitle(out, "\n"+title+":") t.Render() } } -func (s statLapiMachine) table(out io.Writer) { +func (s statLapiMachine) Description() (string, string) { + return "Local API Machines Metrics", + `Tracks the number of calls to the local API from each registered machine.` +} + +func (s statLapiMachine) Process(machine, route, method string, val int) { + if _, ok := s[machine]; !ok { + s[machine] = make(map[string]map[string]int) + } + + if _, ok := s[machine][route]; !ok { + s[machine][route] = make(map[string]int) + } + + s[machine][route][method] += val +} + +func (s statLapiMachine) Table(out io.Writer, noUnit bool, showEmpty bool) { t := newTable(out) t.SetRowLines(false) t.SetHeaders("Machine", "Route", "Method", "Hits") @@ -239,13 +437,31 @@ func (s statLapiMachine) table(out io.Writer) { numRows := lapiMetricsToTable(t, s) - if numRows > 0 { - renderTableTitle(out, "\nLocal API Machines Metrics:") + if numRows > 0 || showEmpty { + title, _ := s.Description() + renderTableTitle(out, "\n"+title+":") t.Render() } } -func (s statLapiBouncer) table(out io.Writer) { +func (s statLapiBouncer) Description() (string, string) { + return "Local API Bouncers Metrics", + `Tracks total hits to remediation component related API routes.` +} + +func (s statLapiBouncer) Process(bouncer, route, method string, val int) { + if _, ok := s[bouncer]; !ok { + s[bouncer] = make(map[string]map[string]int) + } + + if _, ok := s[bouncer][route]; !ok { + s[bouncer][route] = make(map[string]int) + } + + s[bouncer][route][method] += val +} + +func (s statLapiBouncer) Table(out io.Writer, noUnit bool, showEmpty bool) { t := newTable(out) t.SetRowLines(false) t.SetHeaders("Bouncer", "Route", "Method", "Hits") @@ -253,41 +469,88 @@ func (s statLapiBouncer) table(out io.Writer) { numRows := lapiMetricsToTable(t, s) - if numRows > 0 { - renderTableTitle(out, "\nLocal API Bouncers Metrics:") + if numRows > 0 || showEmpty { + title, _ := s.Description() + renderTableTitle(out, "\n"+title+":") t.Render() } } -func (s statLapiDecision) table(out io.Writer) { +func (s statLapiDecision) Description() (string, string) { + return "Local API Bouncers Decisions", + `Tracks the number of empty/non-empty answers from LAPI to bouncers that are working in "live" mode.` +} + +func (s statLapiDecision) Process(bouncer, fam string, val int) { + if _, ok := s[bouncer]; !ok { + s[bouncer] = struct { + NonEmpty int + Empty int + }{} + } + + x := s[bouncer] + + switch fam { + case "cs_lapi_decisions_ko_total": + x.Empty += val + case "cs_lapi_decisions_ok_total": + x.NonEmpty += val + } + + s[bouncer] = x +} + +func (s statLapiDecision) Table(out io.Writer, noUnit bool, showEmpty bool) { t := newTable(out) t.SetRowLines(false) t.SetHeaders("Bouncer", "Empty answers", "Non-empty answers") t.SetAlignment(table.AlignLeft, table.AlignLeft, table.AlignLeft) numRows := 0 + for bouncer, hits := range s { t.AddRow( bouncer, - fmt.Sprintf("%d", hits.Empty), - fmt.Sprintf("%d", hits.NonEmpty), + strconv.Itoa(hits.Empty), + strconv.Itoa(hits.NonEmpty), ) numRows++ } - if numRows > 0 { - renderTableTitle(out, "\nLocal API Bouncers Decisions:") + if numRows > 0 || showEmpty { + title, _ := s.Description() + renderTableTitle(out, "\n"+title+":") t.Render() } } -func (s statDecision) table(out io.Writer) { +func (s statDecision) Description() (string, string) { + return "Local API Decisions", + `Provides information about all currently active decisions. ` + + `Includes both local (crowdsec) and global decisions (CAPI), and lists subscriptions (lists).` +} + +func (s statDecision) Process(reason, origin, action string, val int) { + if _, ok := s[reason]; !ok { + s[reason] = make(map[string]map[string]int) + } + + if _, ok := s[reason][origin]; !ok { + s[reason][origin] = make(map[string]int) + } + + s[reason][origin][action] += val +} + +func (s statDecision) Table(out io.Writer, noUnit bool, showEmpty bool) { t := newTable(out) t.SetRowLines(false) t.SetHeaders("Reason", "Origin", "Action", "Count") t.SetAlignment(table.AlignLeft, table.AlignLeft, table.AlignLeft, table.AlignLeft) numRows := 0 + for reason, origins := range s { for origin, actions := range origins { for action, hits := range actions { @@ -295,36 +558,48 @@ func (s statDecision) table(out io.Writer) { reason, origin, action, - fmt.Sprintf("%d", hits), + strconv.Itoa(hits), ) numRows++ } } } - if numRows > 0 { - renderTableTitle(out, "\nLocal API Decisions:") + if numRows > 0 || showEmpty { + title, _ := s.Description() + renderTableTitle(out, "\n"+title+":") t.Render() } } -func (s statAlert) table(out io.Writer) { +func (s statAlert) Description() (string, string) { + return "Local API Alerts", + `Tracks the total number of past and present alerts for the installed scenarios.` +} + +func (s statAlert) Process(reason string, val int) { + s[reason] += val +} + +func (s statAlert) Table(out io.Writer, noUnit bool, showEmpty bool) { t := newTable(out) t.SetRowLines(false) t.SetHeaders("Reason", "Count") t.SetAlignment(table.AlignLeft, table.AlignLeft) numRows := 0 + for scenario, hits := range s { t.AddRow( scenario, - fmt.Sprintf("%d", hits), + strconv.Itoa(hits), ) numRows++ } - if numRows > 0 { - renderTableTitle(out, "\nLocal API Alerts:") + if numRows > 0 || showEmpty { + title, _ := s.Description() + renderTableTitle(out, "\n"+title+":") t.Render() } } diff --git a/cmd/crowdsec-cli/papi.go b/cmd/crowdsec-cli/papi.go index 04223ef93ab..e18af94d4bb 100644 --- a/cmd/crowdsec-cli/papi.go +++ b/cmd/crowdsec-cli/papi.go @@ -10,19 +10,18 @@ import ( "github.com/crowdsecurity/go-cs-lib/ptr" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require" "github.com/crowdsecurity/crowdsec/pkg/apiserver" "github.com/crowdsecurity/crowdsec/pkg/database" - - "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require" ) type cliPapi struct { cfg configGetter } -func NewCLIPapi(getconfig configGetter) *cliPapi { +func NewCLIPapi(cfg configGetter) *cliPapi { return &cliPapi{ - cfg: getconfig, + cfg: cfg, } } @@ -43,6 +42,7 @@ func (cli *cliPapi) NewCommand() *cobra.Command { if err := require.PAPI(cfg); err != nil { return err } + return nil }, } diff --git a/cmd/crowdsec-cli/simulation.go b/cmd/crowdsec-cli/simulation.go index a6e710c5747..6ccac761727 100644 --- a/cmd/crowdsec-cli/simulation.go +++ b/cmd/crowdsec-cli/simulation.go @@ -3,23 +3,23 @@ package main import ( "fmt" "os" + "slices" log "github.com/sirupsen/logrus" "github.com/spf13/cobra" "gopkg.in/yaml.v2" - "slices" "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require" "github.com/crowdsecurity/crowdsec/pkg/cwhub" ) -type cliSimulation struct{ +type cliSimulation struct { cfg configGetter } -func NewCLISimulation(getconfig configGetter) *cliSimulation { +func NewCLISimulation(cfg configGetter) *cliSimulation { return &cliSimulation{ - cfg: getconfig, + cfg: cfg, } } @@ -38,6 +38,7 @@ cscli simulation disable crowdsecurity/ssh-bf`, if cli.cfg().Cscli.SimulationConfig == nil { return fmt.Errorf("no simulation configured") } + return nil }, PersistentPostRun: func(cmd *cobra.Command, _ []string) { diff --git a/cmd/crowdsec-cli/support.go b/cmd/crowdsec-cli/support.go index e0a2fa9db90..661950fa8f6 100644 --- a/cmd/crowdsec-cli/support.go +++ b/cmd/crowdsec-cli/support.go @@ -66,10 +66,15 @@ func collectMetrics() ([]byte, []byte, error) { } humanMetrics := bytes.NewBuffer(nil) - err := FormatPrometheusMetrics(humanMetrics, csConfig.Cscli.PrometheusUrl, "human", false) - if err != nil { - return nil, nil, fmt.Errorf("could not fetch promtheus metrics: %s", err) + ms := NewMetricStore() + + if err := ms.Fetch(csConfig.Cscli.PrometheusUrl); err != nil { + return nil, nil, fmt.Errorf("could not fetch prometheus metrics: %s", err) + } + + if err := ms.Format(humanMetrics, nil, "human", false); err != nil { + return nil, nil, err } req, err := http.NewRequest(http.MethodGet, csConfig.Cscli.PrometheusUrl, nil) diff --git a/cmd/crowdsec/metrics.go b/cmd/crowdsec/metrics.go index ca893872edb..fa2d8d5de32 100644 --- a/cmd/crowdsec/metrics.go +++ b/cmd/crowdsec/metrics.go @@ -161,7 +161,7 @@ func registerPrometheus(config *csconfig.PrometheusCfg) { leaky.BucketsUnderflow, leaky.BucketsCanceled, leaky.BucketsInstantiation, leaky.BucketsOverflow, v1.LapiRouteHits, leaky.BucketsCurrentCount, - cache.CacheMetrics, exprhelpers.RegexpCacheMetrics, + cache.CacheMetrics, exprhelpers.RegexpCacheMetrics, parser.NodesWlHitsOk, parser.NodesWlHits, ) } else { log.Infof("Loading prometheus collectors") @@ -170,7 +170,7 @@ func registerPrometheus(config *csconfig.PrometheusCfg) { globalCsInfo, globalParsingHistogram, globalPourHistogram, v1.LapiRouteHits, v1.LapiMachineHits, v1.LapiBouncerHits, v1.LapiNilDecisions, v1.LapiNonNilDecisions, v1.LapiResponseTime, leaky.BucketsPour, leaky.BucketsUnderflow, leaky.BucketsCanceled, leaky.BucketsInstantiation, leaky.BucketsOverflow, leaky.BucketsCurrentCount, - globalActiveDecisions, globalAlerts, + globalActiveDecisions, globalAlerts, parser.NodesWlHitsOk, parser.NodesWlHits, cache.CacheMetrics, exprhelpers.RegexpCacheMetrics, ) diff --git a/pkg/acquisition/modules/appsec/appsec.go b/pkg/acquisition/modules/appsec/appsec.go index 030724fc3e9..4e2ff0bd22b 100644 --- a/pkg/acquisition/modules/appsec/appsec.go +++ b/pkg/acquisition/modules/appsec/appsec.go @@ -354,15 +354,17 @@ func (w *AppsecSource) appsecHandler(rw http.ResponseWriter, r *http.Request) { w.InChan <- parsedRequest + /* + response is a copy of w.AppSecRuntime.Response that is safe to use. + As OutOfBand might still be running, the original one can be modified + */ response := <-parsedRequest.ResponseChannel - statusCode := http.StatusOK if response.InBandInterrupt { - statusCode = http.StatusForbidden AppsecBlockCounter.With(prometheus.Labels{"source": parsedRequest.RemoteAddrNormalized, "appsec_engine": parsedRequest.AppsecEngine}).Inc() } - appsecResponse := w.AppsecRuntime.GenerateResponse(response, logger) + statusCode, appsecResponse := w.AppsecRuntime.GenerateResponse(response, logger) logger.Debugf("Response: %+v", appsecResponse) rw.WriteHeader(statusCode) diff --git a/pkg/acquisition/modules/appsec/appsec_runner.go b/pkg/acquisition/modules/appsec/appsec_runner.go index a9d74aa8f63..cc7264aa2c8 100644 --- a/pkg/acquisition/modules/appsec/appsec_runner.go +++ b/pkg/acquisition/modules/appsec/appsec_runner.go @@ -226,7 +226,8 @@ func (r *AppsecRunner) handleInBandInterrupt(request *appsec.ParsedRequest) { if in := request.Tx.Interruption(); in != nil { r.logger.Debugf("inband rules matched : %d", in.RuleID) r.AppsecRuntime.Response.InBandInterrupt = true - r.AppsecRuntime.Response.HTTPResponseCode = r.AppsecRuntime.Config.BlockedHTTPCode + r.AppsecRuntime.Response.BouncerHTTPResponseCode = r.AppsecRuntime.Config.BouncerBlockedHTTPCode + r.AppsecRuntime.Response.UserHTTPResponseCode = r.AppsecRuntime.Config.UserBlockedHTTPCode r.AppsecRuntime.Response.Action = r.AppsecRuntime.DefaultRemediation if _, ok := r.AppsecRuntime.RemediationById[in.RuleID]; ok { @@ -252,7 +253,9 @@ func (r *AppsecRunner) handleInBandInterrupt(request *appsec.ParsedRequest) { r.logger.Errorf("unable to generate appsec event : %s", err) return } - r.outChan <- *appsecOvlfw + if appsecOvlfw != nil { + r.outChan <- *appsecOvlfw + } } // Should the in band match trigger an event ? diff --git a/pkg/acquisition/modules/appsec/appsec_test.go b/pkg/acquisition/modules/appsec/appsec_test.go index 2a58580137d..25aea0c78ea 100644 --- a/pkg/acquisition/modules/appsec/appsec_test.go +++ b/pkg/acquisition/modules/appsec/appsec_test.go @@ -1,6 +1,7 @@ package appsecacquisition import ( + "net/http" "net/url" "testing" "time" @@ -21,16 +22,21 @@ Missing tests (wip): */ type appsecRuleTest struct { - name string - expected_load_ok bool - inband_rules []appsec_rule.CustomRule - outofband_rules []appsec_rule.CustomRule - on_load []appsec.Hook - pre_eval []appsec.Hook - post_eval []appsec.Hook - on_match []appsec.Hook - input_request appsec.ParsedRequest - output_asserts func(events []types.Event, responses []appsec.AppsecTempResponse) + name string + expected_load_ok bool + inband_rules []appsec_rule.CustomRule + outofband_rules []appsec_rule.CustomRule + on_load []appsec.Hook + pre_eval []appsec.Hook + post_eval []appsec.Hook + on_match []appsec.Hook + BouncerBlockedHTTPCode int + UserBlockedHTTPCode int + UserPassedHTTPCode int + DefaultRemediation string + DefaultPassAction string + input_request appsec.ParsedRequest + output_asserts func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) } func TestAppsecOnMatchHooks(t *testing.T) { @@ -53,13 +59,14 @@ func TestAppsecOnMatchHooks(t *testing.T) { URI: "/urllll", Args: url.Values{"foo": []string{"toto"}}, }, - output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse) { + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { require.Len(t, events, 2) require.Equal(t, types.APPSEC, events[0].Type) require.Equal(t, types.LOG, events[1].Type) require.Len(t, responses, 1) - require.Equal(t, 403, responses[0].HTTPResponseCode) - require.Equal(t, "ban", responses[0].Action) + require.Equal(t, 403, responses[0].BouncerHTTPResponseCode) + require.Equal(t, 403, responses[0].UserHTTPResponseCode) + require.Equal(t, appsec.BanRemediation, responses[0].Action) }, }, @@ -84,17 +91,18 @@ func TestAppsecOnMatchHooks(t *testing.T) { URI: "/urllll", Args: url.Values{"foo": []string{"toto"}}, }, - output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse) { + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { require.Len(t, events, 2) require.Equal(t, types.APPSEC, events[0].Type) require.Equal(t, types.LOG, events[1].Type) require.Len(t, responses, 1) - require.Equal(t, 413, responses[0].HTTPResponseCode) - require.Equal(t, "ban", responses[0].Action) + require.Equal(t, 403, responses[0].BouncerHTTPResponseCode) + require.Equal(t, 413, responses[0].UserHTTPResponseCode) + require.Equal(t, appsec.BanRemediation, responses[0].Action) }, }, { - name: "on_match: change action to another standard one (log)", + name: "on_match: change action to a non standard one (log)", expected_load_ok: true, inband_rules: []appsec_rule.CustomRule{ { @@ -114,7 +122,7 @@ func TestAppsecOnMatchHooks(t *testing.T) { URI: "/urllll", Args: url.Values{"foo": []string{"toto"}}, }, - output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse) { + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { require.Len(t, events, 2) require.Equal(t, types.APPSEC, events[0].Type) require.Equal(t, types.LOG, events[1].Type) @@ -143,16 +151,16 @@ func TestAppsecOnMatchHooks(t *testing.T) { URI: "/urllll", Args: url.Values{"foo": []string{"toto"}}, }, - output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse) { + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { require.Len(t, events, 2) require.Equal(t, types.APPSEC, events[0].Type) require.Equal(t, types.LOG, events[1].Type) require.Len(t, responses, 1) - require.Equal(t, "allow", responses[0].Action) + require.Equal(t, appsec.AllowRemediation, responses[0].Action) }, }, { - name: "on_match: change action to another standard one (deny/ban/block)", + name: "on_match: change action to another standard one (ban)", expected_load_ok: true, inband_rules: []appsec_rule.CustomRule{ { @@ -164,7 +172,7 @@ func TestAppsecOnMatchHooks(t *testing.T) { }, }, on_match: []appsec.Hook{ - {Filter: "IsInBand == true", Apply: []string{"SetRemediation('deny')"}}, + {Filter: "IsInBand == true", Apply: []string{"SetRemediation('ban')"}}, }, input_request: appsec.ParsedRequest{ RemoteAddr: "1.2.3.4", @@ -172,10 +180,10 @@ func TestAppsecOnMatchHooks(t *testing.T) { URI: "/urllll", Args: url.Values{"foo": []string{"toto"}}, }, - output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse) { + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { require.Len(t, responses, 1) //note: SetAction normalizes deny, ban and block to ban - require.Equal(t, "ban", responses[0].Action) + require.Equal(t, appsec.BanRemediation, responses[0].Action) }, }, { @@ -199,10 +207,10 @@ func TestAppsecOnMatchHooks(t *testing.T) { URI: "/urllll", Args: url.Values{"foo": []string{"toto"}}, }, - output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse) { + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { require.Len(t, responses, 1) //note: SetAction normalizes deny, ban and block to ban - require.Equal(t, "captcha", responses[0].Action) + require.Equal(t, appsec.CaptchaRemediation, responses[0].Action) }, }, { @@ -226,7 +234,7 @@ func TestAppsecOnMatchHooks(t *testing.T) { URI: "/urllll", Args: url.Values{"foo": []string{"toto"}}, }, - output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse) { + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { require.Len(t, events, 2) require.Equal(t, types.APPSEC, events[0].Type) require.Equal(t, types.LOG, events[1].Type) @@ -255,11 +263,11 @@ func TestAppsecOnMatchHooks(t *testing.T) { URI: "/urllll", Args: url.Values{"foo": []string{"toto"}}, }, - output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse) { + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { require.Len(t, events, 1) require.Equal(t, types.LOG, events[0].Type) require.Len(t, responses, 1) - require.Equal(t, "ban", responses[0].Action) + require.Equal(t, appsec.BanRemediation, responses[0].Action) }, }, { @@ -283,11 +291,11 @@ func TestAppsecOnMatchHooks(t *testing.T) { URI: "/urllll", Args: url.Values{"foo": []string{"toto"}}, }, - output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse) { + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { require.Len(t, events, 1) require.Equal(t, types.APPSEC, events[0].Type) require.Len(t, responses, 1) - require.Equal(t, "ban", responses[0].Action) + require.Equal(t, appsec.BanRemediation, responses[0].Action) }, }, } @@ -328,7 +336,7 @@ func TestAppsecPreEvalHooks(t *testing.T) { URI: "/urllll", Args: url.Values{"foo": []string{"toto"}}, }, - output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse) { + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { require.Empty(t, events) require.Len(t, responses, 1) require.False(t, responses[0].InBandInterrupt) @@ -356,7 +364,7 @@ func TestAppsecPreEvalHooks(t *testing.T) { URI: "/urllll", Args: url.Values{"foo": []string{"toto"}}, }, - output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse) { + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { require.Len(t, events, 2) require.Equal(t, types.APPSEC, events[0].Type) @@ -391,7 +399,7 @@ func TestAppsecPreEvalHooks(t *testing.T) { URI: "/urllll", Args: url.Values{"foo": []string{"toto"}}, }, - output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse) { + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { require.Empty(t, events) require.Len(t, responses, 1) require.False(t, responses[0].InBandInterrupt) @@ -419,7 +427,7 @@ func TestAppsecPreEvalHooks(t *testing.T) { URI: "/urllll", Args: url.Values{"foo": []string{"toto"}}, }, - output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse) { + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { require.Empty(t, events) require.Len(t, responses, 1) require.False(t, responses[0].InBandInterrupt) @@ -447,7 +455,7 @@ func TestAppsecPreEvalHooks(t *testing.T) { URI: "/urllll", Args: url.Values{"foo": []string{"toto"}}, }, - output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse) { + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { require.Empty(t, events) require.Len(t, responses, 1) require.False(t, responses[0].InBandInterrupt) @@ -472,7 +480,7 @@ func TestAppsecPreEvalHooks(t *testing.T) { URI: "/urllll", Args: url.Values{"foo": []string{"toto"}}, }, - output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse) { + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { require.Len(t, events, 1) require.Equal(t, types.LOG, events[0].Type) require.True(t, events[0].Appsec.HasOutBandMatches) @@ -506,7 +514,7 @@ func TestAppsecPreEvalHooks(t *testing.T) { URI: "/urllll", Args: url.Values{"foo": []string{"toto"}}, }, - output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse) { + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { require.Len(t, events, 2) require.Len(t, responses, 1) require.Equal(t, "foobar", responses[0].Action) @@ -533,7 +541,7 @@ func TestAppsecPreEvalHooks(t *testing.T) { URI: "/urllll", Args: url.Values{"foo": []string{"toto"}}, }, - output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse) { + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { require.Len(t, events, 2) require.Len(t, responses, 1) require.Equal(t, "foobar", responses[0].Action) @@ -560,10 +568,12 @@ func TestAppsecPreEvalHooks(t *testing.T) { URI: "/urllll", Args: url.Values{"foo": []string{"toto"}}, }, - output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse) { + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { require.Len(t, events, 2) require.Len(t, responses, 1) require.Equal(t, "foobar", responses[0].Action) + require.Equal(t, "foobar", appsecResponse.Action) + require.Equal(t, http.StatusForbidden, appsecResponse.HTTPStatus) }, }, } @@ -574,6 +584,473 @@ func TestAppsecPreEvalHooks(t *testing.T) { }) } } + +func TestAppsecRemediationConfigHooks(t *testing.T) { + + tests := []appsecRuleTest{ + { + name: "Basic matching rule", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Equal(t, appsec.BanRemediation, responses[0].Action) + require.Equal(t, http.StatusForbidden, statusCode) + require.Equal(t, appsec.BanRemediation, appsecResponse.Action) + require.Equal(t, http.StatusForbidden, appsecResponse.HTTPStatus) + }, + }, + { + name: "SetRemediation", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + on_match: []appsec.Hook{{Apply: []string{"SetRemediation('captcha')"}}}, //rule ID is generated at runtime. If you change rule, it will break the test (: + + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Equal(t, appsec.CaptchaRemediation, responses[0].Action) + require.Equal(t, http.StatusForbidden, statusCode) + require.Equal(t, appsec.CaptchaRemediation, appsecResponse.Action) + require.Equal(t, http.StatusForbidden, appsecResponse.HTTPStatus) + }, + }, + { + name: "SetRemediation", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + on_match: []appsec.Hook{{Apply: []string{"SetReturnCode(418)"}}}, //rule ID is generated at runtime. If you change rule, it will break the test (: + + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Equal(t, appsec.BanRemediation, responses[0].Action) + require.Equal(t, http.StatusForbidden, statusCode) + require.Equal(t, appsec.BanRemediation, appsecResponse.Action) + require.Equal(t, http.StatusTeapot, appsecResponse.HTTPStatus) + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + loadAppSecEngine(test, t) + }) + } +} +func TestOnMatchRemediationHooks(t *testing.T) { + tests := []appsecRuleTest{ + { + name: "set remediation to allow with on_match hook", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule42", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + on_match: []appsec.Hook{ + {Filter: "IsInBand == true", Apply: []string{"SetRemediation('allow')"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Equal(t, appsec.AllowRemediation, appsecResponse.Action) + require.Equal(t, http.StatusOK, appsecResponse.HTTPStatus) + }, + }, + { + name: "set remediation to captcha + custom user code with on_match hook", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule42", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + DefaultRemediation: appsec.AllowRemediation, + on_match: []appsec.Hook{ + {Filter: "IsInBand == true", Apply: []string{"SetRemediation('captcha')", "SetReturnCode(418)"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + spew.Dump(responses) + spew.Dump(appsecResponse) + + log.Errorf("http status : %d", statusCode) + require.Equal(t, appsec.CaptchaRemediation, appsecResponse.Action) + require.Equal(t, http.StatusTeapot, appsecResponse.HTTPStatus) + require.Equal(t, http.StatusForbidden, statusCode) + }, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + loadAppSecEngine(test, t) + }) + } +} + +func TestAppsecDefaultPassRemediation(t *testing.T) { + + tests := []appsecRuleTest{ + { + name: "Basic non-matching rule", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/", + Args: url.Values{"foo": []string{"tutu"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Equal(t, appsec.AllowRemediation, responses[0].Action) + require.Equal(t, http.StatusOK, statusCode) + require.Equal(t, appsec.AllowRemediation, appsecResponse.Action) + require.Equal(t, http.StatusOK, appsecResponse.HTTPStatus) + }, + }, + { + name: "DefaultPassAction: pass", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/", + Args: url.Values{"foo": []string{"tutu"}}, + }, + DefaultPassAction: "allow", + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Equal(t, appsec.AllowRemediation, responses[0].Action) + require.Equal(t, http.StatusOK, statusCode) + require.Equal(t, appsec.AllowRemediation, appsecResponse.Action) + require.Equal(t, http.StatusOK, appsecResponse.HTTPStatus) + }, + }, + { + name: "DefaultPassAction: captcha", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/", + Args: url.Values{"foo": []string{"tutu"}}, + }, + DefaultPassAction: "captcha", + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Equal(t, appsec.CaptchaRemediation, responses[0].Action) + require.Equal(t, http.StatusOK, statusCode) //@tko: body is captcha, but as it's 200, captcha won't be showed to user + require.Equal(t, appsec.CaptchaRemediation, appsecResponse.Action) + require.Equal(t, http.StatusOK, appsecResponse.HTTPStatus) + }, + }, + { + name: "DefaultPassHTTPCode: 200", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/", + Args: url.Values{"foo": []string{"tutu"}}, + }, + UserPassedHTTPCode: 200, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Equal(t, appsec.AllowRemediation, responses[0].Action) + require.Equal(t, http.StatusOK, statusCode) + require.Equal(t, appsec.AllowRemediation, appsecResponse.Action) + require.Equal(t, http.StatusOK, appsecResponse.HTTPStatus) + }, + }, + { + name: "DefaultPassHTTPCode: 200", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/", + Args: url.Values{"foo": []string{"tutu"}}, + }, + UserPassedHTTPCode: 418, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Equal(t, appsec.AllowRemediation, responses[0].Action) + require.Equal(t, http.StatusOK, statusCode) + require.Equal(t, appsec.AllowRemediation, appsecResponse.Action) + require.Equal(t, http.StatusTeapot, appsecResponse.HTTPStatus) + }, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + loadAppSecEngine(test, t) + }) + } +} + +func TestAppsecDefaultRemediation(t *testing.T) { + + tests := []appsecRuleTest{ + { + name: "Basic matching rule", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Equal(t, appsec.BanRemediation, responses[0].Action) + require.Equal(t, http.StatusForbidden, statusCode) + require.Equal(t, appsec.BanRemediation, appsecResponse.Action) + require.Equal(t, http.StatusForbidden, appsecResponse.HTTPStatus) + }, + }, + { + name: "default remediation to ban (default)", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule42", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + DefaultRemediation: "ban", + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Equal(t, appsec.BanRemediation, responses[0].Action) + require.Equal(t, http.StatusForbidden, statusCode) + require.Equal(t, appsec.BanRemediation, appsecResponse.Action) + require.Equal(t, http.StatusForbidden, appsecResponse.HTTPStatus) + }, + }, + { + name: "default remediation to allow", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule42", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + DefaultRemediation: "allow", + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Equal(t, appsec.AllowRemediation, responses[0].Action) + require.Equal(t, http.StatusOK, statusCode) + require.Equal(t, appsec.AllowRemediation, appsecResponse.Action) + require.Equal(t, http.StatusOK, appsecResponse.HTTPStatus) + }, + }, + { + name: "default remediation to captcha", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule42", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + DefaultRemediation: "captcha", + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Equal(t, appsec.CaptchaRemediation, responses[0].Action) + require.Equal(t, http.StatusForbidden, statusCode) + require.Equal(t, appsec.CaptchaRemediation, appsecResponse.Action) + require.Equal(t, http.StatusForbidden, appsecResponse.HTTPStatus) + }, + }, + { + name: "custom user HTTP code", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule42", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + UserBlockedHTTPCode: 418, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Equal(t, appsec.BanRemediation, responses[0].Action) + require.Equal(t, http.StatusForbidden, statusCode) + require.Equal(t, appsec.BanRemediation, appsecResponse.Action) + require.Equal(t, http.StatusTeapot, appsecResponse.HTTPStatus) + }, + }, + { + name: "custom remediation + HTTP code", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule42", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + UserBlockedHTTPCode: 418, + DefaultRemediation: "foobar", + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Equal(t, "foobar", responses[0].Action) + require.Equal(t, http.StatusForbidden, statusCode) + require.Equal(t, "foobar", appsecResponse.Action) + require.Equal(t, http.StatusTeapot, appsecResponse.HTTPStatus) + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + loadAppSecEngine(test, t) + }) + } +} + func TestAppsecRuleMatches(t *testing.T) { /* @@ -601,7 +1078,7 @@ func TestAppsecRuleMatches(t *testing.T) { URI: "/urllll", Args: url.Values{"foo": []string{"toto"}}, }, - output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse) { + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { require.Len(t, events, 2) require.Equal(t, types.APPSEC, events[0].Type) @@ -632,13 +1109,172 @@ func TestAppsecRuleMatches(t *testing.T) { URI: "/urllll", Args: url.Values{"foo": []string{"tutu"}}, }, - output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse) { + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { require.Empty(t, events) require.Len(t, responses, 1) require.False(t, responses[0].InBandInterrupt) require.False(t, responses[0].OutOfBandInterrupt) }, }, + { + name: "default remediation to allow", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule42", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + DefaultRemediation: "allow", + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Equal(t, appsec.AllowRemediation, responses[0].Action) + require.Equal(t, http.StatusOK, statusCode) + require.Equal(t, appsec.AllowRemediation, appsecResponse.Action) + require.Equal(t, http.StatusOK, appsecResponse.HTTPStatus) + }, + }, + { + name: "default remediation to captcha", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule42", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + DefaultRemediation: "captcha", + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Equal(t, appsec.CaptchaRemediation, responses[0].Action) + require.Equal(t, http.StatusForbidden, statusCode) + require.Equal(t, appsec.CaptchaRemediation, appsecResponse.Action) + require.Equal(t, http.StatusForbidden, appsecResponse.HTTPStatus) + }, + }, + { + name: "no default remediation / custom user HTTP code", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule42", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + UserBlockedHTTPCode: 418, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Equal(t, appsec.BanRemediation, responses[0].Action) + require.Equal(t, http.StatusForbidden, statusCode) + require.Equal(t, appsec.BanRemediation, appsecResponse.Action) + require.Equal(t, http.StatusTeapot, appsecResponse.HTTPStatus) + }, + }, + { + name: "no match but try to set remediation to captcha with on_match hook", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule42", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + on_match: []appsec.Hook{ + {Filter: "IsInBand == true", Apply: []string{"SetRemediation('captcha')"}}, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"bla"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Empty(t, events) + require.Equal(t, http.StatusOK, statusCode) + require.Equal(t, appsec.AllowRemediation, appsecResponse.Action) + }, + }, + { + name: "no match but try to set user HTTP code with on_match hook", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule42", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + on_match: []appsec.Hook{ + {Filter: "IsInBand == true", Apply: []string{"SetReturnCode(418)"}}, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"bla"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Empty(t, events) + require.Equal(t, http.StatusOK, statusCode) + require.Equal(t, appsec.AllowRemediation, appsecResponse.Action) + }, + }, + { + name: "no match but try to set remediation with pre_eval hook", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule42", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + pre_eval: []appsec.Hook{ + {Filter: "IsInBand == true", Apply: []string{"SetRemediationByName('rule42', 'captcha')"}}, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"bla"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Empty(t, events) + require.Equal(t, http.StatusOK, statusCode) + require.Equal(t, appsec.AllowRemediation, appsecResponse.Action) + }, + }, } for _, test := range tests { @@ -678,7 +1314,16 @@ func loadAppSecEngine(test appsecRuleTest, t *testing.T) { outofbandRules = append(outofbandRules, strRule) } - appsecCfg := appsec.AppsecConfig{Logger: logger, OnLoad: test.on_load, PreEval: test.pre_eval, PostEval: test.post_eval, OnMatch: test.on_match} + appsecCfg := appsec.AppsecConfig{Logger: logger, + OnLoad: test.on_load, + PreEval: test.pre_eval, + PostEval: test.post_eval, + OnMatch: test.on_match, + BouncerBlockedHTTPCode: test.BouncerBlockedHTTPCode, + UserBlockedHTTPCode: test.UserBlockedHTTPCode, + UserPassedHTTPCode: test.UserPassedHTTPCode, + DefaultRemediation: test.DefaultRemediation, + DefaultPassAction: test.DefaultPassAction} AppsecRuntime, err := appsecCfg.Build() if err != nil { t.Fatalf("unable to build appsec runtime : %s", err) @@ -724,8 +1369,10 @@ func loadAppSecEngine(test appsecRuleTest, t *testing.T) { runner.handleRequest(&input) time.Sleep(50 * time.Millisecond) + + http_status, appsecResponse := AppsecRuntime.GenerateResponse(OutputResponses[0], logger) log.Infof("events : %s", spew.Sdump(OutputEvents)) log.Infof("responses : %s", spew.Sdump(OutputResponses)) - test.output_asserts(OutputEvents, OutputResponses) + test.output_asserts(OutputEvents, OutputResponses, appsecResponse, http_status) } diff --git a/pkg/acquisition/modules/loki/internal/lokiclient/loki_client.go b/pkg/acquisition/modules/loki/internal/lokiclient/loki_client.go index 8451a86fcdf..d2af4e8af28 100644 --- a/pkg/acquisition/modules/loki/internal/lokiclient/loki_client.go +++ b/pkg/acquisition/modules/loki/internal/lokiclient/loki_client.go @@ -25,6 +25,7 @@ type LokiClient struct { t *tomb.Tomb fail_start time.Time currentTickerInterval time.Duration + requestHeaders map[string]string } type Config struct { @@ -116,7 +117,7 @@ func (lc *LokiClient) queryRange(uri string, ctx context.Context, c chan *LokiQu case <-lc.t.Dying(): return lc.t.Err() case <-ticker.C: - resp, err := http.Get(uri) + resp, err := lc.Get(uri) if err != nil { if ok := lc.shouldRetry(); !ok { return errors.Wrapf(err, "error querying range") @@ -127,6 +128,7 @@ func (lc *LokiClient) queryRange(uri string, ctx context.Context, c chan *LokiQu } if resp.StatusCode != http.StatusOK { + lc.Logger.Warnf("bad HTTP response code for query range: %d", resp.StatusCode) body, _ := io.ReadAll(resp.Body) resp.Body.Close() if ok := lc.shouldRetry(); !ok { @@ -215,7 +217,7 @@ func (lc *LokiClient) Ready(ctx context.Context) error { return lc.t.Err() case <-tick.C: lc.Logger.Debug("Checking if Loki is ready") - resp, err := http.Get(url) + resp, err := lc.Get(url) if err != nil { lc.Logger.Warnf("Error checking if Loki is ready: %s", err) continue @@ -251,10 +253,9 @@ func (lc *LokiClient) Tail(ctx context.Context) (chan *LokiResponse, error) { } requestHeader := http.Header{} - for k, v := range lc.config.Headers { + for k, v := range lc.requestHeaders { requestHeader.Add(k, v) } - requestHeader.Set("User-Agent", "Crowdsec "+cwversion.VersionStr()) lc.Logger.Infof("Connecting to %s", u) conn, _, err := dialer.Dial(u, requestHeader) @@ -293,16 +294,6 @@ func (lc *LokiClient) QueryRange(ctx context.Context, infinite bool) chan *LokiQ lc.Logger.Debugf("Since: %s (%s)", lc.config.Since, time.Now().Add(-lc.config.Since)) - requestHeader := http.Header{} - for k, v := range lc.config.Headers { - requestHeader.Add(k, v) - } - - if lc.config.Username != "" || lc.config.Password != "" { - requestHeader.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte(lc.config.Username+":"+lc.config.Password))) - } - - requestHeader.Set("User-Agent", "Crowdsec "+cwversion.VersionStr()) lc.Logger.Infof("Connecting to %s", url) lc.t.Go(func() error { return lc.queryRange(url, ctx, c, infinite) @@ -310,6 +301,26 @@ func (lc *LokiClient) QueryRange(ctx context.Context, infinite bool) chan *LokiQ return c } +// Create a wrapper for http.Get to be able to set headers and auth +func (lc *LokiClient) Get(url string) (*http.Response, error) { + request, err := http.NewRequest(http.MethodGet, url, nil) + if err != nil { + return nil, err + } + for k, v := range lc.requestHeaders { + request.Header.Add(k, v) + } + return http.DefaultClient.Do(request) +} + func NewLokiClient(config Config) *LokiClient { - return &LokiClient{Logger: log.WithField("component", "lokiclient"), config: config} + headers := make(map[string]string) + for k, v := range config.Headers { + headers[k] = v + } + if config.Username != "" || config.Password != "" { + headers["Authorization"] = "Basic " + base64.StdEncoding.EncodeToString([]byte(config.Username+":"+config.Password)) + } + headers["User-Agent"] = "Crowdsec " + cwversion.VersionStr() + return &LokiClient{Logger: log.WithField("component", "lokiclient"), config: config, requestHeaders: headers} } diff --git a/pkg/acquisition/modules/loki/loki_test.go b/pkg/acquisition/modules/loki/loki_test.go index fae2e3aa98f..6cac1c0fec3 100644 --- a/pkg/acquisition/modules/loki/loki_test.go +++ b/pkg/acquisition/modules/loki/loki_test.go @@ -276,10 +276,17 @@ func feedLoki(logger *log.Entry, n int, title string) error { if err != nil { return err } - resp, err := http.Post("http://127.0.0.1:3100/loki/api/v1/push", "application/json", bytes.NewBuffer(buff)) + req, err := http.NewRequest(http.MethodPost, "http://127.0.0.1:3100/loki/api/v1/push", bytes.NewBuffer(buff)) if err != nil { return err } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("X-Scope-OrgID", "1234") + resp, err := http.DefaultClient.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() if resp.StatusCode != http.StatusNoContent { b, _ := io.ReadAll(resp.Body) logger.Error(string(b)) @@ -306,6 +313,8 @@ mode: cat source: loki url: http://127.0.0.1:3100 query: '{server="demo",key="%s"}' +headers: + x-scope-orgid: "1234" since: 1h `, title), }, @@ -362,26 +371,26 @@ func TestStreamingAcquisition(t *testing.T) { }{ { name: "Bad port", - config: ` -mode: tail + config: `mode: tail source: loki -url: http://127.0.0.1:3101 +url: "http://127.0.0.1:3101" +headers: + x-scope-orgid: "1234" query: > - {server="demo"} -`, // No Loki server here + {server="demo"}`, // No Loki server here expectedErr: "", streamErr: `loki is not ready: context deadline exceeded`, expectedLines: 0, }, { name: "ok", - config: ` -mode: tail + config: `mode: tail source: loki -url: http://127.0.0.1:3100 +url: "http://127.0.0.1:3100" +headers: + x-scope-orgid: "1234" query: > - {server="demo"} -`, + {server="demo"}`, expectedErr: "", streamErr: "", expectedLines: 20, @@ -456,6 +465,8 @@ func TestStopStreaming(t *testing.T) { mode: tail source: loki url: http://127.0.0.1:3100 +headers: + x-scope-orgid: "1234" query: > {server="demo"} ` diff --git a/pkg/apiserver/middlewares/v1/api_key.go b/pkg/apiserver/middlewares/v1/api_key.go index ae7645e1b85..41ee15b4417 100644 --- a/pkg/apiserver/middlewares/v1/api_key.go +++ b/pkg/apiserver/middlewares/v1/api_key.go @@ -66,7 +66,7 @@ func (a *APIKey) authTLS(c *gin.Context, logger *log.Entry) *ent.Bouncer { validCert, extractedCN, err := a.TlsAuth.ValidateCert(c) if !validCert { - logger.Errorf("invalid client certificate: %s", err) + logger.Error(err) return nil } diff --git a/pkg/apiserver/middlewares/v1/tls_auth.go b/pkg/apiserver/middlewares/v1/tls_auth.go index 904f6cd445a..bd2c4bb30e7 100644 --- a/pkg/apiserver/middlewares/v1/tls_auth.go +++ b/pkg/apiserver/middlewares/v1/tls_auth.go @@ -4,6 +4,7 @@ import ( "bytes" "crypto" "crypto/x509" + "encoding/pem" "fmt" "io" "net/http" @@ -19,14 +20,13 @@ import ( type TLSAuth struct { AllowedOUs []string CrlPath string - revokationCache map[string]cacheEntry + revocationCache map[string]cacheEntry cacheExpiration time.Duration logger *log.Entry } type cacheEntry struct { revoked bool - err error timestamp time.Time } @@ -89,10 +89,12 @@ func (ta *TLSAuth) isExpired(cert *x509.Certificate) bool { return false } -func (ta *TLSAuth) isOCSPRevoked(cert *x509.Certificate, issuer *x509.Certificate) (bool, error) { - if cert.OCSPServer == nil || (cert.OCSPServer != nil && len(cert.OCSPServer) == 0) { +// isOCSPRevoked checks if the client certificate is revoked by any of the OCSP servers present in the certificate. +// It returns a boolean indicating if the certificate is revoked and a boolean indicating if the OCSP check was successful and could be cached. +func (ta *TLSAuth) isOCSPRevoked(cert *x509.Certificate, issuer *x509.Certificate) (bool, bool) { + if cert.OCSPServer == nil || len(cert.OCSPServer) == 0 { ta.logger.Infof("TLSAuth: no OCSP Server present in client certificate, skipping OCSP verification") - return false, nil + return false, true } for _, server := range cert.OCSPServer { @@ -104,9 +106,10 @@ func (ta *TLSAuth) isOCSPRevoked(cert *x509.Certificate, issuer *x509.Certificat switch ocspResponse.Status { case ocsp.Good: - return false, nil + return false, true case ocsp.Revoked: - return true, fmt.Errorf("client certificate is revoked by server %s", server) + ta.logger.Errorf("TLSAuth: client certificate is revoked by server %s", server) + return true, true case ocsp.Unknown: log.Debugf("unknow OCSP status for server %s", server) continue @@ -115,83 +118,82 @@ func (ta *TLSAuth) isOCSPRevoked(cert *x509.Certificate, issuer *x509.Certificat log.Infof("Could not get any valid OCSP response, assuming the cert is revoked") - return true, nil + return true, false } -func (ta *TLSAuth) isCRLRevoked(cert *x509.Certificate) (bool, error) { +// isCRLRevoked checks if the client certificate is revoked by the CRL present in the CrlPath. +// It returns a boolean indicating if the certificate is revoked and a boolean indicating if the CRL check was successful and could be cached. +func (ta *TLSAuth) isCRLRevoked(cert *x509.Certificate) (bool, bool) { if ta.CrlPath == "" { - ta.logger.Warn("no crl_path, skipping CRL check") - return false, nil + ta.logger.Info("no crl_path, skipping CRL check") + return false, true } crlContent, err := os.ReadFile(ta.CrlPath) if err != nil { - ta.logger.Warnf("could not read CRL file, skipping check: %s", err) - return false, nil + ta.logger.Errorf("could not read CRL file, skipping check: %s", err) + return false, false } - crl, err := x509.ParseCRL(crlContent) + crlBinary, rest := pem.Decode(crlContent) + if len(rest) > 0 { + ta.logger.Warn("CRL file contains more than one PEM block, ignoring the rest") + } + + crl, err := x509.ParseRevocationList(crlBinary.Bytes) if err != nil { - ta.logger.Warnf("could not parse CRL file, skipping check: %s", err) - return false, nil + ta.logger.Errorf("could not parse CRL file, skipping check: %s", err) + return false, false } - if crl.HasExpired(time.Now().UTC()) { + now := time.Now().UTC() + + if now.After(crl.NextUpdate) { ta.logger.Warn("CRL has expired, will still validate the cert against it.") } - for _, revoked := range crl.TBSCertList.RevokedCertificates { + if now.Before(crl.ThisUpdate) { + ta.logger.Warn("CRL is not yet valid, will still validate the cert against it.") + } + + for _, revoked := range crl.RevokedCertificateEntries { if revoked.SerialNumber.Cmp(cert.SerialNumber) == 0 { - return true, fmt.Errorf("client certificate is revoked by CRL") + ta.logger.Warn("client certificate is revoked by CRL") + return true, true } } - return false, nil + return false, true } func (ta *TLSAuth) isRevoked(cert *x509.Certificate, issuer *x509.Certificate) (bool, error) { sn := cert.SerialNumber.String() - if cacheValue, ok := ta.revokationCache[sn]; ok { + if cacheValue, ok := ta.revocationCache[sn]; ok { if time.Now().UTC().Sub(cacheValue.timestamp) < ta.cacheExpiration { - ta.logger.Debugf("TLSAuth: using cached value for cert %s: %t | %s", sn, cacheValue.revoked, cacheValue.err) - return cacheValue.revoked, cacheValue.err - } else { - ta.logger.Debugf("TLSAuth: cached value expired, removing from cache") - delete(ta.revokationCache, sn) + ta.logger.Debugf("TLSAuth: using cached value for cert %s: %t", sn, cacheValue.revoked) + return cacheValue.revoked, nil } + + ta.logger.Debugf("TLSAuth: cached value expired, removing from cache") + delete(ta.revocationCache, sn) } else { ta.logger.Tracef("TLSAuth: no cached value for cert %s", sn) } - revoked, err := ta.isOCSPRevoked(cert, issuer) - if err != nil { - ta.revokationCache[sn] = cacheEntry{ - revoked: revoked, - err: err, - timestamp: time.Now().UTC(), - } + revokedByOCSP, cacheOCSP := ta.isOCSPRevoked(cert, issuer) - return true, err - } + revokedByCRL, cacheCRL := ta.isCRLRevoked(cert) - if revoked { - ta.revokationCache[sn] = cacheEntry{ + revoked := revokedByOCSP || revokedByCRL + + if cacheOCSP && cacheCRL { + ta.revocationCache[sn] = cacheEntry{ revoked: revoked, - err: err, timestamp: time.Now().UTC(), } - - return true, nil - } - - revoked, err = ta.isCRLRevoked(cert) - ta.revokationCache[sn] = cacheEntry{ - revoked: revoked, - err: err, - timestamp: time.Now().UTC(), } - return revoked, err + return revoked, nil } func (ta *TLSAuth) isInvalid(cert *x509.Certificate, issuer *x509.Certificate) (bool, error) { @@ -265,11 +267,11 @@ func (ta *TLSAuth) ValidateCert(c *gin.Context) (bool, string, error) { revoked, err := ta.isInvalid(clientCert, c.Request.TLS.VerifiedChains[0][1]) if err != nil { ta.logger.Errorf("TLSAuth: error checking if client certificate is revoked: %s", err) - return false, "", fmt.Errorf("could not check for client certification revokation status: %w", err) + return false, "", fmt.Errorf("could not check for client certification revocation status: %w", err) } if revoked { - return false, "", fmt.Errorf("client certificate is revoked") + return false, "", fmt.Errorf("client certificate for CN=%s OU=%s is revoked", clientCert.Subject.CommonName, clientCert.Subject.OrganizationalUnit) } ta.logger.Debugf("client OU %v is allowed vs required OU %v", clientCert.Subject.OrganizationalUnit, ta.AllowedOUs) @@ -282,7 +284,7 @@ func (ta *TLSAuth) ValidateCert(c *gin.Context) (bool, string, error) { func NewTLSAuth(allowedOus []string, crlPath string, cacheExpiration time.Duration, logger *log.Entry) (*TLSAuth, error) { ta := &TLSAuth{ - revokationCache: map[string]cacheEntry{}, + revocationCache: map[string]cacheEntry{}, cacheExpiration: cacheExpiration, CrlPath: crlPath, logger: logger, diff --git a/pkg/appsec/appsec.go b/pkg/appsec/appsec.go index ec7e7bef3b6..554fc3b7123 100644 --- a/pkg/appsec/appsec.go +++ b/pkg/appsec/appsec.go @@ -2,6 +2,7 @@ package appsec import ( "fmt" + "net/http" "os" "regexp" @@ -30,6 +31,12 @@ const ( hookOnMatch ) +const ( + BanRemediation = "ban" + CaptchaRemediation = "captcha" + AllowRemediation = "allow" +) + func (h *Hook) Build(hookStage int) error { ctx := map[string]interface{}{} @@ -62,12 +69,13 @@ func (h *Hook) Build(hookStage int) error { } type AppsecTempResponse struct { - InBandInterrupt bool - OutOfBandInterrupt bool - Action string //allow, deny, captcha, log - HTTPResponseCode int - SendEvent bool //do we send an internal event on rule match - SendAlert bool //do we send an alert on rule match + InBandInterrupt bool + OutOfBandInterrupt bool + Action string //allow, deny, captcha, log + UserHTTPResponseCode int //The response code to send to the user + BouncerHTTPResponseCode int //The response code to send to the remediation component + SendEvent bool //do we send an internal event on rule match + SendAlert bool //do we send an alert on rule match } type AppsecSubEngineOpts struct { @@ -110,31 +118,33 @@ type AppsecRuntimeConfig struct { } type AppsecConfig struct { - Name string `yaml:"name"` - OutOfBandRules []string `yaml:"outofband_rules"` - InBandRules []string `yaml:"inband_rules"` - DefaultRemediation string `yaml:"default_remediation"` - DefaultPassAction string `yaml:"default_pass_action"` - BlockedHTTPCode int `yaml:"blocked_http_code"` - PassedHTTPCode int `yaml:"passed_http_code"` - OnLoad []Hook `yaml:"on_load"` - PreEval []Hook `yaml:"pre_eval"` - PostEval []Hook `yaml:"post_eval"` - OnMatch []Hook `yaml:"on_match"` - VariablesTracking []string `yaml:"variables_tracking"` - InbandOptions AppsecSubEngineOpts `yaml:"inband_options"` - OutOfBandOptions AppsecSubEngineOpts `yaml:"outofband_options"` + Name string `yaml:"name"` + OutOfBandRules []string `yaml:"outofband_rules"` + InBandRules []string `yaml:"inband_rules"` + DefaultRemediation string `yaml:"default_remediation"` + DefaultPassAction string `yaml:"default_pass_action"` + BouncerBlockedHTTPCode int `yaml:"blocked_http_code"` //returned to the bouncer + BouncerPassedHTTPCode int `yaml:"passed_http_code"` //returned to the bouncer + UserBlockedHTTPCode int `yaml:"user_blocked_http_code"` //returned to the user + UserPassedHTTPCode int `yaml:"user_passed_http_code"` //returned to the user + + OnLoad []Hook `yaml:"on_load"` + PreEval []Hook `yaml:"pre_eval"` + PostEval []Hook `yaml:"post_eval"` + OnMatch []Hook `yaml:"on_match"` + VariablesTracking []string `yaml:"variables_tracking"` + InbandOptions AppsecSubEngineOpts `yaml:"inband_options"` + OutOfBandOptions AppsecSubEngineOpts `yaml:"outofband_options"` LogLevel *log.Level `yaml:"log_level"` Logger *log.Entry `yaml:"-"` } func (w *AppsecRuntimeConfig) ClearResponse() { - w.Logger.Debugf("#-> %p", w) w.Response = AppsecTempResponse{} - w.Logger.Debugf("-> %p", w.Config) w.Response.Action = w.Config.DefaultPassAction - w.Response.HTTPResponseCode = w.Config.PassedHTTPCode + w.Response.BouncerHTTPResponseCode = w.Config.BouncerPassedHTTPCode + w.Response.UserHTTPResponseCode = w.Config.UserPassedHTTPCode w.Response.SendEvent = true w.Response.SendAlert = true } @@ -191,24 +201,35 @@ func (wc *AppsecConfig) GetDataDir() string { func (wc *AppsecConfig) Build() (*AppsecRuntimeConfig, error) { ret := &AppsecRuntimeConfig{Logger: wc.Logger.WithField("component", "appsec_runtime_config")} - //set the defaults - switch wc.DefaultRemediation { - case "": - wc.DefaultRemediation = "ban" - case "ban", "captcha", "log": - //those are the officially supported remediation(s) - default: - wc.Logger.Warningf("default '%s' remediation of %s is none of [ban,captcha,log] ensure bouncer compatbility!", wc.DefaultRemediation, wc.Name) + + if wc.BouncerBlockedHTTPCode == 0 { + wc.BouncerBlockedHTTPCode = http.StatusForbidden + } + if wc.BouncerPassedHTTPCode == 0 { + wc.BouncerPassedHTTPCode = http.StatusOK } - if wc.BlockedHTTPCode == 0 { - wc.BlockedHTTPCode = 403 + + if wc.UserBlockedHTTPCode == 0 { + wc.UserBlockedHTTPCode = http.StatusForbidden } - if wc.PassedHTTPCode == 0 { - wc.PassedHTTPCode = 200 + if wc.UserPassedHTTPCode == 0 { + wc.UserPassedHTTPCode = http.StatusOK } if wc.DefaultPassAction == "" { - wc.DefaultPassAction = "allow" + wc.DefaultPassAction = AllowRemediation } + if wc.DefaultRemediation == "" { + wc.DefaultRemediation = BanRemediation + } + + //set the defaults + switch wc.DefaultRemediation { + case BanRemediation, CaptchaRemediation, AllowRemediation: + //those are the officially supported remediation(s) + default: + wc.Logger.Warningf("default '%s' remediation of %s is none of [%s,%s,%s] ensure bouncer compatbility!", wc.DefaultRemediation, wc.Name, BanRemediation, CaptchaRemediation, AllowRemediation) + } + ret.Name = wc.Name ret.Config = wc ret.DefaultRemediation = wc.DefaultRemediation @@ -553,27 +574,13 @@ func (w *AppsecRuntimeConfig) SetActionByName(name string, action string) error func (w *AppsecRuntimeConfig) SetAction(action string) error { //log.Infof("setting to %s", action) w.Logger.Debugf("setting action to %s", action) - switch action { - case "allow": - w.Response.Action = action - w.Response.HTTPResponseCode = w.Config.PassedHTTPCode - //@tko how should we handle this ? it seems bouncer only understand bans, but it might be misleading ? - case "deny", "ban", "block": - w.Response.Action = "ban" - case "log": - w.Response.Action = action - w.Response.HTTPResponseCode = w.Config.PassedHTTPCode - case "captcha": - w.Response.Action = action - default: - w.Response.Action = action - } + w.Response.Action = action return nil } func (w *AppsecRuntimeConfig) SetHTTPCode(code int) error { w.Logger.Debugf("setting http code to %d", code) - w.Response.HTTPResponseCode = code + w.Response.UserHTTPResponseCode = code return nil } @@ -582,24 +589,23 @@ type BodyResponse struct { HTTPStatus int `json:"http_status"` } -func (w *AppsecRuntimeConfig) GenerateResponse(response AppsecTempResponse, logger *log.Entry) BodyResponse { - resp := BodyResponse{} - //if there is no interrupt, we should allow with default code - if !response.InBandInterrupt { - resp.Action = w.Config.DefaultPassAction - resp.HTTPStatus = w.Config.PassedHTTPCode - return resp - } - resp.Action = response.Action - if resp.Action == "" { - resp.Action = w.Config.DefaultRemediation - } - logger.Debugf("action is %s", resp.Action) +func (w *AppsecRuntimeConfig) GenerateResponse(response AppsecTempResponse, logger *log.Entry) (int, BodyResponse) { + var bouncerStatusCode int - resp.HTTPStatus = response.HTTPResponseCode - if resp.HTTPStatus == 0 { - resp.HTTPStatus = w.Config.BlockedHTTPCode + resp := BodyResponse{Action: response.Action} + if response.Action == AllowRemediation { + resp.HTTPStatus = w.Config.UserPassedHTTPCode + bouncerStatusCode = w.Config.BouncerPassedHTTPCode + } else { //ban, captcha and anything else + resp.HTTPStatus = response.UserHTTPResponseCode + if resp.HTTPStatus == 0 { + resp.HTTPStatus = w.Config.UserBlockedHTTPCode + } + bouncerStatusCode = response.BouncerHTTPResponseCode + if bouncerStatusCode == 0 { + bouncerStatusCode = w.Config.BouncerBlockedHTTPCode + } } - logger.Debugf("http status is %d", resp.HTTPStatus) - return resp + + return bouncerStatusCode, resp } diff --git a/pkg/parser/node.go b/pkg/parser/node.go index 23ed20511c3..fe5432ce938 100644 --- a/pkg/parser/node.go +++ b/pkg/parser/node.go @@ -168,9 +168,9 @@ func (n *Node) process(p *types.Event, ctx UnixParserCtx, expressionEnv map[stri NodesHits.With(prometheus.Labels{"source": p.Line.Src, "type": p.Line.Module, "name": n.Name}).Inc() } exprErr := error(nil) - isWhitelisted := n.CheckIPsWL(p.ParseIPSources()) + isWhitelisted := n.CheckIPsWL(p) if !isWhitelisted { - isWhitelisted, exprErr = n.CheckExprWL(cachedExprEnv) + isWhitelisted, exprErr = n.CheckExprWL(cachedExprEnv, p) } if exprErr != nil { // Previous code returned nil if there was an error, so we keep this behavior diff --git a/pkg/parser/runtime.go b/pkg/parser/runtime.go index 4f4f6a0f3d0..afdf88dc873 100644 --- a/pkg/parser/runtime.go +++ b/pkg/parser/runtime.go @@ -221,6 +221,24 @@ var NodesHitsKo = prometheus.NewCounterVec( []string{"source", "type", "name"}, ) +// + +var NodesWlHitsOk = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: "cs_node_wl_hits_ok_total", + Help: "Total events successfully whitelisted by node.", + }, + []string{"source", "type", "name", "reason"}, +) + +var NodesWlHits = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: "cs_node_wl_hits_total", + Help: "Total events processed by whitelist node.", + }, + []string{"source", "type", "name", "reason"}, +) + func stageidx(stage string, stages []string) int { for i, v := range stages { if stage == v { diff --git a/pkg/parser/whitelist.go b/pkg/parser/whitelist.go index 027a9a2858a..f3739a49438 100644 --- a/pkg/parser/whitelist.go +++ b/pkg/parser/whitelist.go @@ -8,6 +8,7 @@ import ( "github.com/antonmedv/expr/vm" "github.com/crowdsecurity/crowdsec/pkg/exprhelpers" "github.com/crowdsecurity/crowdsec/pkg/types" + "github.com/prometheus/client_golang/prometheus" ) type Whitelist struct { @@ -36,11 +37,13 @@ func (n *Node) ContainsIPLists() bool { return len(n.Whitelist.B_Ips) > 0 || len(n.Whitelist.B_Cidrs) > 0 } -func (n *Node) CheckIPsWL(srcs []net.IP) bool { +func (n *Node) CheckIPsWL(p *types.Event) bool { + srcs := p.ParseIPSources() isWhitelisted := false if !n.ContainsIPLists() { return isWhitelisted } + NodesWlHits.With(prometheus.Labels{"source": p.Line.Src, "type": p.Line.Module, "name": n.Name, "reason": n.Whitelist.Reason}).Inc() for _, src := range srcs { if isWhitelisted { break @@ -62,15 +65,19 @@ func (n *Node) CheckIPsWL(srcs []net.IP) bool { n.Logger.Tracef("whitelist: %s not in [%s]", src, v) } } + if isWhitelisted { + NodesWlHitsOk.With(prometheus.Labels{"source": p.Line.Src, "type": p.Line.Module, "name": n.Name, "reason": n.Whitelist.Reason}).Inc() + } return isWhitelisted } -func (n *Node) CheckExprWL(cachedExprEnv map[string]interface{}) (bool, error) { +func (n *Node) CheckExprWL(cachedExprEnv map[string]interface{}, p *types.Event) (bool, error) { isWhitelisted := false if !n.ContainsExprLists() { return false, nil } + NodesWlHits.With(prometheus.Labels{"source": p.Line.Src, "type": p.Line.Module, "name": n.Name, "reason": n.Whitelist.Reason}).Inc() /* run whitelist expression tests anyway */ for eidx, e := range n.Whitelist.B_Exprs { //if we already know the event is whitelisted, skip the rest of the expressions @@ -94,6 +101,9 @@ func (n *Node) CheckExprWL(cachedExprEnv map[string]interface{}) (bool, error) { n.Logger.Errorf("unexpected type %t (%v) while running '%s'", output, output, n.Whitelist.Exprs[eidx]) } } + if isWhitelisted { + NodesWlHitsOk.With(prometheus.Labels{"source": p.Line.Src, "type": p.Line.Module, "name": n.Name, "reason": n.Whitelist.Reason}).Inc() + } return isWhitelisted, nil } diff --git a/pkg/parser/whitelist_test.go b/pkg/parser/whitelist_test.go index 8796aaedafe..501c655243d 100644 --- a/pkg/parser/whitelist_test.go +++ b/pkg/parser/whitelist_test.go @@ -289,9 +289,9 @@ func TestWhitelistCheck(t *testing.T) { var err error node.Whitelist = tt.whitelist node.CompileWLs() - isWhitelisted := node.CheckIPsWL(tt.event.ParseIPSources()) + isWhitelisted := node.CheckIPsWL(tt.event) if !isWhitelisted { - isWhitelisted, err = node.CheckExprWL(map[string]interface{}{"evt": tt.event}) + isWhitelisted, err = node.CheckExprWL(map[string]interface{}{"evt": tt.event}, tt.event) } require.NoError(t, err) require.Equal(t, tt.expected, isWhitelisted) diff --git a/test/bats/01_cscli.bats b/test/bats/01_cscli.bats index 3a5b4aad04c..60a65b98d58 100644 --- a/test/bats/01_cscli.bats +++ b/test/bats/01_cscli.bats @@ -273,15 +273,6 @@ teardown() { assert_output 'failed to authenticate to Local API (LAPI): API error: incorrect Username or Password' } -@test "cscli metrics" { - rune -0 ./instance-crowdsec start - rune -0 cscli lapi status - rune -0 cscli metrics - assert_output --partial "Route" - assert_output --partial '/v1/watchers/login' - assert_output --partial "Local API Metrics:" -} - @test "'cscli completion' with or without configuration file" { rune -0 cscli completion bash assert_output --partial "# bash completion for cscli" diff --git a/test/bats/08_metrics.bats b/test/bats/08_metrics.bats index 0275d7fd4a0..8bf30812cff 100644 --- a/test/bats/08_metrics.bats +++ b/test/bats/08_metrics.bats @@ -25,7 +25,7 @@ teardown() { @test "cscli metrics (crowdsec not running)" { rune -1 cscli metrics # crowdsec is down - assert_stderr --partial 'failed to fetch prometheus metrics: executing GET request for URL \"http://127.0.0.1:6060/metrics\" failed: Get \"http://127.0.0.1:6060/metrics\": dial tcp 127.0.0.1:6060: connect: connection refused' + assert_stderr --partial 'failed to fetch metrics: executing GET request for URL \"http://127.0.0.1:6060/metrics\" failed: Get \"http://127.0.0.1:6060/metrics\": dial tcp 127.0.0.1:6060: connect: connection refused' } @test "cscli metrics (bad configuration)" { @@ -59,3 +59,57 @@ teardown() { rune -1 cscli metrics assert_stderr --partial "prometheus is not enabled, can't show metrics" } + +@test "cscli metrics" { + rune -0 ./instance-crowdsec start + rune -0 cscli lapi status + rune -0 cscli metrics + assert_output --partial "Route" + assert_output --partial '/v1/watchers/login' + assert_output --partial "Local API Metrics:" + + rune -0 cscli metrics -o json + rune -0 jq 'keys' <(output) + assert_output --partial '"alerts",' + assert_output --partial '"parsers",' + + rune -0 cscli metrics -o raw + assert_output --partial 'alerts: {}' + assert_output --partial 'parsers: {}' +} + +@test "cscli metrics list" { + rune -0 cscli metrics list + assert_output --regexp "Type.*Title.*Description" + + rune -0 cscli metrics list -o json + rune -0 jq -c '.[] | [.type,.title]' <(output) + assert_line '["acquisition","Acquisition Metrics"]' + + rune -0 cscli metrics list -o raw + assert_line "- type: acquisition" + assert_line " title: Acquisition Metrics" +} + +@test "cscli metrics show" { + rune -0 ./instance-crowdsec start + rune -0 cscli lapi status + + assert_equal "$(cscli metrics)" "$(cscli metrics show)" + + rune -1 cscli metrics show foobar + assert_stderr --partial "unknown metrics type: foobar" + + rune -0 cscli metrics show lapi + assert_output --partial "Local API Metrics:" + assert_output --regexp "Route.*Method.*Hits" + assert_output --regexp "/v1/watchers/login.*POST" + + rune -0 cscli metrics show lapi -o json + rune -0 jq -c '.lapi."/v1/watchers/login" | keys' <(output) + assert_json '["POST"]' + + rune -0 cscli metrics show lapi -o raw + assert_line 'lapi:' + assert_line ' /v1/watchers/login:' +} diff --git a/test/bats/11_bouncers_tls.bats b/test/bats/11_bouncers_tls.bats index 8fb4579259d..2c39aae3079 100644 --- a/test/bats/11_bouncers_tls.bats +++ b/test/bats/11_bouncers_tls.bats @@ -90,7 +90,10 @@ teardown() { } @test "simulate one bouncer request with a revoked certificate" { + truncate_log rune -0 curl -i -s --cert "${tmpdir}/bouncer_revoked.pem" --key "${tmpdir}/bouncer_revoked-key.pem" --cacert "${tmpdir}/bundle.pem" https://localhost:8080/v1/decisions\?ip=42.42.42.42 + assert_log --partial "client certificate is revoked by CRL" + assert_log --partial "client certificate for CN=localhost OU=[bouncer-ou] is revoked" assert_output --partial "access forbidden" rune -0 cscli bouncers list -o json assert_output "[]" diff --git a/test/bats/30_machines_tls.bats b/test/bats/30_machines_tls.bats index 535435336ba..311293ca70c 100644 --- a/test/bats/30_machines_tls.bats +++ b/test/bats/30_machines_tls.bats @@ -132,13 +132,15 @@ teardown() { ' config_set "${CONFIG_DIR}/local_api_credentials.yaml" 'del(.login,.password)' ./instance-crowdsec start + rune -1 cscli lapi status rune -0 cscli machines list -o json assert_output '[]' } @test "revoked cert for agent" { + truncate_log config_set "${CONFIG_DIR}/local_api_credentials.yaml" ' - .ca_cert_path=strenv(tmpdir) + "/bundle.pem" | + .ca_cert_path=strenv(tmpdir) + "/bundle.pem" | .key_path=strenv(tmpdir) + "/agent_revoked-key.pem" | .cert_path=strenv(tmpdir) + "/agent_revoked.pem" | .url="https://127.0.0.1:8080" @@ -146,6 +148,9 @@ teardown() { config_set "${CONFIG_DIR}/local_api_credentials.yaml" 'del(.login,.password)' ./instance-crowdsec start + rune -1 cscli lapi status + assert_log --partial "client certificate is revoked by CRL" + assert_log --partial "client certificate for CN=localhost OU=[agent-ou] is revoked" rune -0 cscli machines list -o json assert_output '[]' }