From 58a1d7164f99bacbbe27c0d32b14bf63e27b4274 Mon Sep 17 00:00:00 2001 From: mmetc <92726601+mmetc@users.noreply.github.com> Date: Fri, 9 Feb 2024 17:39:50 +0100 Subject: [PATCH 01/20] refact "cscli lapi" (#2825) --- cmd/crowdsec-cli/lapi.go | 271 ++++++++++++++++++++++++--------------- cmd/crowdsec-cli/main.go | 2 +- 2 files changed, 167 insertions(+), 106 deletions(-) diff --git a/cmd/crowdsec-cli/lapi.go b/cmd/crowdsec-cli/lapi.go index ce59ac370cd..0bb4a31b72a 100644 --- a/cmd/crowdsec-cli/lapi.go +++ b/cmd/crowdsec-cli/lapi.go @@ -6,6 +6,7 @@ import ( "fmt" "net/url" "os" + "slices" "sort" "strings" @@ -13,7 +14,6 @@ import ( log "github.com/sirupsen/logrus" "github.com/spf13/cobra" "gopkg.in/yaml.v2" - "slices" "github.com/crowdsecurity/go-cs-lib/version" @@ -29,15 +29,27 @@ import ( const LAPIURLPrefix = "v1" -func runLapiStatus(cmd *cobra.Command, args []string) error { - password := strfmt.Password(csConfig.API.Client.Credentials.Password) - apiurl, err := url.Parse(csConfig.API.Client.Credentials.URL) - login := csConfig.API.Client.Credentials.Login +type cliLapi struct { + cfg configGetter +} + +func NewCLILapi(cfg configGetter) *cliLapi { + return &cliLapi{ + cfg: cfg, + } +} + +func (cli *cliLapi) status() error { + cfg := cli.cfg() + password := strfmt.Password(cfg.API.Client.Credentials.Password) + login := cfg.API.Client.Credentials.Login + + apiurl, err := url.Parse(cfg.API.Client.Credentials.URL) if err != nil { return fmt.Errorf("parsing api url: %w", err) } - hub, err := require.Hub(csConfig, nil, nil) + hub, err := require.Hub(cfg, nil, nil) if err != nil { return err } @@ -54,13 +66,14 @@ func runLapiStatus(cmd *cobra.Command, args []string) error { if err != nil { return fmt.Errorf("init default client: %w", err) } + t := models.WatcherAuthRequest{ MachineID: &login, Password: &password, Scenarios: scenarios, } - log.Infof("Loaded credentials from %s", csConfig.API.Client.CredentialsFilePath) + log.Infof("Loaded credentials from %s", cfg.API.Client.CredentialsFilePath) log.Infof("Trying to authenticate with username %s on %s", login, apiurl) _, _, err = Client.Auth.AuthenticateWatcher(context.Background(), t) @@ -69,26 +82,15 @@ func runLapiStatus(cmd *cobra.Command, args []string) error { } log.Infof("You can successfully interact with Local API (LAPI)") + return nil } -func runLapiRegister(cmd *cobra.Command, args []string) error { - flags := cmd.Flags() +func (cli *cliLapi) register(apiURL string, outputFile string, machine string) error { + var err error - apiURL, err := flags.GetString("url") - if err != nil { - return err - } - - outputFile, err := flags.GetString("file") - if err != nil { - return err - } - - lapiUser, err := flags.GetString("machine") - if err != nil { - return err - } + lapiUser := machine + cfg := cli.cfg() if lapiUser == "" { lapiUser, err = generateID("") @@ -96,12 +98,15 @@ func runLapiRegister(cmd *cobra.Command, args []string) error { return fmt.Errorf("unable to generate machine id: %w", err) } } + password := strfmt.Password(generatePassword(passwordLength)) + if apiURL == "" { - if csConfig.API.Client == nil || csConfig.API.Client.Credentials == nil || csConfig.API.Client.Credentials.URL == "" { + if cfg.API.Client == nil || cfg.API.Client.Credentials == nil || cfg.API.Client.Credentials.URL == "" { return fmt.Errorf("no Local API URL. Please provide it in your configuration or with the -u parameter") } - apiURL = csConfig.API.Client.Credentials.URL + + apiURL = cfg.API.Client.Credentials.URL } /*URL needs to end with /, but user doesn't care*/ if !strings.HasSuffix(apiURL, "/") { @@ -111,10 +116,12 @@ func runLapiRegister(cmd *cobra.Command, args []string) error { if !strings.HasPrefix(apiURL, "http://") && !strings.HasPrefix(apiURL, "https://") { apiURL = "http://" + apiURL } + apiurl, err := url.Parse(apiURL) if err != nil { return fmt.Errorf("parsing api url: %w", err) } + _, err = apiclient.RegisterClient(&apiclient.Config{ MachineID: lapiUser, Password: password, @@ -130,138 +137,142 @@ func runLapiRegister(cmd *cobra.Command, args []string) error { log.Printf("Successfully registered to Local API (LAPI)") var dumpFile string + if outputFile != "" { dumpFile = outputFile - } else if csConfig.API.Client.CredentialsFilePath != "" { - dumpFile = csConfig.API.Client.CredentialsFilePath + } else if cfg.API.Client.CredentialsFilePath != "" { + dumpFile = cfg.API.Client.CredentialsFilePath } else { dumpFile = "" } + apiCfg := csconfig.ApiCredentialsCfg{ Login: lapiUser, Password: password.String(), URL: apiURL, } + apiConfigDump, err := yaml.Marshal(apiCfg) if err != nil { return fmt.Errorf("unable to marshal api credentials: %w", err) } + if dumpFile != "" { err = os.WriteFile(dumpFile, apiConfigDump, 0o600) if err != nil { return fmt.Errorf("write api credentials to '%s' failed: %w", dumpFile, err) } + log.Printf("Local API credentials written to '%s'", dumpFile) } else { fmt.Printf("%s\n", string(apiConfigDump)) } + log.Warning(ReloadMessage()) return nil } -func NewLapiStatusCmd() *cobra.Command { +func (cli *cliLapi) newStatusCmd() *cobra.Command { cmdLapiStatus := &cobra.Command{ Use: "status", Short: "Check authentication to Local API (LAPI)", Args: cobra.MinimumNArgs(0), DisableAutoGenTag: true, - RunE: runLapiStatus, + RunE: func(cmd *cobra.Command, args []string) error { + return cli.status() + }, } return cmdLapiStatus } -func NewLapiRegisterCmd() *cobra.Command { - cmdLapiRegister := &cobra.Command{ +func (cli *cliLapi) newRegisterCmd() *cobra.Command { + var ( + apiURL string + outputFile string + machine string + ) + + cmd := &cobra.Command{ Use: "register", Short: "Register a machine to Local API (LAPI)", Long: `Register your machine to the Local API (LAPI). Keep in mind the machine needs to be validated by an administrator on LAPI side to be effective.`, Args: cobra.MinimumNArgs(0), DisableAutoGenTag: true, - RunE: runLapiRegister, + RunE: func(_ *cobra.Command, _ []string) error { + return cli.register(apiURL, outputFile, machine) + }, } - flags := cmdLapiRegister.Flags() - flags.StringP("url", "u", "", "URL of the API (ie. http://127.0.0.1)") - flags.StringP("file", "f", "", "output file destination") - flags.String("machine", "", "Name of the machine to register with") + flags := cmd.Flags() + flags.StringVarP(&apiURL, "url", "u", "", "URL of the API (ie. http://127.0.0.1)") + flags.StringVarP(&outputFile, "file", "f", "", "output file destination") + flags.StringVar(&machine, "machine", "", "Name of the machine to register with") - return cmdLapiRegister + return cmd } -func NewLapiCmd() *cobra.Command { - cmdLapi := &cobra.Command{ +func (cli *cliLapi) NewCommand() *cobra.Command { + cmd := &cobra.Command{ Use: "lapi [action]", Short: "Manage interaction with Local API (LAPI)", Args: cobra.MinimumNArgs(1), DisableAutoGenTag: true, - PersistentPreRunE: func(cmd *cobra.Command, args []string) error { - if err := csConfig.LoadAPIClient(); err != nil { + PersistentPreRunE: func(_ *cobra.Command, _ []string) error { + if err := cli.cfg().LoadAPIClient(); err != nil { return fmt.Errorf("loading api client: %w", err) } return nil }, } - cmdLapi.AddCommand(NewLapiRegisterCmd()) - cmdLapi.AddCommand(NewLapiStatusCmd()) - cmdLapi.AddCommand(NewLapiContextCmd()) + cmd.AddCommand(cli.newRegisterCmd()) + cmd.AddCommand(cli.newStatusCmd()) + cmd.AddCommand(cli.newContextCmd()) - return cmdLapi + return cmd } -func AddContext(key string, values []string) error { +func (cli *cliLapi) addContext(key string, values []string) error { + cfg := cli.cfg() + if err := alertcontext.ValidateContextExpr(key, values); err != nil { - return fmt.Errorf("invalid context configuration :%s", err) + return fmt.Errorf("invalid context configuration: %w", err) } - if _, ok := csConfig.Crowdsec.ContextToSend[key]; !ok { - csConfig.Crowdsec.ContextToSend[key] = make([]string, 0) + + if _, ok := cfg.Crowdsec.ContextToSend[key]; !ok { + cfg.Crowdsec.ContextToSend[key] = make([]string, 0) log.Infof("key '%s' added", key) } - data := csConfig.Crowdsec.ContextToSend[key] + + data := cfg.Crowdsec.ContextToSend[key] + for _, val := range values { if !slices.Contains(data, val) { log.Infof("value '%s' added to key '%s'", val, key) data = append(data, val) } - csConfig.Crowdsec.ContextToSend[key] = data + + cfg.Crowdsec.ContextToSend[key] = data } - if err := csConfig.Crowdsec.DumpContextConfigFile(); err != nil { + + if err := cfg.Crowdsec.DumpContextConfigFile(); err != nil { return err } return nil } -func NewLapiContextCmd() *cobra.Command { - cmdContext := &cobra.Command{ - Use: "context [command]", - Short: "Manage context to send with alerts", - DisableAutoGenTag: true, - PersistentPreRunE: func(cmd *cobra.Command, args []string) error { - if err := csConfig.LoadCrowdsec(); err != nil { - fileNotFoundMessage := fmt.Sprintf("failed to open context file: open %s: no such file or directory", csConfig.Crowdsec.ConsoleContextPath) - if err.Error() != fileNotFoundMessage { - return fmt.Errorf("unable to load CrowdSec agent configuration: %w", err) - } - } - if csConfig.DisableAgent { - return errors.New("agent is disabled and lapi context can only be used on the agent") - } - - return nil - }, - Run: func(cmd *cobra.Command, args []string) { - printHelp(cmd) - }, - } +func (cli *cliLapi) newContextAddCmd() *cobra.Command { + var ( + keyToAdd string + valuesToAdd []string + ) - var keyToAdd string - var valuesToAdd []string - cmdContextAdd := &cobra.Command{ + cmd := &cobra.Command{ Use: "add", Short: "Add context to send with alerts. You must specify the output key with the expr value you want", Example: `cscli lapi context add --key source_ip --value evt.Meta.source_ip @@ -269,18 +280,18 @@ cscli lapi context add --key file_source --value evt.Line.Src cscli lapi context add --value evt.Meta.source_ip --value evt.Meta.target_user `, DisableAutoGenTag: true, - RunE: func(cmd *cobra.Command, args []string) error { - hub, err := require.Hub(csConfig, nil, nil) + RunE: func(_ *cobra.Command, _ []string) error { + hub, err := require.Hub(cli.cfg(), nil, nil) if err != nil { return err } - if err = alertcontext.LoadConsoleContext(csConfig, hub); err != nil { + if err = alertcontext.LoadConsoleContext(cli.cfg(), hub); err != nil { return fmt.Errorf("while loading context: %w", err) } if keyToAdd != "" { - if err := AddContext(keyToAdd, valuesToAdd); err != nil { + if err := cli.addContext(keyToAdd, valuesToAdd); err != nil { return err } return nil @@ -290,7 +301,7 @@ cscli lapi context add --value evt.Meta.source_ip --value evt.Meta.target_user keySlice := strings.Split(v, ".") key := keySlice[len(keySlice)-1] value := []string{v} - if err := AddContext(key, value); err != nil { + if err := cli.addContext(key, value); err != nil { return err } } @@ -298,31 +309,37 @@ cscli lapi context add --value evt.Meta.source_ip --value evt.Meta.target_user return nil }, } - cmdContextAdd.Flags().StringVarP(&keyToAdd, "key", "k", "", "The key of the different values to send") - cmdContextAdd.Flags().StringSliceVar(&valuesToAdd, "value", []string{}, "The expr fields to associate with the key") - cmdContextAdd.MarkFlagRequired("value") - cmdContext.AddCommand(cmdContextAdd) - cmdContextStatus := &cobra.Command{ + flags := cmd.Flags() + flags.StringVarP(&keyToAdd, "key", "k", "", "The key of the different values to send") + flags.StringSliceVar(&valuesToAdd, "value", []string{}, "The expr fields to associate with the key") + cmd.MarkFlagRequired("value") + + return cmd +} + +func (cli *cliLapi) newContextStatusCmd() *cobra.Command { + cmd := &cobra.Command{ Use: "status", Short: "List context to send with alerts", DisableAutoGenTag: true, - RunE: func(cmd *cobra.Command, args []string) error { - hub, err := require.Hub(csConfig, nil, nil) + RunE: func(_ *cobra.Command, _ []string) error { + cfg := cli.cfg() + hub, err := require.Hub(cfg, nil, nil) if err != nil { return err } - if err = alertcontext.LoadConsoleContext(csConfig, hub); err != nil { + if err = alertcontext.LoadConsoleContext(cfg, hub); err != nil { return fmt.Errorf("while loading context: %w", err) } - if len(csConfig.Crowdsec.ContextToSend) == 0 { + if len(cfg.Crowdsec.ContextToSend) == 0 { fmt.Println("No context found on this agent. You can use 'cscli lapi context add' to add context to your alerts.") return nil } - dump, err := yaml.Marshal(csConfig.Crowdsec.ContextToSend) + dump, err := yaml.Marshal(cfg.Crowdsec.ContextToSend) if err != nil { return fmt.Errorf("unable to show context status: %w", err) } @@ -332,10 +349,14 @@ cscli lapi context add --value evt.Meta.source_ip --value evt.Meta.target_user return nil }, } - cmdContext.AddCommand(cmdContextStatus) + return cmd +} + +func (cli *cliLapi) newContextDetectCmd() *cobra.Command { var detectAll bool - cmdContextDetect := &cobra.Command{ + + cmd := &cobra.Command{ Use: "detect", Short: "Detect available fields from the installed parsers", Example: `cscli lapi context detect --all @@ -343,6 +364,7 @@ cscli lapi context detect crowdsecurity/sshd-logs `, DisableAutoGenTag: true, RunE: func(cmd *cobra.Command, args []string) error { + cfg := cli.cfg() if !detectAll && len(args) == 0 { log.Infof("Please provide parsers to detect or --all flag.") printHelp(cmd) @@ -355,13 +377,13 @@ cscli lapi context detect crowdsecurity/sshd-logs return fmt.Errorf("failed to init expr helpers: %w", err) } - hub, err := require.Hub(csConfig, nil, nil) + hub, err := require.Hub(cfg, nil, nil) if err != nil { return err } csParsers := parser.NewParsers(hub) - if csParsers, err = parser.LoadParsers(csConfig, csParsers); err != nil { + if csParsers, err = parser.LoadParsers(cfg, csParsers); err != nil { return fmt.Errorf("unable to load parsers: %w", err) } @@ -418,47 +440,85 @@ cscli lapi context detect crowdsecurity/sshd-logs return nil }, } - cmdContextDetect.Flags().BoolVarP(&detectAll, "all", "a", false, "Detect evt field for all installed parser") - cmdContext.AddCommand(cmdContextDetect) + cmd.Flags().BoolVarP(&detectAll, "all", "a", false, "Detect evt field for all installed parser") + + return cmd +} - cmdContextDelete := &cobra.Command{ +func (cli *cliLapi) newContextDeleteCmd() *cobra.Command { + cmd := &cobra.Command{ Use: "delete", DisableAutoGenTag: true, RunE: func(_ *cobra.Command, _ []string) error { - filePath := csConfig.Crowdsec.ConsoleContextPath + filePath := cli.cfg().Crowdsec.ConsoleContextPath if filePath == "" { filePath = "the context file" } - fmt.Printf("Command \"delete\" is deprecated, please manually edit %s.", filePath) + fmt.Printf("Command 'delete' is deprecated, please manually edit %s.", filePath) + + return nil + }, + } + + return cmd +} + +func (cli *cliLapi) newContextCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "context [command]", + Short: "Manage context to send with alerts", + DisableAutoGenTag: true, + PersistentPreRunE: func(_ *cobra.Command, _ []string) error { + cfg := cli.cfg() + if err := cfg.LoadCrowdsec(); err != nil { + fileNotFoundMessage := fmt.Sprintf("failed to open context file: open %s: no such file or directory", cfg.Crowdsec.ConsoleContextPath) + if err.Error() != fileNotFoundMessage { + return fmt.Errorf("unable to load CrowdSec agent configuration: %w", err) + } + } + if cfg.DisableAgent { + return errors.New("agent is disabled and lapi context can only be used on the agent") + } + return nil }, + Run: func(cmd *cobra.Command, _ []string) { + printHelp(cmd) + }, } - cmdContext.AddCommand(cmdContextDelete) - return cmdContext + cmd.AddCommand(cli.newContextAddCmd()) + cmd.AddCommand(cli.newContextStatusCmd()) + cmd.AddCommand(cli.newContextDetectCmd()) + cmd.AddCommand(cli.newContextDeleteCmd()) + + return cmd } -func detectStaticField(GrokStatics []parser.ExtraField) []string { +func detectStaticField(grokStatics []parser.ExtraField) []string { ret := make([]string, 0) - for _, static := range GrokStatics { + for _, static := range grokStatics { if static.Parsed != "" { fieldName := fmt.Sprintf("evt.Parsed.%s", static.Parsed) if !slices.Contains(ret, fieldName) { ret = append(ret, fieldName) } } + if static.Meta != "" { fieldName := fmt.Sprintf("evt.Meta.%s", static.Meta) if !slices.Contains(ret, fieldName) { ret = append(ret, fieldName) } } + if static.TargetByName != "" { fieldName := static.TargetByName if !strings.HasPrefix(fieldName, "evt.") { fieldName = "evt." + fieldName } + if !slices.Contains(ret, fieldName) { ret = append(ret, fieldName) } @@ -526,6 +586,7 @@ func detectSubNode(node parser.Node, parserCTX parser.UnixParserCtx) []string { } } } + if subnode.Grok.RegexpName != "" { grokCompiled, err := parserCTX.Grok.Get(subnode.Grok.RegexpName) if err == nil { diff --git a/cmd/crowdsec-cli/main.go b/cmd/crowdsec-cli/main.go index 62b85e63047..b0855fb047e 100644 --- a/cmd/crowdsec-cli/main.go +++ b/cmd/crowdsec-cli/main.go @@ -241,7 +241,7 @@ It is meant to allow you to manage bans, parsers/scenarios/etc, api and generall cmd.AddCommand(NewCLIBouncers(cli.cfg).NewCommand()) cmd.AddCommand(NewCLIMachines(cli.cfg).NewCommand()) cmd.AddCommand(NewCLICapi().NewCommand()) - cmd.AddCommand(NewLapiCmd()) + cmd.AddCommand(NewCLILapi(cli.cfg).NewCommand()) cmd.AddCommand(NewCompletionCmd()) cmd.AddCommand(NewConsoleCmd()) cmd.AddCommand(NewCLIExplain().NewCommand()) From 2853410576456471b3f0efef223c4f7bb04600ab Mon Sep 17 00:00:00 2001 From: mmetc <92726601+mmetc@users.noreply.github.com> Date: Fri, 9 Feb 2024 17:51:29 +0100 Subject: [PATCH 02/20] refact "cscli alerts" (#2827) --- cmd/crowdsec-cli/alerts.go | 181 ++++++++++++++++++++++--------------- cmd/crowdsec-cli/main.go | 2 +- 2 files changed, 109 insertions(+), 74 deletions(-) diff --git a/cmd/crowdsec-cli/alerts.go b/cmd/crowdsec-cli/alerts.go index 4ab71be5bbf..ce304bcc777 100644 --- a/cmd/crowdsec-cli/alerts.go +++ b/cmd/crowdsec-cli/alerts.go @@ -29,39 +29,46 @@ import ( func DecisionsFromAlert(alert *models.Alert) string { ret := "" - var decMap = make(map[string]int) + decMap := make(map[string]int) + for _, decision := range alert.Decisions { k := *decision.Type if *decision.Simulated { k = fmt.Sprintf("(simul)%s", k) } + v := decMap[k] decMap[k] = v + 1 } + for k, v := range decMap { if len(ret) > 0 { ret += " " } + ret += fmt.Sprintf("%s:%d", k, v) } + return ret } -func alertsToTable(alerts *models.GetAlertsResponse, printMachine bool) error { - switch csConfig.Cscli.Output { +func (cli *cliAlerts) alertsToTable(alerts *models.GetAlertsResponse, printMachine bool) error { + switch cli.cfg().Cscli.Output { case "raw": csvwriter := csv.NewWriter(os.Stdout) header := []string{"id", "scope", "value", "reason", "country", "as", "decisions", "created_at"} + if printMachine { header = append(header, "machine") } - err := csvwriter.Write(header) - if err != nil { + + if err := csvwriter.Write(header); err != nil { return err } + for _, alertItem := range *alerts { row := []string{ - fmt.Sprintf("%d", alertItem.ID), + strconv.FormatInt(alertItem.ID, 10), *alertItem.Source.Scope, *alertItem.Source.Value, *alertItem.Scenario, @@ -73,11 +80,12 @@ func alertsToTable(alerts *models.GetAlertsResponse, printMachine bool) error { if printMachine { row = append(row, alertItem.MachineID) } - err := csvwriter.Write(row) - if err != nil { + + if err := csvwriter.Write(row); err != nil { return err } } + csvwriter.Flush() case "json": if *alerts == nil { @@ -86,6 +94,7 @@ func alertsToTable(alerts *models.GetAlertsResponse, printMachine bool) error { fmt.Println("[]") return nil } + x, _ := json.MarshalIndent(alerts, "", " ") fmt.Print(string(x)) case "human": @@ -93,8 +102,10 @@ func alertsToTable(alerts *models.GetAlertsResponse, printMachine bool) error { fmt.Println("No active alerts") return nil } + alertsTable(color.Output, alerts, printMachine) } + return nil } @@ -116,13 +127,13 @@ var alertTemplate = ` ` -func displayOneAlert(alert *models.Alert, withDetail bool) error { +func (cli *cliAlerts) displayOneAlert(alert *models.Alert, withDetail bool) error { tmpl, err := template.New("alert").Parse(alertTemplate) if err != nil { return err } - err = tmpl.Execute(os.Stdout, alert) - if err != nil { + + if err = tmpl.Execute(os.Stdout, alert); err != nil { return err } @@ -133,14 +144,17 @@ func displayOneAlert(alert *models.Alert, withDetail bool) error { sort.Slice(alert.Meta, func(i, j int) bool { return alert.Meta[i].Key < alert.Meta[j].Key }) + table := newTable(color.Output) table.SetRowLines(false) table.SetHeaders("Key", "Value") + for _, meta := range alert.Meta { var valSlice []string if err := json.Unmarshal([]byte(meta.Value), &valSlice); err != nil { - return fmt.Errorf("unknown context value type '%s' : %s", meta.Value, err) + return fmt.Errorf("unknown context value type '%s': %w", meta.Value, err) } + for _, value := range valSlice { table.AddRow( meta.Key, @@ -148,11 +162,13 @@ func displayOneAlert(alert *models.Alert, withDetail bool) error { ) } } + table.Render() } if withDetail { fmt.Printf("\n - Events :\n") + for _, event := range alert.Events { alertEventTable(color.Output, event) } @@ -163,10 +179,13 @@ func displayOneAlert(alert *models.Alert, withDetail bool) error { type cliAlerts struct{ client *apiclient.ApiClient + cfg configGetter } -func NewCLIAlerts() *cliAlerts { - return &cliAlerts{} +func NewCLIAlerts(getconfig configGetter) *cliAlerts { + return &cliAlerts{ + cfg: getconfig, + } } func (cli *cliAlerts) NewCommand() *cobra.Command { @@ -176,18 +195,18 @@ func (cli *cliAlerts) NewCommand() *cobra.Command { Args: cobra.MinimumNArgs(1), DisableAutoGenTag: true, Aliases: []string{"alert"}, - PersistentPreRunE: func(cmd *cobra.Command, args []string) error { - var err error - if err := csConfig.LoadAPIClient(); err != nil { + PersistentPreRunE: func(_ *cobra.Command, _ []string) error { + cfg := cli.cfg() + if err := cfg.LoadAPIClient(); err != nil { return fmt.Errorf("loading api client: %w", err) } - apiURL, err := url.Parse(csConfig.API.Client.Credentials.URL) + apiURL, err := url.Parse(cfg.API.Client.Credentials.URL) if err != nil { return fmt.Errorf("parsing api url %s: %w", apiURL, err) } cli.client, err = apiclient.NewClient(&apiclient.Config{ - MachineID: csConfig.API.Client.Credentials.Login, - Password: strfmt.Password(csConfig.API.Client.Credentials.Password), + MachineID: cfg.API.Client.Credentials.Login, + Password: strfmt.Password(cfg.API.Client.Credentials.Password), UserAgent: fmt.Sprintf("crowdsec/%s", version.String()), URL: apiURL, VersionPrefix: "v1", @@ -196,6 +215,7 @@ func (cli *cliAlerts) NewCommand() *cobra.Command { if err != nil { return fmt.Errorf("new api client: %w", err) } + return nil }, } @@ -221,8 +241,10 @@ func (cli *cliAlerts) NewListCmd() *cobra.Command { IncludeCAPI: new(bool), OriginEquals: new(string), } + limit := new(int) contained := new(bool) + var printMachine bool cmd := &cobra.Command{ @@ -234,9 +256,7 @@ cscli alerts list --range 1.2.3.0/24 cscli alerts list -s crowdsecurity/ssh-bf cscli alerts list --type ban`, DisableAutoGenTag: true, - RunE: func(cmd *cobra.Command, args []string) error { - var err error - + RunE: func(cmd *cobra.Command, _ []string) error { if err := manageCliDecisionAlerts(alertListFilter.IPEquals, alertListFilter.RangeEquals, alertListFilter.ScopeEquals, alertListFilter.ValueEquals); err != nil { printHelp(cmd) @@ -304,40 +324,43 @@ cscli alerts list --type ban`, alerts, _, err := cli.client.Alerts.List(context.Background(), alertListFilter) if err != nil { - return fmt.Errorf("unable to list alerts: %v", err) + return fmt.Errorf("unable to list alerts: %w", err) } - err = alertsToTable(alerts, printMachine) - if err != nil { - return fmt.Errorf("unable to list alerts: %v", err) + if err = cli.alertsToTable(alerts, printMachine); err != nil { + return fmt.Errorf("unable to list alerts: %w", err) } return nil }, } - cmd.Flags().SortFlags = false - cmd.Flags().BoolVarP(alertListFilter.IncludeCAPI, "all", "a", false, "Include decisions from Central API") - cmd.Flags().StringVar(alertListFilter.Until, "until", "", "restrict to alerts older than until (ie. 4h, 30d)") - cmd.Flags().StringVar(alertListFilter.Since, "since", "", "restrict to alerts newer than since (ie. 4h, 30d)") - cmd.Flags().StringVarP(alertListFilter.IPEquals, "ip", "i", "", "restrict to alerts from this source ip (shorthand for --scope ip --value )") - cmd.Flags().StringVarP(alertListFilter.ScenarioEquals, "scenario", "s", "", "the scenario (ie. crowdsecurity/ssh-bf)") - cmd.Flags().StringVarP(alertListFilter.RangeEquals, "range", "r", "", "restrict to alerts from this range (shorthand for --scope range --value )") - cmd.Flags().StringVar(alertListFilter.TypeEquals, "type", "", "restrict to alerts with given decision type (ie. ban, captcha)") - cmd.Flags().StringVar(alertListFilter.ScopeEquals, "scope", "", "restrict to alerts of this scope (ie. ip,range)") - cmd.Flags().StringVarP(alertListFilter.ValueEquals, "value", "v", "", "the value to match for in the specified scope") - cmd.Flags().StringVar(alertListFilter.OriginEquals, "origin", "", fmt.Sprintf("the value to match for the specified origin (%s ...)", strings.Join(types.GetOrigins(), ","))) - cmd.Flags().BoolVar(contained, "contained", false, "query decisions contained by range") - cmd.Flags().BoolVarP(&printMachine, "machine", "m", false, "print machines that sent alerts") - cmd.Flags().IntVarP(limit, "limit", "l", 50, "limit size of alerts list table (0 to view all alerts)") + + flags := cmd.Flags() + flags.SortFlags = false + flags.BoolVarP(alertListFilter.IncludeCAPI, "all", "a", false, "Include decisions from Central API") + flags.StringVar(alertListFilter.Until, "until", "", "restrict to alerts older than until (ie. 4h, 30d)") + flags.StringVar(alertListFilter.Since, "since", "", "restrict to alerts newer than since (ie. 4h, 30d)") + flags.StringVarP(alertListFilter.IPEquals, "ip", "i", "", "restrict to alerts from this source ip (shorthand for --scope ip --value )") + flags.StringVarP(alertListFilter.ScenarioEquals, "scenario", "s", "", "the scenario (ie. crowdsecurity/ssh-bf)") + flags.StringVarP(alertListFilter.RangeEquals, "range", "r", "", "restrict to alerts from this range (shorthand for --scope range --value )") + flags.StringVar(alertListFilter.TypeEquals, "type", "", "restrict to alerts with given decision type (ie. ban, captcha)") + flags.StringVar(alertListFilter.ScopeEquals, "scope", "", "restrict to alerts of this scope (ie. ip,range)") + flags.StringVarP(alertListFilter.ValueEquals, "value", "v", "", "the value to match for in the specified scope") + flags.StringVar(alertListFilter.OriginEquals, "origin", "", fmt.Sprintf("the value to match for the specified origin (%s ...)", strings.Join(types.GetOrigins(), ","))) + flags.BoolVar(contained, "contained", false, "query decisions contained by range") + flags.BoolVarP(&printMachine, "machine", "m", false, "print machines that sent alerts") + flags.IntVarP(limit, "limit", "l", 50, "limit size of alerts list table (0 to view all alerts)") return cmd } func (cli *cliAlerts) NewDeleteCmd() *cobra.Command { - var ActiveDecision *bool - var AlertDeleteAll bool - var delAlertByID string - contained := new(bool) + var ( + ActiveDecision *bool + AlertDeleteAll bool + delAlertByID string + ) + var alertDeleteFilter = apiclient.AlertsDeleteOpts{ ScopeEquals: new(string), ValueEquals: new(string), @@ -345,6 +368,9 @@ func (cli *cliAlerts) NewDeleteCmd() *cobra.Command { IPEquals: new(string), RangeEquals: new(string), } + + contained := new(bool) + cmd := &cobra.Command{ Use: "delete [filters] [--all]", Short: `Delete alerts @@ -355,7 +381,7 @@ cscli alerts delete -s crowdsecurity/ssh-bf"`, DisableAutoGenTag: true, Aliases: []string{"remove"}, Args: cobra.ExactArgs(0), - PreRunE: func(cmd *cobra.Command, args []string) error { + PreRunE: func(cmd *cobra.Command, _ []string) error { if AlertDeleteAll { return nil } @@ -368,11 +394,11 @@ cscli alerts delete -s crowdsecurity/ssh-bf"`, return nil }, - RunE: func(cmd *cobra.Command, args []string) error { + RunE: func(cmd *cobra.Command, _ []string) error { var err error if !AlertDeleteAll { - if err := manageCliDecisionAlerts(alertDeleteFilter.IPEquals, alertDeleteFilter.RangeEquals, + if err = manageCliDecisionAlerts(alertDeleteFilter.IPEquals, alertDeleteFilter.RangeEquals, alertDeleteFilter.ScopeEquals, alertDeleteFilter.ValueEquals); err != nil { printHelp(cmd) return err @@ -410,12 +436,12 @@ cscli alerts delete -s crowdsecurity/ssh-bf"`, if delAlertByID == "" { alerts, _, err = cli.client.Alerts.Delete(context.Background(), alertDeleteFilter) if err != nil { - return fmt.Errorf("unable to delete alerts : %v", err) + return fmt.Errorf("unable to delete alerts: %w", err) } } else { alerts, _, err = cli.client.Alerts.DeleteOne(context.Background(), delAlertByID) if err != nil { - return fmt.Errorf("unable to delete alert: %v", err) + return fmt.Errorf("unable to delete alert: %w", err) } } log.Infof("%s alert(s) deleted", alerts.NbDeleted) @@ -423,26 +449,31 @@ cscli alerts delete -s crowdsecurity/ssh-bf"`, return nil }, } - cmd.Flags().SortFlags = false - cmd.Flags().StringVar(alertDeleteFilter.ScopeEquals, "scope", "", "the scope (ie. ip,range)") - cmd.Flags().StringVarP(alertDeleteFilter.ValueEquals, "value", "v", "", "the value to match for in the specified scope") - cmd.Flags().StringVarP(alertDeleteFilter.ScenarioEquals, "scenario", "s", "", "the scenario (ie. crowdsecurity/ssh-bf)") - cmd.Flags().StringVarP(alertDeleteFilter.IPEquals, "ip", "i", "", "Source ip (shorthand for --scope ip --value )") - cmd.Flags().StringVarP(alertDeleteFilter.RangeEquals, "range", "r", "", "Range source ip (shorthand for --scope range --value )") - cmd.Flags().StringVar(&delAlertByID, "id", "", "alert ID") - cmd.Flags().BoolVarP(&AlertDeleteAll, "all", "a", false, "delete all alerts") - cmd.Flags().BoolVar(contained, "contained", false, "query decisions contained by range") + + flags := cmd.Flags() + flags.SortFlags = false + flags.StringVar(alertDeleteFilter.ScopeEquals, "scope", "", "the scope (ie. ip,range)") + flags.StringVarP(alertDeleteFilter.ValueEquals, "value", "v", "", "the value to match for in the specified scope") + flags.StringVarP(alertDeleteFilter.ScenarioEquals, "scenario", "s", "", "the scenario (ie. crowdsecurity/ssh-bf)") + flags.StringVarP(alertDeleteFilter.IPEquals, "ip", "i", "", "Source ip (shorthand for --scope ip --value )") + flags.StringVarP(alertDeleteFilter.RangeEquals, "range", "r", "", "Range source ip (shorthand for --scope range --value )") + flags.StringVar(&delAlertByID, "id", "", "alert ID") + flags.BoolVarP(&AlertDeleteAll, "all", "a", false, "delete all alerts") + flags.BoolVar(contained, "contained", false, "query decisions contained by range") + return cmd } func (cli *cliAlerts) NewInspectCmd() *cobra.Command { var details bool + cmd := &cobra.Command{ Use: `inspect "alert_id"`, Short: `Show info about an alert`, Example: `cscli alerts inspect 123`, DisableAutoGenTag: true, RunE: func(cmd *cobra.Command, args []string) error { + cfg := cli.cfg() if len(args) == 0 { printHelp(cmd) return fmt.Errorf("missing alert_id") @@ -454,31 +485,32 @@ func (cli *cliAlerts) NewInspectCmd() *cobra.Command { } alert, _, err := cli.client.Alerts.GetByID(context.Background(), id) if err != nil { - return fmt.Errorf("can't find alert with id %s: %s", alertID, err) + return fmt.Errorf("can't find alert with id %s: %w", alertID, err) } - switch csConfig.Cscli.Output { + switch cfg.Cscli.Output { case "human": - if err := displayOneAlert(alert, details); err != nil { + if err := cli.displayOneAlert(alert, details); err != nil { continue } case "json": data, err := json.MarshalIndent(alert, "", " ") if err != nil { - return fmt.Errorf("unable to marshal alert with id %s: %s", alertID, err) + return fmt.Errorf("unable to marshal alert with id %s: %w", alertID, err) } fmt.Printf("%s\n", string(data)) case "raw": data, err := yaml.Marshal(alert) if err != nil { - return fmt.Errorf("unable to marshal alert with id %s: %s", alertID, err) + return fmt.Errorf("unable to marshal alert with id %s: %w", alertID, err) } - fmt.Printf("%s\n", string(data)) + fmt.Println(string(data)) } } return nil }, } + cmd.Flags().SortFlags = false cmd.Flags().BoolVarP(&details, "details", "d", false, "show alerts with events") @@ -486,27 +518,30 @@ func (cli *cliAlerts) NewInspectCmd() *cobra.Command { } func (cli *cliAlerts) NewFlushCmd() *cobra.Command { - var maxItems int - var maxAge string + var ( + maxItems int + maxAge string + ) + cmd := &cobra.Command{ Use: `flush`, Short: `Flush alerts /!\ This command can be used only on the same machine than the local API`, Example: `cscli alerts flush --max-items 1000 --max-age 7d`, DisableAutoGenTag: true, - RunE: func(cmd *cobra.Command, args []string) error { - var err error - if err := require.LAPI(csConfig); err != nil { + RunE: func(_ *cobra.Command, _ []string) error { + cfg := cli.cfg() + if err := require.LAPI(cfg); err != nil { return err } - db, err := database.NewClient(csConfig.DbConfig) + db, err := database.NewClient(cfg.DbConfig) if err != nil { - return fmt.Errorf("unable to create new database client: %s", err) + return fmt.Errorf("unable to create new database client: %w", err) } log.Info("Flushing alerts. !! This may take a long time !!") err = db.FlushAlerts(maxAge, maxItems) if err != nil { - return fmt.Errorf("unable to flush alerts: %s", err) + return fmt.Errorf("unable to flush alerts: %w", err) } log.Info("Alerts flushed") diff --git a/cmd/crowdsec-cli/main.go b/cmd/crowdsec-cli/main.go index b0855fb047e..55fcacee39c 100644 --- a/cmd/crowdsec-cli/main.go +++ b/cmd/crowdsec-cli/main.go @@ -236,7 +236,7 @@ It is meant to allow you to manage bans, parsers/scenarios/etc, api and generall cmd.AddCommand(NewCLIMetrics(cli.cfg).NewCommand()) cmd.AddCommand(NewCLIDashboard(cli.cfg).NewCommand()) cmd.AddCommand(NewCLIDecisions(cli.cfg).NewCommand()) - cmd.AddCommand(NewCLIAlerts().NewCommand()) + cmd.AddCommand(NewCLIAlerts(cli.cfg).NewCommand()) cmd.AddCommand(NewCLISimulation(cli.cfg).NewCommand()) cmd.AddCommand(NewCLIBouncers(cli.cfg).NewCommand()) cmd.AddCommand(NewCLIMachines(cli.cfg).NewCommand()) From 5c83695177cd4044a8cc953978103377b63607f0 Mon Sep 17 00:00:00 2001 From: mmetc <92726601+mmetc@users.noreply.github.com> Date: Mon, 12 Feb 2024 11:23:17 +0100 Subject: [PATCH 03/20] refact "cscli explain" (#2835) --- cmd/crowdsec-cli/explain.go | 184 +++++++++++++++--------------------- cmd/crowdsec-cli/main.go | 2 +- go.mod | 4 +- go.sum | 14 +-- 4 files changed, 82 insertions(+), 122 deletions(-) diff --git a/cmd/crowdsec-cli/explain.go b/cmd/crowdsec-cli/explain.go index d21c1704930..ce323fd0ce1 100644 --- a/cmd/crowdsec-cli/explain.go +++ b/cmd/crowdsec-cli/explain.go @@ -16,33 +16,53 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/hubtest" ) -func GetLineCountForFile(filepath string) (int, error) { +func getLineCountForFile(filepath string) (int, error) { f, err := os.Open(filepath) if err != nil { return 0, err } defer f.Close() + lc := 0 fs := bufio.NewReader(f) + for { input, err := fs.ReadBytes('\n') if len(input) > 1 { lc++ } + if err != nil && err == io.EOF { break } } + return lc, nil } -type cliExplain struct{} +type cliExplain struct { + cfg configGetter + flags struct { + logFile string + dsn string + logLine string + logType string + details bool + skipOk bool + onlySuccessfulParsers bool + noClean bool + crowdsec string + labels string + } +} -func NewCLIExplain() *cliExplain { - return &cliExplain{} +func NewCLIExplain(cfg configGetter) *cliExplain { + return &cliExplain{ + cfg: cfg, + } } -func (cli cliExplain) NewCommand() *cobra.Command { +func (cli *cliExplain) NewCommand() *cobra.Command { cmd := &cobra.Command{ Use: "explain", Short: "Explain log pipeline", @@ -57,118 +77,50 @@ tail -n 5 myfile.log | cscli explain --type nginx -f - `, Args: cobra.ExactArgs(0), DisableAutoGenTag: true, - RunE: cli.run, - PersistentPreRunE: func(cmd *cobra.Command, args []string) error { - flags := cmd.Flags() - - logFile, err := flags.GetString("file") - if err != nil { - return err - } - - dsn, err := flags.GetString("dsn") - if err != nil { - return err - } - - logLine, err := flags.GetString("log") - if err != nil { - return err - } - - logType, err := flags.GetString("type") - if err != nil { - return err - } - - if logLine == "" && logFile == "" && dsn == "" { - printHelp(cmd) - fmt.Println() - return fmt.Errorf("please provide --log, --file or --dsn flag") - } - if logType == "" { - printHelp(cmd) - fmt.Println() - return fmt.Errorf("please provide --type flag") - } + RunE: func(_ *cobra.Command, _ []string) error { + return cli.run() + }, + PersistentPreRunE: func(_ *cobra.Command, _ []string) error { fileInfo, _ := os.Stdin.Stat() - if logFile == "-" && ((fileInfo.Mode() & os.ModeCharDevice) == os.ModeCharDevice) { + if cli.flags.logFile == "-" && ((fileInfo.Mode() & os.ModeCharDevice) == os.ModeCharDevice) { return fmt.Errorf("the option -f - is intended to work with pipes") } + return nil }, } flags := cmd.Flags() - flags.StringP("file", "f", "", "Log file to test") - flags.StringP("dsn", "d", "", "DSN to test") - flags.StringP("log", "l", "", "Log line to test") - flags.StringP("type", "t", "", "Type of the acquisition to test") - flags.String("labels", "", "Additional labels to add to the acquisition format (key:value,key2:value2)") - flags.BoolP("verbose", "v", false, "Display individual changes") - flags.Bool("failures", false, "Only show failed lines") - flags.Bool("only-successful-parsers", false, "Only show successful parsers") - flags.String("crowdsec", "crowdsec", "Path to crowdsec") - flags.Bool("no-clean", false, "Don't clean runtime environment after tests") + flags.StringVarP(&cli.flags.logFile, "file", "f", "", "Log file to test") + flags.StringVarP(&cli.flags.dsn, "dsn", "d", "", "DSN to test") + flags.StringVarP(&cli.flags.logLine, "log", "l", "", "Log line to test") + flags.StringVarP(&cli.flags.logType, "type", "t", "", "Type of the acquisition to test") + flags.StringVar(&cli.flags.labels, "labels", "", "Additional labels to add to the acquisition format (key:value,key2:value2)") + flags.BoolVarP(&cli.flags.details, "verbose", "v", false, "Display individual changes") + flags.BoolVar(&cli.flags.skipOk, "failures", false, "Only show failed lines") + flags.BoolVar(&cli.flags.onlySuccessfulParsers, "only-successful-parsers", false, "Only show successful parsers") + flags.StringVar(&cli.flags.crowdsec, "crowdsec", "crowdsec", "Path to crowdsec") + flags.BoolVar(&cli.flags.noClean, "no-clean", false, "Don't clean runtime environment after tests") + + cmd.MarkFlagRequired("type") + cmd.MarkFlagsOneRequired("log", "file", "dsn") return cmd } -func (cli cliExplain) run(cmd *cobra.Command, args []string) error { - flags := cmd.Flags() - - logFile, err := flags.GetString("file") - if err != nil { - return err - } +func (cli *cliExplain) run() error { + logFile := cli.flags.logFile + logLine := cli.flags.logLine + logType := cli.flags.logType + dsn := cli.flags.dsn + labels := cli.flags.labels + crowdsec := cli.flags.crowdsec - dsn, err := flags.GetString("dsn") - if err != nil { - return err - } - - logLine, err := flags.GetString("log") - if err != nil { - return err - } - - logType, err := flags.GetString("type") - if err != nil { - return err - } - - opts := dumps.DumpOpts{} - - opts.Details, err = flags.GetBool("verbose") - if err != nil { - return err - } - - no_clean, err := flags.GetBool("no-clean") - if err != nil { - return err - } - - opts.SkipOk, err = flags.GetBool("failures") - if err != nil { - return err - } - - opts.ShowNotOkParsers, err = flags.GetBool("only-successful-parsers") - opts.ShowNotOkParsers = !opts.ShowNotOkParsers - if err != nil { - return err - } - - crowdsec, err := flags.GetString("crowdsec") - if err != nil { - return err - } - - labels, err := flags.GetString("labels") - if err != nil { - return err + opts := dumps.DumpOpts{ + Details: cli.flags.details, + SkipOk: cli.flags.skipOk, + ShowNotOkParsers: !cli.flags.onlySuccessfulParsers, } var f *os.File @@ -176,21 +128,25 @@ func (cli cliExplain) run(cmd *cobra.Command, args []string) error { // using empty string fallback to /tmp dir, err := os.MkdirTemp("", "cscli_explain") if err != nil { - return fmt.Errorf("couldn't create a temporary directory to store cscli explain result: %s", err) + return fmt.Errorf("couldn't create a temporary directory to store cscli explain result: %w", err) } + defer func() { - if no_clean { + if cli.flags.noClean { return } + if _, err := os.Stat(dir); !os.IsNotExist(err) { if err := os.RemoveAll(dir); err != nil { log.Errorf("unable to delete temporary directory '%s': %s", dir, err) } } }() + // we create a temporary log file if a log line/stdin has been provided if logLine != "" || logFile == "-" { tmpFile := filepath.Join(dir, "cscli_test_tmp.log") + f, err = os.Create(tmpFile) if err != nil { return err @@ -220,6 +176,7 @@ func (cli cliExplain) run(cmd *cobra.Command, args []string) error { log.Warnf("Failed to write %d lines to %s", errCount, tmpFile) } } + f.Close() // this is the file that was going to be read by crowdsec anyway logFile = tmpFile @@ -230,15 +187,20 @@ func (cli cliExplain) run(cmd *cobra.Command, args []string) error { if err != nil { return fmt.Errorf("unable to get absolute path of '%s', exiting", logFile) } + dsn = fmt.Sprintf("file://%s", absolutePath) - lineCount, err := GetLineCountForFile(absolutePath) + + lineCount, err := getLineCountForFile(absolutePath) if err != nil { return err } + log.Debugf("file %s has %d lines", absolutePath, lineCount) + if lineCount == 0 { return fmt.Errorf("the log file is empty: %s", absolutePath) } + if lineCount > 100 { log.Warnf("%s contains %d lines. This may take a lot of resources.", absolutePath, lineCount) } @@ -249,15 +211,19 @@ func (cli cliExplain) run(cmd *cobra.Command, args []string) error { } cmdArgs := []string{"-c", ConfigFilePath, "-type", logType, "-dsn", dsn, "-dump-data", dir, "-no-api"} + if labels != "" { log.Debugf("adding labels %s", labels) cmdArgs = append(cmdArgs, "-label", labels) } + crowdsecCmd := exec.Command(crowdsec, cmdArgs...) + output, err := crowdsecCmd.CombinedOutput() if err != nil { fmt.Println(string(output)) - return fmt.Errorf("fail to run crowdsec for test: %v", err) + + return fmt.Errorf("fail to run crowdsec for test: %w", err) } parserDumpFile := filepath.Join(dir, hubtest.ParserResultFileName) @@ -265,12 +231,12 @@ func (cli cliExplain) run(cmd *cobra.Command, args []string) error { parserDump, err := dumps.LoadParserDump(parserDumpFile) if err != nil { - return fmt.Errorf("unable to load parser dump result: %s", err) + return fmt.Errorf("unable to load parser dump result: %w", err) } bucketStateDump, err := dumps.LoadBucketPourDump(bucketStateDumpFile) if err != nil { - return fmt.Errorf("unable to load bucket dump result: %s", err) + return fmt.Errorf("unable to load bucket dump result: %w", err) } dumps.DumpTree(*parserDump, *bucketStateDump, opts) diff --git a/cmd/crowdsec-cli/main.go b/cmd/crowdsec-cli/main.go index 55fcacee39c..43998623566 100644 --- a/cmd/crowdsec-cli/main.go +++ b/cmd/crowdsec-cli/main.go @@ -244,7 +244,7 @@ It is meant to allow you to manage bans, parsers/scenarios/etc, api and generall cmd.AddCommand(NewCLILapi(cli.cfg).NewCommand()) cmd.AddCommand(NewCompletionCmd()) cmd.AddCommand(NewConsoleCmd()) - cmd.AddCommand(NewCLIExplain().NewCommand()) + cmd.AddCommand(NewCLIExplain(cli.cfg).NewCommand()) cmd.AddCommand(NewCLIHubTest().NewCommand()) cmd.AddCommand(NewCLINotifications().NewCommand()) cmd.AddCommand(NewCLISupport().NewCommand()) diff --git a/go.mod b/go.mod index d61c191c14f..e1da18387a5 100644 --- a/go.mod +++ b/go.mod @@ -77,7 +77,7 @@ require ( github.com/shirou/gopsutil/v3 v3.23.5 github.com/sirupsen/logrus v1.9.3 github.com/slack-go/slack v0.12.2 - github.com/spf13/cobra v1.7.0 + github.com/spf13/cobra v1.8.0 github.com/stretchr/testify v1.8.4 github.com/umahmood/haversine v0.0.0-20151105152445-808ab04add26 github.com/wasilibs/go-re2 v1.3.0 @@ -108,7 +108,7 @@ require ( github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect github.com/corazawaf/libinjection-go v0.1.2 // indirect github.com/coreos/go-systemd/v22 v22.5.0 // indirect - github.com/cpuguy83/go-md2man/v2 v2.0.2 // indirect + github.com/cpuguy83/go-md2man/v2 v2.0.3 // indirect github.com/creack/pty v1.1.18 // indirect github.com/docker/distribution v2.8.2+incompatible // indirect github.com/docker/go-units v0.5.0 // indirect diff --git a/go.sum b/go.sum index f5f61594ecd..8fa2021316b 100644 --- a/go.sum +++ b/go.sum @@ -91,21 +91,17 @@ github.com/coreos/go-systemd v0.0.0-20190719114852-fd7a80b32e1f/go.mod h1:F5haX7 github.com/coreos/go-systemd/v22 v22.5.0 h1:RrqgGjYQKalulkV8NGVIfkXQf6YYmOyiJKk8iXXhfZs= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/cpuguy83/go-md2man/v2 v2.0.1/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= -github.com/cpuguy83/go-md2man/v2 v2.0.2 h1:p1EgwI/C7NhT0JmVkwCD2ZBK8j4aeHQX2pMHHBfMQ6w= -github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= +github.com/cpuguy83/go-md2man/v2 v2.0.3 h1:qMCsGGgs+MAzDFyp9LpAe1Lqy/fY/qCovCm0qnXZOBM= +github.com/cpuguy83/go-md2man/v2 v2.0.3/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/creack/pty v1.1.17/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4= github.com/creack/pty v1.1.18 h1:n56/Zwd5o6whRC5PMGretI4IdRLlmBXYNjScPaBgsbY= github.com/creack/pty v1.1.18/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4= -github.com/crowdsecurity/coraza/v3 v3.0.0-20231213144607-41d5358da94f h1:FkOB9aDw0xzDd14pTarGRLsUNAymONq3dc7zhvsXElg= -github.com/crowdsecurity/coraza/v3 v3.0.0-20231213144607-41d5358da94f/go.mod h1:TrU7Li+z2RHNrPy0TKJ6R65V6Yzpan2sTIRryJJyJso= github.com/crowdsecurity/coraza/v3 v3.0.0-20240108124027-a62b8d8e5607 h1:hyrYw3h8clMcRL2u5ooZ3tmwnmJftmhb9Ws1MKmavvI= github.com/crowdsecurity/coraza/v3 v3.0.0-20240108124027-a62b8d8e5607/go.mod h1:br36fEqurGYZQGit+iDYsIzW0FF6VufMbDzyyLxEuPA= github.com/crowdsecurity/dlog v0.0.0-20170105205344-4fb5f8204f26 h1:r97WNVC30Uen+7WnLs4xDScS/Ex988+id2k6mDf8psU= github.com/crowdsecurity/dlog v0.0.0-20170105205344-4fb5f8204f26/go.mod h1:zpv7r+7KXwgVUZnUNjyP22zc/D7LKjyoY02weH2RBbk= -github.com/crowdsecurity/go-cs-lib v0.0.5 h1:eVLW+BRj3ZYn0xt5/xmgzfbbB8EBo32gM4+WpQQk2e8= -github.com/crowdsecurity/go-cs-lib v0.0.5/go.mod h1:8FMKNGsh3hMZi2SEv6P15PURhEJnZV431XjzzBSuf0k= github.com/crowdsecurity/go-cs-lib v0.0.6 h1:Ef6MylXe0GaJE9vrfvxEdbHb31+JUP1os+murPz7Pos= github.com/crowdsecurity/go-cs-lib v0.0.6/go.mod h1:8FMKNGsh3hMZi2SEv6P15PURhEJnZV431XjzzBSuf0k= github.com/crowdsecurity/grokky v0.2.1 h1:t4VYnDlAd0RjDM2SlILalbwfCrQxtJSMGdQOR0zwkE4= @@ -640,8 +636,8 @@ github.com/spf13/cast v1.3.1 h1:nFm6S0SMdyzrzcmThSipiEubIDy8WEXKNZ0UOgiRpng= github.com/spf13/cast v1.3.1/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE= github.com/spf13/cobra v0.0.3/go.mod h1:1l0Ry5zgKvJasoi3XT1TypsSe7PqH0Sj9dhYf7v3XqQ= github.com/spf13/cobra v1.4.0/go.mod h1:Wo4iy3BUC+X2Fybo0PDqwJIv3dNRiZLHQymsfxlB84g= -github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I= -github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0= +github.com/spf13/cobra v1.8.0 h1:7aJaZx1B85qltLMc546zn58BxxfZdR/W22ej9CFoEf0= +github.com/spf13/cobra v1.8.0/go.mod h1:WXLWApfZ71AjXPya3WOlMsY9yMs7YeiHhFVlvLyhcho= github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= @@ -809,8 +805,6 @@ golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.5.0 h1:60k92dhOjHxJkrqnwsfl8KuaHbn/5dl0lUPUklKo3qE= -golang.org/x/sync v0.5.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ= golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= From bdecf38616723dddf30a7c776694cd020f8a6944 Mon Sep 17 00:00:00 2001 From: blotus Date: Mon, 12 Feb 2024 11:33:44 +0100 Subject: [PATCH 04/20] update codeql action to v3 (#2822) --- .github/workflows/codeql-analysis.yml | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml index 0904769dd60..4b262f13d09 100644 --- a/.github/workflows/codeql-analysis.yml +++ b/.github/workflows/codeql-analysis.yml @@ -48,10 +48,15 @@ jobs: with: # required to pick up tags for BUILD_VERSION fetch-depth: 0 + - name: "Set up Go" + uses: actions/setup-go@v5 + with: + go-version: "1.21.6" + cache-dependency-path: "**/go.sum" # Initializes the CodeQL tools for scanning. - name: Initialize CodeQL - uses: github/codeql-action/init@v2 + uses: github/codeql-action/init@v3 with: languages: ${{ matrix.language }} # If you wish to specify custom queries, you can do so here or in a config file. @@ -71,14 +76,8 @@ jobs: # and modify them (or add more) to build your code if your project # uses a compiled language - - name: "Set up Go" - uses: actions/setup-go@v5 - with: - go-version: "1.21.6" - cache-dependency-path: "**/go.sum" - - run: | make clean build BUILD_RE2_WASM=1 - name: Perform CodeQL Analysis - uses: github/codeql-action/analyze@v2 + uses: github/codeql-action/analyze@v3 From eada3739e6849cf6da085dfa4862dcbfad4deb10 Mon Sep 17 00:00:00 2001 From: mmetc <92726601+mmetc@users.noreply.github.com> Date: Mon, 12 Feb 2024 11:40:59 +0100 Subject: [PATCH 05/20] refact "cscli notifications" (#2833) --- cmd/crowdsec-cli/main.go | 2 +- cmd/crowdsec-cli/notifications.go | 155 ++++++++++++++++++------------ 2 files changed, 95 insertions(+), 62 deletions(-) diff --git a/cmd/crowdsec-cli/main.go b/cmd/crowdsec-cli/main.go index 43998623566..63b7211b39b 100644 --- a/cmd/crowdsec-cli/main.go +++ b/cmd/crowdsec-cli/main.go @@ -246,7 +246,7 @@ It is meant to allow you to manage bans, parsers/scenarios/etc, api and generall cmd.AddCommand(NewConsoleCmd()) cmd.AddCommand(NewCLIExplain(cli.cfg).NewCommand()) cmd.AddCommand(NewCLIHubTest().NewCommand()) - cmd.AddCommand(NewCLINotifications().NewCommand()) + cmd.AddCommand(NewCLINotifications(cli.cfg).NewCommand()) cmd.AddCommand(NewCLISupport().NewCommand()) cmd.AddCommand(NewCLIPapi(cli.cfg).NewCommand()) cmd.AddCommand(NewCLICollection().NewCommand()) diff --git a/cmd/crowdsec-cli/notifications.go b/cmd/crowdsec-cli/notifications.go index da436420d12..f12333a3942 100644 --- a/cmd/crowdsec-cli/notifications.go +++ b/cmd/crowdsec-cli/notifications.go @@ -23,14 +23,13 @@ import ( "github.com/crowdsecurity/go-cs-lib/ptr" "github.com/crowdsecurity/go-cs-lib/version" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require" "github.com/crowdsecurity/crowdsec/pkg/apiclient" "github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/csplugin" "github.com/crowdsecurity/crowdsec/pkg/csprofiles" - "github.com/crowdsecurity/crowdsec/pkg/types" - - "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require" "github.com/crowdsecurity/crowdsec/pkg/models" + "github.com/crowdsecurity/crowdsec/pkg/types" ) type NotificationsCfg struct { @@ -39,13 +38,17 @@ type NotificationsCfg struct { ids []uint } -type cliNotifications struct{} +type cliNotifications struct { + cfg configGetter +} -func NewCLINotifications() *cliNotifications { - return &cliNotifications{} +func NewCLINotifications(cfg configGetter) *cliNotifications { + return &cliNotifications{ + cfg: cfg, + } } -func (cli cliNotifications) NewCommand() *cobra.Command { +func (cli *cliNotifications) NewCommand() *cobra.Command { cmd := &cobra.Command{ Use: "notifications [action]", Short: "Helper for notification plugin configuration", @@ -53,14 +56,15 @@ func (cli cliNotifications) NewCommand() *cobra.Command { Args: cobra.MinimumNArgs(1), Aliases: []string{"notifications", "notification"}, DisableAutoGenTag: true, - PersistentPreRunE: func(cmd *cobra.Command, args []string) error { - if err := require.LAPI(csConfig); err != nil { + PersistentPreRunE: func(_ *cobra.Command, _ []string) error { + cfg := cli.cfg() + if err := require.LAPI(cfg); err != nil { return err } - if err := csConfig.LoadAPIClient(); err != nil { + if err := cfg.LoadAPIClient(); err != nil { return fmt.Errorf("loading api client: %w", err) } - if err := require.Notifications(csConfig); err != nil { + if err := require.Notifications(cfg); err != nil { return err } @@ -76,67 +80,79 @@ func (cli cliNotifications) NewCommand() *cobra.Command { return cmd } -func getPluginConfigs() (map[string]csplugin.PluginConfig, error) { +func (cli *cliNotifications) getPluginConfigs() (map[string]csplugin.PluginConfig, error) { + cfg := cli.cfg() pcfgs := map[string]csplugin.PluginConfig{} wf := func(path string, info fs.FileInfo, err error) error { if info == nil { return fmt.Errorf("error while traversing directory %s: %w", path, err) } - name := filepath.Join(csConfig.ConfigPaths.NotificationDir, info.Name()) //Avoid calling info.Name() twice + + name := filepath.Join(cfg.ConfigPaths.NotificationDir, info.Name()) //Avoid calling info.Name() twice if (strings.HasSuffix(name, "yaml") || strings.HasSuffix(name, "yml")) && !(info.IsDir()) { ts, err := csplugin.ParsePluginConfigFile(name) if err != nil { return fmt.Errorf("loading notifification plugin configuration with %s: %w", name, err) } + for _, t := range ts { csplugin.SetRequiredFields(&t) pcfgs[t.Name] = t } } + return nil } - if err := filepath.Walk(csConfig.ConfigPaths.NotificationDir, wf); err != nil { + if err := filepath.Walk(cfg.ConfigPaths.NotificationDir, wf); err != nil { return nil, fmt.Errorf("while loading notifification plugin configuration: %w", err) } + return pcfgs, nil } -func getProfilesConfigs() (map[string]NotificationsCfg, error) { +func (cli *cliNotifications) getProfilesConfigs() (map[string]NotificationsCfg, error) { + cfg := cli.cfg() // A bit of a tricky stuf now: reconcile profiles and notification plugins - pcfgs, err := getPluginConfigs() + pcfgs, err := cli.getPluginConfigs() if err != nil { return nil, err } + ncfgs := map[string]NotificationsCfg{} for _, pc := range pcfgs { ncfgs[pc.Name] = NotificationsCfg{ Config: pc, } } - profiles, err := csprofiles.NewProfile(csConfig.API.Server.Profiles) + + profiles, err := csprofiles.NewProfile(cfg.API.Server.Profiles) if err != nil { return nil, fmt.Errorf("while extracting profiles from configuration: %w", err) } + for profileID, profile := range profiles { for _, notif := range profile.Cfg.Notifications { pc, ok := pcfgs[notif] if !ok { return nil, fmt.Errorf("notification plugin '%s' does not exist", notif) } + tmp, ok := ncfgs[pc.Name] if !ok { return nil, fmt.Errorf("notification plugin '%s' does not exist", pc.Name) } + tmp.Profiles = append(tmp.Profiles, profile.Cfg) tmp.ids = append(tmp.ids, uint(profileID)) ncfgs[pc.Name] = tmp } } + return ncfgs, nil } -func (cli cliNotifications) NewListCmd() *cobra.Command { +func (cli *cliNotifications) NewListCmd() *cobra.Command { cmd := &cobra.Command{ Use: "list", Short: "list active notifications plugins", @@ -144,21 +160,22 @@ func (cli cliNotifications) NewListCmd() *cobra.Command { Example: `cscli notifications list`, Args: cobra.ExactArgs(0), DisableAutoGenTag: true, - RunE: func(cmd *cobra.Command, arg []string) error { - ncfgs, err := getProfilesConfigs() + RunE: func(_ *cobra.Command, _ []string) error { + cfg := cli.cfg() + ncfgs, err := cli.getProfilesConfigs() if err != nil { return fmt.Errorf("can't build profiles configuration: %w", err) } - if csConfig.Cscli.Output == "human" { + if cfg.Cscli.Output == "human" { notificationListTable(color.Output, ncfgs) - } else if csConfig.Cscli.Output == "json" { + } else if cfg.Cscli.Output == "json" { x, err := json.MarshalIndent(ncfgs, "", " ") if err != nil { return fmt.Errorf("failed to marshal notification configuration: %w", err) } fmt.Printf("%s", string(x)) - } else if csConfig.Cscli.Output == "raw" { + } else if cfg.Cscli.Output == "raw" { csvwriter := csv.NewWriter(os.Stdout) err := csvwriter.Write([]string{"Name", "Type", "Profile name"}) if err != nil { @@ -176,6 +193,7 @@ func (cli cliNotifications) NewListCmd() *cobra.Command { } csvwriter.Flush() } + return nil }, } @@ -183,7 +201,7 @@ func (cli cliNotifications) NewListCmd() *cobra.Command { return cmd } -func (cli cliNotifications) NewInspectCmd() *cobra.Command { +func (cli *cliNotifications) NewInspectCmd() *cobra.Command { cmd := &cobra.Command{ Use: "inspect", Short: "Inspect active notifications plugin configuration", @@ -191,36 +209,32 @@ func (cli cliNotifications) NewInspectCmd() *cobra.Command { Example: `cscli notifications inspect `, Args: cobra.ExactArgs(1), DisableAutoGenTag: true, - PreRunE: func(cmd *cobra.Command, args []string) error { - if args[0] == "" { - return fmt.Errorf("please provide a plugin name to inspect") - } - return nil - }, - RunE: func(cmd *cobra.Command, args []string) error { - ncfgs, err := getProfilesConfigs() + RunE: func(_ *cobra.Command, args []string) error { + cfg := cli.cfg() + ncfgs, err := cli.getProfilesConfigs() if err != nil { return fmt.Errorf("can't build profiles configuration: %w", err) } - cfg, ok := ncfgs[args[0]] + ncfg, ok := ncfgs[args[0]] if !ok { return fmt.Errorf("plugin '%s' does not exist or is not active", args[0]) } - if csConfig.Cscli.Output == "human" || csConfig.Cscli.Output == "raw" { - fmt.Printf(" - %15s: %15s\n", "Type", cfg.Config.Type) - fmt.Printf(" - %15s: %15s\n", "Name", cfg.Config.Name) - fmt.Printf(" - %15s: %15s\n", "Timeout", cfg.Config.TimeOut) - fmt.Printf(" - %15s: %15s\n", "Format", cfg.Config.Format) - for k, v := range cfg.Config.Config { + if cfg.Cscli.Output == "human" || cfg.Cscli.Output == "raw" { + fmt.Printf(" - %15s: %15s\n", "Type", ncfg.Config.Type) + fmt.Printf(" - %15s: %15s\n", "Name", ncfg.Config.Name) + fmt.Printf(" - %15s: %15s\n", "Timeout", ncfg.Config.TimeOut) + fmt.Printf(" - %15s: %15s\n", "Format", ncfg.Config.Format) + for k, v := range ncfg.Config.Config { fmt.Printf(" - %15s: %15v\n", k, v) } - } else if csConfig.Cscli.Output == "json" { + } else if cfg.Cscli.Output == "json" { x, err := json.MarshalIndent(cfg, "", " ") if err != nil { return fmt.Errorf("failed to marshal notification configuration: %w", err) } fmt.Printf("%s", string(x)) } + return nil }, } @@ -228,12 +242,13 @@ func (cli cliNotifications) NewInspectCmd() *cobra.Command { return cmd } -func (cli cliNotifications) NewTestCmd() *cobra.Command { +func (cli *cliNotifications) NewTestCmd() *cobra.Command { var ( pluginBroker csplugin.PluginBroker pluginTomb tomb.Tomb alertOverride string ) + cmd := &cobra.Command{ Use: "test [plugin name]", Short: "send a generic test alert to notification plugin", @@ -241,25 +256,26 @@ func (cli cliNotifications) NewTestCmd() *cobra.Command { Example: `cscli notifications test [plugin_name]`, Args: cobra.ExactArgs(1), DisableAutoGenTag: true, - PreRunE: func(cmd *cobra.Command, args []string) error { - pconfigs, err := getPluginConfigs() + PreRunE: func(_ *cobra.Command, args []string) error { + cfg := cli.cfg() + pconfigs, err := cli.getPluginConfigs() if err != nil { return fmt.Errorf("can't build profiles configuration: %w", err) } - cfg, ok := pconfigs[args[0]] + pcfg, ok := pconfigs[args[0]] if !ok { return fmt.Errorf("plugin name: '%s' does not exist", args[0]) } //Create a single profile with plugin name as notification name - return pluginBroker.Init(csConfig.PluginConfig, []*csconfig.ProfileCfg{ + return pluginBroker.Init(cfg.PluginConfig, []*csconfig.ProfileCfg{ { Notifications: []string{ - cfg.Name, + pcfg.Name, }, }, - }, csConfig.ConfigPaths) + }, cfg.ConfigPaths) }, - RunE: func(cmd *cobra.Command, args []string) error { + RunE: func(_ *cobra.Command, _ []string) error { pluginTomb.Go(func() error { pluginBroker.Run(&pluginTomb) return nil @@ -298,13 +314,16 @@ func (cli cliNotifications) NewTestCmd() *cobra.Command { if err := yaml.Unmarshal([]byte(alertOverride), alert); err != nil { return fmt.Errorf("failed to unmarshal alert override: %w", err) } + pluginBroker.PluginChannel <- csplugin.ProfileAlert{ ProfileID: uint(0), Alert: alert, } + //time.Sleep(2 * time.Second) // There's no mechanism to ensure notification has been sent pluginTomb.Kill(fmt.Errorf("terminating")) pluginTomb.Wait() + return nil }, } @@ -313,9 +332,11 @@ func (cli cliNotifications) NewTestCmd() *cobra.Command { return cmd } -func (cli cliNotifications) NewReinjectCmd() *cobra.Command { - var alertOverride string - var alert *models.Alert +func (cli *cliNotifications) NewReinjectCmd() *cobra.Command { + var ( + alertOverride string + alert *models.Alert + ) cmd := &cobra.Command{ Use: "reinject", @@ -328,25 +349,30 @@ cscli notifications reinject -a '{"remediation": true,"scenario":"not `, Args: cobra.ExactArgs(1), DisableAutoGenTag: true, - PreRunE: func(cmd *cobra.Command, args []string) error { + PreRunE: func(_ *cobra.Command, args []string) error { var err error - alert, err = FetchAlertFromArgString(args[0]) + alert, err = cli.fetchAlertFromArgString(args[0]) if err != nil { return err } + return nil }, - RunE: func(cmd *cobra.Command, args []string) error { + RunE: func(_ *cobra.Command, _ []string) error { var ( pluginBroker csplugin.PluginBroker pluginTomb tomb.Tomb ) + + cfg := cli.cfg() + if alertOverride != "" { if err := json.Unmarshal([]byte(alertOverride), alert); err != nil { return fmt.Errorf("can't unmarshal data in the alert flag: %w", err) } } - err := pluginBroker.Init(csConfig.PluginConfig, csConfig.API.Server.Profiles, csConfig.ConfigPaths) + + err := pluginBroker.Init(cfg.PluginConfig, cfg.API.Server.Profiles, cfg.ConfigPaths) if err != nil { return fmt.Errorf("can't initialize plugins: %w", err) } @@ -356,7 +382,7 @@ cscli notifications reinject -a '{"remediation": true,"scenario":"not return nil }) - profiles, err := csprofiles.NewProfile(csConfig.API.Server.Profiles) + profiles, err := csprofiles.NewProfile(cfg.API.Server.Profiles) if err != nil { return fmt.Errorf("cannot extract profiles from configuration: %w", err) } @@ -382,9 +408,9 @@ cscli notifications reinject -a '{"remediation": true,"scenario":"not default: time.Sleep(50 * time.Millisecond) log.Info("sleeping\n") - } } + if profile.Cfg.OnSuccess == "break" { log.Infof("The profile %s contains a 'on_success: break' so bailing out", profile.Cfg.Name) break @@ -393,6 +419,7 @@ cscli notifications reinject -a '{"remediation": true,"scenario":"not //time.Sleep(2 * time.Second) // There's no mechanism to ensure notification has been sent pluginTomb.Kill(fmt.Errorf("terminating")) pluginTomb.Wait() + return nil }, } @@ -401,18 +428,22 @@ cscli notifications reinject -a '{"remediation": true,"scenario":"not return cmd } -func FetchAlertFromArgString(toParse string) (*models.Alert, error) { +func (cli *cliNotifications) fetchAlertFromArgString(toParse string) (*models.Alert, error) { + cfg := cli.cfg() + id, err := strconv.Atoi(toParse) if err != nil { return nil, fmt.Errorf("bad alert id %s", toParse) } - apiURL, err := url.Parse(csConfig.API.Client.Credentials.URL) + + apiURL, err := url.Parse(cfg.API.Client.Credentials.URL) if err != nil { return nil, fmt.Errorf("error parsing the URL of the API: %w", err) } + client, err := apiclient.NewClient(&apiclient.Config{ - MachineID: csConfig.API.Client.Credentials.Login, - Password: strfmt.Password(csConfig.API.Client.Credentials.Password), + MachineID: cfg.API.Client.Credentials.Login, + Password: strfmt.Password(cfg.API.Client.Credentials.Password), UserAgent: fmt.Sprintf("crowdsec/%s", version.String()), URL: apiURL, VersionPrefix: "v1", @@ -420,9 +451,11 @@ func FetchAlertFromArgString(toParse string) (*models.Alert, error) { if err != nil { return nil, fmt.Errorf("error creating the client for the API: %w", err) } + alert, _, err := client.Alerts.GetByID(context.Background(), id) if err != nil { return nil, fmt.Errorf("can't find alert with id %d: %w", id, err) } + return alert, nil } From a6a4d460d7069a67369906fbe4447eed601b4942 Mon Sep 17 00:00:00 2001 From: mmetc <92726601+mmetc@users.noreply.github.com> Date: Mon, 12 Feb 2024 11:45:58 +0100 Subject: [PATCH 06/20] refact "cscli console" (#2834) --- cmd/crowdsec-cli/console.go | 253 ++++++++++++++++++------------ cmd/crowdsec-cli/console_table.go | 14 +- cmd/crowdsec-cli/main.go | 2 +- 3 files changed, 160 insertions(+), 109 deletions(-) diff --git a/cmd/crowdsec-cli/console.go b/cmd/crowdsec-cli/console.go index dcd6fb37f62..b1912825c06 100644 --- a/cmd/crowdsec-cli/console.go +++ b/cmd/crowdsec-cli/console.go @@ -25,32 +25,53 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/types" ) -func NewConsoleCmd() *cobra.Command { - var cmdConsole = &cobra.Command{ +type cliConsole struct { + cfg configGetter +} + +func NewCLIConsole(cfg configGetter) *cliConsole { + return &cliConsole{ + cfg: cfg, + } +} + +func (cli *cliConsole) NewCommand() *cobra.Command { + var cmd = &cobra.Command{ Use: "console [action]", Short: "Manage interaction with Crowdsec console (https://app.crowdsec.net)", Args: cobra.MinimumNArgs(1), DisableAutoGenTag: true, - PersistentPreRunE: func(cmd *cobra.Command, args []string) error { - if err := require.LAPI(csConfig); err != nil { + PersistentPreRunE: func(_ *cobra.Command, _ []string) error { + cfg := cli.cfg() + if err := require.LAPI(cfg); err != nil { return err } - if err := require.CAPI(csConfig); err != nil { + if err := require.CAPI(cfg); err != nil { return err } - if err := require.CAPIRegistered(csConfig); err != nil { + if err := require.CAPIRegistered(cfg); err != nil { return err } + return nil }, } + cmd.AddCommand(cli.newEnrollCmd()) + cmd.AddCommand(cli.newEnableCmd()) + cmd.AddCommand(cli.newDisableCmd()) + cmd.AddCommand(cli.newStatusCmd()) + + return cmd +} + +func (cli *cliConsole) newEnrollCmd() *cobra.Command { name := "" overwrite := false tags := []string{} opts := []string{} - cmdEnroll := &cobra.Command{ + cmd := &cobra.Command{ Use: "enroll [enroll-key]", Short: "Enroll this instance to https://app.crowdsec.net [requires local API]", Long: ` @@ -66,96 +87,107 @@ After running this command your will need to validate the enrollment in the weba valid options are : %s,all (see 'cscli console status' for details)`, strings.Join(csconfig.CONSOLE_CONFIGS, ",")), Args: cobra.ExactArgs(1), DisableAutoGenTag: true, - RunE: func(cmd *cobra.Command, args []string) error { - password := strfmt.Password(csConfig.API.Server.OnlineClient.Credentials.Password) - apiURL, err := url.Parse(csConfig.API.Server.OnlineClient.Credentials.URL) + RunE: func(_ *cobra.Command, args []string) error { + cfg := cli.cfg() + password := strfmt.Password(cfg.API.Server.OnlineClient.Credentials.Password) + + apiURL, err := url.Parse(cfg.API.Server.OnlineClient.Credentials.URL) if err != nil { - return fmt.Errorf("could not parse CAPI URL: %s", err) + return fmt.Errorf("could not parse CAPI URL: %w", err) } - hub, err := require.Hub(csConfig, nil, nil) + hub, err := require.Hub(cfg, nil, nil) if err != nil { return err } scenarios, err := hub.GetInstalledItemNames(cwhub.SCENARIOS) if err != nil { - return fmt.Errorf("failed to get installed scenarios: %s", err) + return fmt.Errorf("failed to get installed scenarios: %w", err) } if len(scenarios) == 0 { scenarios = make([]string, 0) } - enable_opts := []string{csconfig.SEND_MANUAL_SCENARIOS, csconfig.SEND_TAINTED_SCENARIOS} + enableOpts := []string{csconfig.SEND_MANUAL_SCENARIOS, csconfig.SEND_TAINTED_SCENARIOS} if len(opts) != 0 { for _, opt := range opts { valid := false if opt == "all" { - enable_opts = csconfig.CONSOLE_CONFIGS + enableOpts = csconfig.CONSOLE_CONFIGS break } - for _, available_opt := range csconfig.CONSOLE_CONFIGS { - if opt == available_opt { + for _, availableOpt := range csconfig.CONSOLE_CONFIGS { + if opt == availableOpt { valid = true enable := true - for _, enabled_opt := range enable_opts { - if opt == enabled_opt { + for _, enabledOpt := range enableOpts { + if opt == enabledOpt { enable = false continue } } if enable { - enable_opts = append(enable_opts, opt) + enableOpts = append(enableOpts, opt) } + break } } if !valid { return fmt.Errorf("option %s doesn't exist", opt) - } } } c, _ := apiclient.NewClient(&apiclient.Config{ - MachineID: csConfig.API.Server.OnlineClient.Credentials.Login, + MachineID: cli.cfg().API.Server.OnlineClient.Credentials.Login, Password: password, Scenarios: scenarios, UserAgent: fmt.Sprintf("crowdsec/%s", version.String()), URL: apiURL, VersionPrefix: "v3", }) + resp, err := c.Auth.EnrollWatcher(context.Background(), args[0], name, tags, overwrite) if err != nil { - return fmt.Errorf("could not enroll instance: %s", err) + return fmt.Errorf("could not enroll instance: %w", err) } + if resp.Response.StatusCode == 200 && !overwrite { log.Warning("Instance already enrolled. You can use '--overwrite' to force enroll") return nil } - if err := SetConsoleOpts(enable_opts, true); err != nil { + if err := cli.setConsoleOpts(enableOpts, true); err != nil { return err } - for _, opt := range enable_opts { + for _, opt := range enableOpts { log.Infof("Enabled %s : %s", opt, csconfig.CONSOLE_CONFIGS_HELP[opt]) } + log.Info("Watcher successfully enrolled. Visit https://app.crowdsec.net to accept it.") log.Info("Please restart crowdsec after accepting the enrollment.") + return nil }, } - cmdEnroll.Flags().StringVarP(&name, "name", "n", "", "Name to display in the console") - cmdEnroll.Flags().BoolVarP(&overwrite, "overwrite", "", false, "Force enroll the instance") - cmdEnroll.Flags().StringSliceVarP(&tags, "tags", "t", tags, "Tags to display in the console") - cmdEnroll.Flags().StringSliceVarP(&opts, "enable", "e", opts, "Enable console options") - cmdConsole.AddCommand(cmdEnroll) - var enableAll, disableAll bool + flags := cmd.Flags() + flags.StringVarP(&name, "name", "n", "", "Name to display in the console") + flags.BoolVarP(&overwrite, "overwrite", "", false, "Force enroll the instance") + flags.StringSliceVarP(&tags, "tags", "t", tags, "Tags to display in the console") + flags.StringSliceVarP(&opts, "enable", "e", opts, "Enable console options") + + return cmd +} + +func (cli *cliConsole) newEnableCmd() *cobra.Command { + var enableAll bool - cmdEnable := &cobra.Command{ + cmd := &cobra.Command{ Use: "enable [option]", Short: "Enable a console option", Example: "sudo cscli console enable tainted", @@ -163,9 +195,9 @@ After running this command your will need to validate the enrollment in the weba Enable given information push to the central API. Allows to empower the console`, ValidArgs: csconfig.CONSOLE_CONFIGS, DisableAutoGenTag: true, - RunE: func(cmd *cobra.Command, args []string) error { + RunE: func(_ *cobra.Command, args []string) error { if enableAll { - if err := SetConsoleOpts(csconfig.CONSOLE_CONFIGS, true); err != nil { + if err := cli.setConsoleOpts(csconfig.CONSOLE_CONFIGS, true); err != nil { return err } log.Infof("All features have been enabled successfully") @@ -173,19 +205,26 @@ Enable given information push to the central API. Allows to empower the console` if len(args) == 0 { return fmt.Errorf("you must specify at least one feature to enable") } - if err := SetConsoleOpts(args, true); err != nil { + if err := cli.setConsoleOpts(args, true); err != nil { return err } log.Infof("%v have been enabled", args) } + log.Infof(ReloadMessage()) + return nil }, } - cmdEnable.Flags().BoolVarP(&enableAll, "all", "a", false, "Enable all console options") - cmdConsole.AddCommand(cmdEnable) + cmd.Flags().BoolVarP(&enableAll, "all", "a", false, "Enable all console options") + + return cmd +} - cmdDisable := &cobra.Command{ +func (cli *cliConsole) newDisableCmd() *cobra.Command { + var disableAll bool + + cmd := &cobra.Command{ Use: "disable [option]", Short: "Disable a console option", Example: "sudo cscli console disable tainted", @@ -193,47 +232,52 @@ Enable given information push to the central API. Allows to empower the console` Disable given information push to the central API.`, ValidArgs: csconfig.CONSOLE_CONFIGS, DisableAutoGenTag: true, - RunE: func(cmd *cobra.Command, args []string) error { + RunE: func(_ *cobra.Command, args []string) error { if disableAll { - if err := SetConsoleOpts(csconfig.CONSOLE_CONFIGS, false); err != nil { + if err := cli.setConsoleOpts(csconfig.CONSOLE_CONFIGS, false); err != nil { return err } log.Infof("All features have been disabled") } else { - if err := SetConsoleOpts(args, false); err != nil { + if err := cli.setConsoleOpts(args, false); err != nil { return err } log.Infof("%v have been disabled", args) } log.Infof(ReloadMessage()) + return nil }, } - cmdDisable.Flags().BoolVarP(&disableAll, "all", "a", false, "Disable all console options") - cmdConsole.AddCommand(cmdDisable) + cmd.Flags().BoolVarP(&disableAll, "all", "a", false, "Disable all console options") + + return cmd +} - cmdConsoleStatus := &cobra.Command{ +func (cli *cliConsole) newStatusCmd() *cobra.Command { + cmd := &cobra.Command{ Use: "status", Short: "Shows status of the console options", Example: `sudo cscli console status`, DisableAutoGenTag: true, - RunE: func(cmd *cobra.Command, args []string) error { - switch csConfig.Cscli.Output { + RunE: func(_ *cobra.Command, _ []string) error { + cfg := cli.cfg() + consoleCfg := cfg.API.Server.ConsoleConfig + switch cfg.Cscli.Output { case "human": - cmdConsoleStatusTable(color.Output, *csConfig) + cmdConsoleStatusTable(color.Output, *consoleCfg) case "json": - c := csConfig.API.Server.ConsoleConfig out := map[string](*bool){ - csconfig.SEND_MANUAL_SCENARIOS: c.ShareManualDecisions, - csconfig.SEND_CUSTOM_SCENARIOS: c.ShareCustomScenarios, - csconfig.SEND_TAINTED_SCENARIOS: c.ShareTaintedScenarios, - csconfig.SEND_CONTEXT: c.ShareContext, - csconfig.CONSOLE_MANAGEMENT: c.ConsoleManagement, + csconfig.SEND_MANUAL_SCENARIOS: consoleCfg.ShareManualDecisions, + csconfig.SEND_CUSTOM_SCENARIOS: consoleCfg.ShareCustomScenarios, + csconfig.SEND_TAINTED_SCENARIOS: consoleCfg.ShareTaintedScenarios, + csconfig.SEND_CONTEXT: consoleCfg.ShareContext, + csconfig.CONSOLE_MANAGEMENT: consoleCfg.ConsoleManagement, } data, err := json.MarshalIndent(out, "", " ") if err != nil { - return fmt.Errorf("failed to marshal configuration: %s", err) + return fmt.Errorf("failed to marshal configuration: %w", err) } fmt.Println(string(data)) case "raw": @@ -244,11 +288,11 @@ Disable given information push to the central API.`, } rows := [][]string{ - {csconfig.SEND_MANUAL_SCENARIOS, fmt.Sprintf("%t", *csConfig.API.Server.ConsoleConfig.ShareManualDecisions)}, - {csconfig.SEND_CUSTOM_SCENARIOS, fmt.Sprintf("%t", *csConfig.API.Server.ConsoleConfig.ShareCustomScenarios)}, - {csconfig.SEND_TAINTED_SCENARIOS, fmt.Sprintf("%t", *csConfig.API.Server.ConsoleConfig.ShareTaintedScenarios)}, - {csconfig.SEND_CONTEXT, fmt.Sprintf("%t", *csConfig.API.Server.ConsoleConfig.ShareContext)}, - {csconfig.CONSOLE_MANAGEMENT, fmt.Sprintf("%t", *csConfig.API.Server.ConsoleConfig.ConsoleManagement)}, + {csconfig.SEND_MANUAL_SCENARIOS, fmt.Sprintf("%t", *consoleCfg.ShareManualDecisions)}, + {csconfig.SEND_CUSTOM_SCENARIOS, fmt.Sprintf("%t", *consoleCfg.ShareCustomScenarios)}, + {csconfig.SEND_TAINTED_SCENARIOS, fmt.Sprintf("%t", *consoleCfg.ShareTaintedScenarios)}, + {csconfig.SEND_CONTEXT, fmt.Sprintf("%t", *consoleCfg.ShareContext)}, + {csconfig.CONSOLE_MANAGEMENT, fmt.Sprintf("%t", *consoleCfg.ConsoleManagement)}, } for _, row := range rows { err = csvwriter.Write(row) @@ -258,132 +302,137 @@ Disable given information push to the central API.`, } csvwriter.Flush() } + return nil }, } - cmdConsole.AddCommand(cmdConsoleStatus) - return cmdConsole + return cmd } -func dumpConsoleConfig(c *csconfig.LocalApiServerCfg) error { - out, err := yaml.Marshal(c.ConsoleConfig) +func (cli *cliConsole) dumpConfig() error { + serverCfg := cli.cfg().API.Server + + out, err := yaml.Marshal(serverCfg.ConsoleConfig) if err != nil { - return fmt.Errorf("while marshaling ConsoleConfig (for %s): %w", c.ConsoleConfigPath, err) + return fmt.Errorf("while marshaling ConsoleConfig (for %s): %w", serverCfg.ConsoleConfigPath, err) } - if c.ConsoleConfigPath == "" { - c.ConsoleConfigPath = csconfig.DefaultConsoleConfigFilePath - log.Debugf("Empty console_path, defaulting to %s", c.ConsoleConfigPath) + if serverCfg.ConsoleConfigPath == "" { + serverCfg.ConsoleConfigPath = csconfig.DefaultConsoleConfigFilePath + log.Debugf("Empty console_path, defaulting to %s", serverCfg.ConsoleConfigPath) } - if err := os.WriteFile(c.ConsoleConfigPath, out, 0o600); err != nil { - return fmt.Errorf("while dumping console config to %s: %w", c.ConsoleConfigPath, err) + if err := os.WriteFile(serverCfg.ConsoleConfigPath, out, 0o600); err != nil { + return fmt.Errorf("while dumping console config to %s: %w", serverCfg.ConsoleConfigPath, err) } return nil } -func SetConsoleOpts(args []string, wanted bool) error { +func (cli *cliConsole) setConsoleOpts(args []string, wanted bool) error { + cfg := cli.cfg() + consoleCfg := cfg.API.Server.ConsoleConfig + for _, arg := range args { switch arg { case csconfig.CONSOLE_MANAGEMENT: /*for each flag check if it's already set before setting it*/ - if csConfig.API.Server.ConsoleConfig.ConsoleManagement != nil { - if *csConfig.API.Server.ConsoleConfig.ConsoleManagement == wanted { + if consoleCfg.ConsoleManagement != nil { + if *consoleCfg.ConsoleManagement == wanted { log.Debugf("%s already set to %t", csconfig.CONSOLE_MANAGEMENT, wanted) } else { log.Infof("%s set to %t", csconfig.CONSOLE_MANAGEMENT, wanted) - *csConfig.API.Server.ConsoleConfig.ConsoleManagement = wanted + *consoleCfg.ConsoleManagement = wanted } } else { log.Infof("%s set to %t", csconfig.CONSOLE_MANAGEMENT, wanted) - csConfig.API.Server.ConsoleConfig.ConsoleManagement = ptr.Of(wanted) + consoleCfg.ConsoleManagement = ptr.Of(wanted) } - if csConfig.API.Server.OnlineClient.Credentials != nil { + if cfg.API.Server.OnlineClient.Credentials != nil { changed := false - if wanted && csConfig.API.Server.OnlineClient.Credentials.PapiURL == "" { + if wanted && cfg.API.Server.OnlineClient.Credentials.PapiURL == "" { changed = true - csConfig.API.Server.OnlineClient.Credentials.PapiURL = types.PAPIBaseURL - } else if !wanted && csConfig.API.Server.OnlineClient.Credentials.PapiURL != "" { + cfg.API.Server.OnlineClient.Credentials.PapiURL = types.PAPIBaseURL + } else if !wanted && cfg.API.Server.OnlineClient.Credentials.PapiURL != "" { changed = true - csConfig.API.Server.OnlineClient.Credentials.PapiURL = "" + cfg.API.Server.OnlineClient.Credentials.PapiURL = "" } if changed { - fileContent, err := yaml.Marshal(csConfig.API.Server.OnlineClient.Credentials) + fileContent, err := yaml.Marshal(cfg.API.Server.OnlineClient.Credentials) if err != nil { - return fmt.Errorf("cannot marshal credentials: %s", err) + return fmt.Errorf("cannot marshal credentials: %w", err) } - log.Infof("Updating credentials file: %s", csConfig.API.Server.OnlineClient.CredentialsFilePath) + log.Infof("Updating credentials file: %s", cfg.API.Server.OnlineClient.CredentialsFilePath) - err = os.WriteFile(csConfig.API.Server.OnlineClient.CredentialsFilePath, fileContent, 0o600) + err = os.WriteFile(cfg.API.Server.OnlineClient.CredentialsFilePath, fileContent, 0o600) if err != nil { - return fmt.Errorf("cannot write credentials file: %s", err) + return fmt.Errorf("cannot write credentials file: %w", err) } } } case csconfig.SEND_CUSTOM_SCENARIOS: /*for each flag check if it's already set before setting it*/ - if csConfig.API.Server.ConsoleConfig.ShareCustomScenarios != nil { - if *csConfig.API.Server.ConsoleConfig.ShareCustomScenarios == wanted { + if consoleCfg.ShareCustomScenarios != nil { + if *consoleCfg.ShareCustomScenarios == wanted { log.Debugf("%s already set to %t", csconfig.SEND_CUSTOM_SCENARIOS, wanted) } else { log.Infof("%s set to %t", csconfig.SEND_CUSTOM_SCENARIOS, wanted) - *csConfig.API.Server.ConsoleConfig.ShareCustomScenarios = wanted + *consoleCfg.ShareCustomScenarios = wanted } } else { log.Infof("%s set to %t", csconfig.SEND_CUSTOM_SCENARIOS, wanted) - csConfig.API.Server.ConsoleConfig.ShareCustomScenarios = ptr.Of(wanted) + consoleCfg.ShareCustomScenarios = ptr.Of(wanted) } case csconfig.SEND_TAINTED_SCENARIOS: /*for each flag check if it's already set before setting it*/ - if csConfig.API.Server.ConsoleConfig.ShareTaintedScenarios != nil { - if *csConfig.API.Server.ConsoleConfig.ShareTaintedScenarios == wanted { + if consoleCfg.ShareTaintedScenarios != nil { + if *consoleCfg.ShareTaintedScenarios == wanted { log.Debugf("%s already set to %t", csconfig.SEND_TAINTED_SCENARIOS, wanted) } else { log.Infof("%s set to %t", csconfig.SEND_TAINTED_SCENARIOS, wanted) - *csConfig.API.Server.ConsoleConfig.ShareTaintedScenarios = wanted + *consoleCfg.ShareTaintedScenarios = wanted } } else { log.Infof("%s set to %t", csconfig.SEND_TAINTED_SCENARIOS, wanted) - csConfig.API.Server.ConsoleConfig.ShareTaintedScenarios = ptr.Of(wanted) + consoleCfg.ShareTaintedScenarios = ptr.Of(wanted) } case csconfig.SEND_MANUAL_SCENARIOS: /*for each flag check if it's already set before setting it*/ - if csConfig.API.Server.ConsoleConfig.ShareManualDecisions != nil { - if *csConfig.API.Server.ConsoleConfig.ShareManualDecisions == wanted { + if consoleCfg.ShareManualDecisions != nil { + if *consoleCfg.ShareManualDecisions == wanted { log.Debugf("%s already set to %t", csconfig.SEND_MANUAL_SCENARIOS, wanted) } else { log.Infof("%s set to %t", csconfig.SEND_MANUAL_SCENARIOS, wanted) - *csConfig.API.Server.ConsoleConfig.ShareManualDecisions = wanted + *consoleCfg.ShareManualDecisions = wanted } } else { log.Infof("%s set to %t", csconfig.SEND_MANUAL_SCENARIOS, wanted) - csConfig.API.Server.ConsoleConfig.ShareManualDecisions = ptr.Of(wanted) + consoleCfg.ShareManualDecisions = ptr.Of(wanted) } case csconfig.SEND_CONTEXT: /*for each flag check if it's already set before setting it*/ - if csConfig.API.Server.ConsoleConfig.ShareContext != nil { - if *csConfig.API.Server.ConsoleConfig.ShareContext == wanted { + if consoleCfg.ShareContext != nil { + if *consoleCfg.ShareContext == wanted { log.Debugf("%s already set to %t", csconfig.SEND_CONTEXT, wanted) } else { log.Infof("%s set to %t", csconfig.SEND_CONTEXT, wanted) - *csConfig.API.Server.ConsoleConfig.ShareContext = wanted + *consoleCfg.ShareContext = wanted } } else { log.Infof("%s set to %t", csconfig.SEND_CONTEXT, wanted) - csConfig.API.Server.ConsoleConfig.ShareContext = ptr.Of(wanted) + consoleCfg.ShareContext = ptr.Of(wanted) } default: return fmt.Errorf("unknown flag %s", arg) } } - if err := dumpConsoleConfig(csConfig.API.Server); err != nil { - return fmt.Errorf("failed writing console config: %s", err) + if err := cli.dumpConfig(); err != nil { + return fmt.Errorf("failed writing console config: %w", err) } return nil diff --git a/cmd/crowdsec-cli/console_table.go b/cmd/crowdsec-cli/console_table.go index 2a221e36f07..e71ea8113fb 100644 --- a/cmd/crowdsec-cli/console_table.go +++ b/cmd/crowdsec-cli/console_table.go @@ -9,7 +9,7 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/csconfig" ) -func cmdConsoleStatusTable(out io.Writer, csConfig csconfig.Config) { +func cmdConsoleStatusTable(out io.Writer, consoleCfg csconfig.ConsoleConfig) { t := newTable(out) t.SetRowLines(false) @@ -18,28 +18,30 @@ func cmdConsoleStatusTable(out io.Writer, csConfig csconfig.Config) { for _, option := range csconfig.CONSOLE_CONFIGS { activated := string(emoji.CrossMark) + switch option { case csconfig.SEND_CUSTOM_SCENARIOS: - if *csConfig.API.Server.ConsoleConfig.ShareCustomScenarios { + if *consoleCfg.ShareCustomScenarios { activated = string(emoji.CheckMarkButton) } case csconfig.SEND_MANUAL_SCENARIOS: - if *csConfig.API.Server.ConsoleConfig.ShareManualDecisions { + if *consoleCfg.ShareManualDecisions { activated = string(emoji.CheckMarkButton) } case csconfig.SEND_TAINTED_SCENARIOS: - if *csConfig.API.Server.ConsoleConfig.ShareTaintedScenarios { + if *consoleCfg.ShareTaintedScenarios { activated = string(emoji.CheckMarkButton) } case csconfig.SEND_CONTEXT: - if *csConfig.API.Server.ConsoleConfig.ShareContext { + if *consoleCfg.ShareContext { activated = string(emoji.CheckMarkButton) } case csconfig.CONSOLE_MANAGEMENT: - if *csConfig.API.Server.ConsoleConfig.ConsoleManagement { + if *consoleCfg.ConsoleManagement { activated = string(emoji.CheckMarkButton) } } + t.AddRow(option, activated, csconfig.CONSOLE_CONFIGS_HELP[option]) } diff --git a/cmd/crowdsec-cli/main.go b/cmd/crowdsec-cli/main.go index 63b7211b39b..27ac17d554f 100644 --- a/cmd/crowdsec-cli/main.go +++ b/cmd/crowdsec-cli/main.go @@ -243,7 +243,7 @@ It is meant to allow you to manage bans, parsers/scenarios/etc, api and generall cmd.AddCommand(NewCLICapi().NewCommand()) cmd.AddCommand(NewCLILapi(cli.cfg).NewCommand()) cmd.AddCommand(NewCompletionCmd()) - cmd.AddCommand(NewConsoleCmd()) + cmd.AddCommand(NewCLIConsole(cli.cfg).NewCommand()) cmd.AddCommand(NewCLIExplain(cli.cfg).NewCommand()) cmd.AddCommand(NewCLIHubTest().NewCommand()) cmd.AddCommand(NewCLINotifications(cli.cfg).NewCommand()) From 4561eb787be6e27693195807ba61181018aa6755 Mon Sep 17 00:00:00 2001 From: mmetc <92726601+mmetc@users.noreply.github.com> Date: Mon, 12 Feb 2024 20:15:16 +0100 Subject: [PATCH 07/20] bats: color formatter in CI (#2838) --- .github/workflows/bats-hub.yml | 5 +- .github/workflows/bats-mysql.yml | 5 +- .github/workflows/bats-postgres.yml | 5 +- .github/workflows/bats-sqlite-coverage.yml | 3 +- test/lib/color-formatter | 355 +++++++++++++++++++++ 5 files changed, 359 insertions(+), 14 deletions(-) create mode 100755 test/lib/color-formatter diff --git a/.github/workflows/bats-hub.yml b/.github/workflows/bats-hub.yml index fe45210ae96..7764da84812 100644 --- a/.github/workflows/bats-hub.yml +++ b/.github/workflows/bats-hub.yml @@ -8,9 +8,6 @@ on: GIST_BADGES_ID: required: true -env: - PREFIX_TEST_NAMES_WITH_FILE: true - jobs: build: strategy: @@ -50,7 +47,7 @@ jobs: - name: "Run hub tests" run: | ./test/bin/generate-hub-tests - ./test/run-tests test/dyn-bats/${{ matrix.test-file }} + ./test/run-tests ./test/dyn-bats/${{ matrix.test-file }} --formatter $(pwd)/test/lib/color-formatter - name: "Collect hub coverage" run: ./test/bin/collect-hub-coverage >> $GITHUB_ENV diff --git a/.github/workflows/bats-mysql.yml b/.github/workflows/bats-mysql.yml index 902c25ba329..243da6eb25d 100644 --- a/.github/workflows/bats-mysql.yml +++ b/.github/workflows/bats-mysql.yml @@ -7,9 +7,6 @@ on: required: true type: string -env: - PREFIX_TEST_NAMES_WITH_FILE: true - jobs: build: name: "Functional tests" @@ -58,7 +55,7 @@ jobs: MYSQL_USER: root - name: "Run tests" - run: make bats-test + run: ./test/run-tests ./test/bats --formatter $(pwd)/test/lib/color-formatter env: DB_BACKEND: mysql MYSQL_HOST: 127.0.0.1 diff --git a/.github/workflows/bats-postgres.yml b/.github/workflows/bats-postgres.yml index e15f1e410c1..07d3cd8d2f1 100644 --- a/.github/workflows/bats-postgres.yml +++ b/.github/workflows/bats-postgres.yml @@ -3,9 +3,6 @@ name: (sub) Bats / Postgres on: workflow_call: -env: - PREFIX_TEST_NAMES_WITH_FILE: true - jobs: build: name: "Functional tests" @@ -67,7 +64,7 @@ jobs: PGUSER: postgres - name: "Run tests (DB_BACKEND: pgx)" - run: make bats-test + run: ./test/run-tests ./test/bats --formatter $(pwd)/test/lib/color-formatter env: DB_BACKEND: pgx PGHOST: 127.0.0.1 diff --git a/.github/workflows/bats-sqlite-coverage.yml b/.github/workflows/bats-sqlite-coverage.yml index 36194555e1d..46a5dd8bc86 100644 --- a/.github/workflows/bats-sqlite-coverage.yml +++ b/.github/workflows/bats-sqlite-coverage.yml @@ -4,7 +4,6 @@ on: workflow_call: env: - PREFIX_TEST_NAMES_WITH_FILE: true TEST_COVERAGE: true jobs: @@ -42,7 +41,7 @@ jobs: make clean bats-build bats-fixture BUILD_STATIC=1 - name: "Run tests" - run: make bats-test + run: ./test/run-tests ./test/bats --formatter $(pwd)/test/lib/color-formatter - name: "Collect coverage data" run: | diff --git a/test/lib/color-formatter b/test/lib/color-formatter new file mode 100755 index 00000000000..aee8d750698 --- /dev/null +++ b/test/lib/color-formatter @@ -0,0 +1,355 @@ +#!/usr/bin/env bash + +# +# Taken from pretty formatter, minus the cursor movements. +# Used in gihtub workflows CI where color is allowed. +# + +set -e + +# shellcheck source=lib/bats-core/formatter.bash +source "$BATS_ROOT/lib/bats-core/formatter.bash" + +BASE_PATH=. +BATS_ENABLE_TIMING= + +while [[ "$#" -ne 0 ]]; do + case "$1" in + -T) + BATS_ENABLE_TIMING="-T" + ;; + --base-path) + shift + normalize_base_path BASE_PATH "$1" + ;; + esac + shift +done + +update_count_column_width() { + count_column_width=$((${#count} * 2 + 2)) + if [[ -n "$BATS_ENABLE_TIMING" ]]; then + # additional space for ' in %s sec' + count_column_width=$((count_column_width + ${#SECONDS} + 8)) + fi + # also update dependent value + update_count_column_left +} + +update_screen_width() { + screen_width="$(tput cols)" + # also update dependent value + update_count_column_left +} + +update_count_column_left() { + count_column_left=$((screen_width - count_column_width)) +} + +# avoid unset variables +count=0 +screen_width=80 +update_count_column_width +#update_screen_width +test_result= + +#trap update_screen_width WINCH + +begin() { + test_result= # reset to avoid carrying over result state from previous test + line_backoff_count=0 + #go_to_column 0 + #update_count_column_width + #buffer_with_truncation $((count_column_left - 1)) ' %s' "$name" + #clear_to_end_of_line + #go_to_column $count_column_left + #if [[ -n "$BATS_ENABLE_TIMING" ]]; then + # buffer "%${#count}s/${count} in %s sec" "$index" "$SECONDS" + #else + # buffer "%${#count}s/${count}" "$index" + #fi + #go_to_column 1 + buffer "%${#count}s" "$index" +} + +finish_test() { + #move_up $line_backoff_count + #go_to_column 0 + buffer "$@" + if [[ -n "${TIMEOUT-}" ]]; then + set_color 2 + if [[ -n "$BATS_ENABLE_TIMING" ]]; then + buffer ' [%s (timeout: %s)]' "$TIMING" "$TIMEOUT" + else + buffer ' [timeout: %s]' "$TIMEOUT" + fi + else + if [[ -n "$BATS_ENABLE_TIMING" ]]; then + set_color 2 + buffer ' [%s]' "$TIMING" + fi + fi + advance + move_down $((line_backoff_count - 1)) +} + +pass() { + local TIMING="${1:-}" + finish_test ' ✓ %s' "$name" + test_result=pass +} + +skip() { + local reason="$1" TIMING="${2:-}" + if [[ -n "$reason" ]]; then + reason=": $reason" + fi + finish_test ' - %s (skipped%s)' "$name" "$reason" + test_result=skip +} + +fail() { + local TIMING="${1:-}" + set_color 1 bold + finish_test ' ✗ %s' "$name" + test_result=fail +} + +timeout() { + local TIMING="${1:-}" + set_color 3 bold + TIMEOUT="${2:-}" finish_test ' ✗ %s' "$name" + test_result=timeout +} + +log() { + case ${test_result} in + pass) + clear_color + ;; + fail) + set_color 1 + ;; + timeout) + set_color 3 + ;; + esac + buffer ' %s\n' "$1" + clear_color +} + +summary() { + if [ "$failures" -eq 0 ]; then + set_color 2 bold + else + set_color 1 bold + fi + + buffer '\n%d test' "$count" + if [[ "$count" -ne 1 ]]; then + buffer 's' + fi + + buffer ', %d failure' "$failures" + if [[ "$failures" -ne 1 ]]; then + buffer 's' + fi + + if [[ "$skipped" -gt 0 ]]; then + buffer ', %d skipped' "$skipped" + fi + + if ((timed_out > 0)); then + buffer ', %d timed out' "$timed_out" + fi + + not_run=$((count - passed - failures - skipped - timed_out)) + if [[ "$not_run" -gt 0 ]]; then + buffer ', %d not run' "$not_run" + fi + + if [[ -n "$BATS_ENABLE_TIMING" ]]; then + buffer " in $SECONDS seconds" + fi + + buffer '\n' + clear_color +} + +buffer_with_truncation() { + local width="$1" + shift + local string + + # shellcheck disable=SC2059 + printf -v 'string' -- "$@" + + if [[ "${#string}" -gt "$width" ]]; then + buffer '%s...' "${string:0:$((width - 4))}" + else + buffer '%s' "$string" + fi +} + +move_up() { + if [[ $1 -gt 0 ]]; then # avoid moving if we got 0 + buffer '\x1B[%dA' "$1" + fi +} + +move_down() { + if [[ $1 -gt 0 ]]; then # avoid moving if we got 0 + buffer '\x1B[%dB' "$1" + fi +} + +go_to_column() { + local column="$1" + buffer '\x1B[%dG' $((column + 1)) +} + +clear_to_end_of_line() { + buffer '\x1B[K' +} + +advance() { + clear_to_end_of_line + buffer '\n' + clear_color +} + +set_color() { + local color="$1" + local weight=22 + + if [[ "${2:-}" == 'bold' ]]; then + weight=1 + fi + buffer '\x1B[%d;%dm' "$((30 + color))" "$weight" +} + +clear_color() { + buffer '\x1B[0m' +} + +_buffer= + +buffer() { + local content + # shellcheck disable=SC2059 + printf -v content -- "$@" + _buffer+="$content" +} + +prefix_buffer_with() { + local old_buffer="$_buffer" + _buffer='' + "$@" + _buffer="$_buffer$old_buffer" +} + +flush() { + printf '%s' "$_buffer" + _buffer= +} + +finish() { + flush + printf '\n' +} + +trap finish EXIT +trap '' INT + +bats_tap_stream_plan() { + count="$1" + index=0 + passed=0 + failures=0 + skipped=0 + timed_out=0 + name= + update_count_column_width +} + +bats_tap_stream_begin() { + index="$1" + name="$2" + begin + flush +} + +bats_tap_stream_ok() { + index="$1" + name="$2" + ((++passed)) + + pass "${BATS_FORMATTER_TEST_DURATION:-}" +} + +bats_tap_stream_skipped() { + index="$1" + name="$2" + ((++skipped)) + skip "$3" "${BATS_FORMATTER_TEST_DURATION:-}" +} + +bats_tap_stream_not_ok() { + index="$1" + name="$2" + + if [[ ${BATS_FORMATTER_TEST_TIMEOUT-x} != x ]]; then + timeout "${BATS_FORMATTER_TEST_DURATION:-}" "${BATS_FORMATTER_TEST_TIMEOUT}s" + ((++timed_out)) + else + fail "${BATS_FORMATTER_TEST_DURATION:-}" + ((++failures)) + fi + +} + +bats_tap_stream_comment() { # + local scope=$2 + # count the lines we printed after the begin text, + if [[ $line_backoff_count -eq 0 && $scope == begin ]]; then + # if this is the first line after begin, go down one line + buffer "\n" + ((++line_backoff_count)) # prefix-increment to avoid "error" due to returning 0 + fi + + ((++line_backoff_count)) + ((line_backoff_count += ${#1} / screen_width)) # account for linebreaks due to length + log "$1" +} + +bats_tap_stream_suite() { + #test_file="$1" + line_backoff_count=0 + index= + # indicate filename for failures + local file_name="${1#"$BASE_PATH"}" + name="File $file_name" + set_color 4 bold + buffer "%s\n" "$file_name" + clear_color +} + +line_backoff_count=0 +bats_tap_stream_unknown() { # + local scope=$2 + # count the lines we printed after the begin text, (or after suite, in case of syntax errors) + if [[ $line_backoff_count -eq 0 && ($scope == begin || $scope == suite) ]]; then + # if this is the first line after begin, go down one line + buffer "\n" + ((++line_backoff_count)) # prefix-increment to avoid "error" due to returning 0 + fi + + ((++line_backoff_count)) + ((line_backoff_count += ${#1} / screen_width)) # account for linebreaks due to length + buffer "%s\n" "$1" + flush +} + +bats_parse_internal_extended_tap + +summary From d34fb7e8a85deaa31697290cf583824911fa6913 Mon Sep 17 00:00:00 2001 From: mmetc <92726601+mmetc@users.noreply.github.com> Date: Tue, 13 Feb 2024 14:22:19 +0100 Subject: [PATCH 08/20] log processor: share apiclient in output goroutines (#2836) --- .golangci.yml | 10 ++- cmd/crowdsec/api.go | 4 +- cmd/crowdsec/crowdsec.go | 56 ++++++++++++--- cmd/crowdsec/lapiclient.go | 92 +++++++++++++++++++++++++ cmd/crowdsec/metrics.go | 13 ++-- cmd/crowdsec/output.go | 105 +++++------------------------ cmd/crowdsec/run_in_svc.go | 12 ++-- cmd/crowdsec/run_in_svc_windows.go | 7 +- cmd/crowdsec/serve.go | 27 ++++++-- test/bats/01_crowdsec.bats | 3 + test/bats/40_live-ban.bats | 21 ++++-- 11 files changed, 229 insertions(+), 121 deletions(-) create mode 100644 cmd/crowdsec/lapiclient.go diff --git a/.golangci.yml b/.golangci.yml index 3161b2c0aaf..e605ac079d4 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -11,7 +11,7 @@ run: linters-settings: cyclop: # lower this after refactoring - max-complexity: 70 + max-complexity: 53 gci: sections: @@ -26,7 +26,7 @@ linters-settings: gocyclo: # lower this after refactoring - min-complexity: 70 + min-complexity: 49 funlen: # Checks the number of lines in a function. @@ -46,7 +46,7 @@ linters-settings: maintidx: # raise this after refactoring - under: 9 + under: 11 misspell: locale: US @@ -263,6 +263,10 @@ issues: - perfsprint text: "fmt.Sprintf can be replaced .*" + - linters: + - perfsprint + text: "fmt.Errorf can be replaced with errors.New" + # # Will fix, easy but some neurons required # diff --git a/cmd/crowdsec/api.go b/cmd/crowdsec/api.go index a1e933cba89..4ac5c3ce96f 100644 --- a/cmd/crowdsec/api.go +++ b/cmd/crowdsec/api.go @@ -56,7 +56,8 @@ func initAPIServer(cConfig *csconfig.Config) (*apiserver.APIServer, error) { return apiServer, nil } -func serveAPIServer(apiServer *apiserver.APIServer, apiReady chan bool) { +func serveAPIServer(apiServer *apiserver.APIServer) { + apiReady := make(chan bool, 1) apiTomb.Go(func() error { defer trace.CatchPanic("crowdsec/serveAPIServer") go func() { @@ -80,6 +81,7 @@ func serveAPIServer(apiServer *apiserver.APIServer, apiReady chan bool) { } return nil }) + <-apiReady } func hasPlugins(profiles []*csconfig.ProfileCfg) bool { diff --git a/cmd/crowdsec/crowdsec.go b/cmd/crowdsec/crowdsec.go index 774b9d381cf..d4cd2d3cf74 100644 --- a/cmd/crowdsec/crowdsec.go +++ b/cmd/crowdsec/crowdsec.go @@ -1,6 +1,7 @@ package main import ( + "context" "fmt" "os" "path/filepath" @@ -13,8 +14,8 @@ import ( "github.com/crowdsecurity/go-cs-lib/trace" "github.com/crowdsecurity/crowdsec/pkg/acquisition" - "github.com/crowdsecurity/crowdsec/pkg/appsec" "github.com/crowdsecurity/crowdsec/pkg/alertcontext" + "github.com/crowdsecurity/crowdsec/pkg/appsec" "github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/cwhub" leaky "github.com/crowdsecurity/crowdsec/pkg/leakybucket" @@ -56,63 +57,86 @@ func runCrowdsec(cConfig *csconfig.Config, parsers *parser.Parsers, hub *cwhub.H //start go-routines for parsing, buckets pour and outputs. parserWg := &sync.WaitGroup{} + parsersTomb.Go(func() error { parserWg.Add(1) + for i := 0; i < cConfig.Crowdsec.ParserRoutinesCount; i++ { parsersTomb.Go(func() error { defer trace.CatchPanic("crowdsec/runParse") + if err := runParse(inputLineChan, inputEventChan, *parsers.Ctx, parsers.Nodes); err != nil { //this error will never happen as parser.Parse is not able to return errors log.Fatalf("starting parse error : %s", err) return err } + return nil }) } parserWg.Done() + return nil }) parserWg.Wait() bucketWg := &sync.WaitGroup{} + bucketsTomb.Go(func() error { bucketWg.Add(1) /*restore previous state as well if present*/ if cConfig.Crowdsec.BucketStateFile != "" { log.Warningf("Restoring buckets state from %s", cConfig.Crowdsec.BucketStateFile) + if err := leaky.LoadBucketsState(cConfig.Crowdsec.BucketStateFile, buckets, holders); err != nil { - return fmt.Errorf("unable to restore buckets : %s", err) + return fmt.Errorf("unable to restore buckets: %w", err) } } for i := 0; i < cConfig.Crowdsec.BucketsRoutinesCount; i++ { bucketsTomb.Go(func() error { defer trace.CatchPanic("crowdsec/runPour") + if err := runPour(inputEventChan, holders, buckets, cConfig); err != nil { log.Fatalf("starting pour error : %s", err) return err } + return nil }) } bucketWg.Done() + return nil }) bucketWg.Wait() + apiClient, err := AuthenticatedLAPIClient(*cConfig.API.Client.Credentials, hub) + if err != nil { + return err + } + + log.Debugf("Starting HeartBeat service") + apiClient.HeartBeat.StartHeartBeat(context.Background(), &outputsTomb) + outputWg := &sync.WaitGroup{} + outputsTomb.Go(func() error { outputWg.Add(1) + for i := 0; i < cConfig.Crowdsec.OutputRoutinesCount; i++ { outputsTomb.Go(func() error { defer trace.CatchPanic("crowdsec/runOutput") - if err := runOutput(inputEventChan, outputEventChan, buckets, *parsers.Povfwctx, parsers.Povfwnodes, *cConfig.API.Client.Credentials, hub); err != nil { + + if err := runOutput(inputEventChan, outputEventChan, buckets, *parsers.Povfwctx, parsers.Povfwnodes, apiClient); err != nil { log.Fatalf("starting outputs error : %s", err) return err } + return nil }) } outputWg.Done() + return nil }) outputWg.Wait() @@ -122,16 +146,16 @@ func runCrowdsec(cConfig *csconfig.Config, parsers *parser.Parsers, hub *cwhub.H if cConfig.Prometheus.Level == "aggregated" { aggregated = true } + if err := acquisition.GetMetrics(dataSources, aggregated); err != nil { return fmt.Errorf("while fetching prometheus metrics for datasources: %w", err) } - } + log.Info("Starting processing data") if err := acquisition.StartAcquisition(dataSources, inputLineChan, &acquisTomb); err != nil { - log.Fatalf("starting acquisition error : %s", err) - return err + return fmt.Errorf("starting acquisition error: %w", err) } return nil @@ -140,11 +164,13 @@ func runCrowdsec(cConfig *csconfig.Config, parsers *parser.Parsers, hub *cwhub.H func serveCrowdsec(parsers *parser.Parsers, cConfig *csconfig.Config, hub *cwhub.Hub, agentReady chan bool) { crowdsecTomb.Go(func() error { defer trace.CatchPanic("crowdsec/serveCrowdsec") + go func() { defer trace.CatchPanic("crowdsec/runCrowdsec") // this logs every time, even at config reload log.Debugf("running agent after %s ms", time.Since(crowdsecT0)) agentReady <- true + if err := runCrowdsec(cConfig, parsers, hub); err != nil { log.Fatalf("unable to start crowdsec routines: %s", err) } @@ -156,16 +182,20 @@ func serveCrowdsec(parsers *parser.Parsers, cConfig *csconfig.Config, hub *cwhub */ waitOnTomb() log.Debugf("Shutting down crowdsec routines") + if err := ShutdownCrowdsecRoutines(); err != nil { log.Fatalf("unable to shutdown crowdsec routines: %s", err) } + log.Debugf("everything is dead, return crowdsecTomb") + if dumpStates { dumpParserState() dumpOverflowState() dumpBucketsPour() os.Exit(0) } + return nil }) } @@ -175,55 +205,65 @@ func dumpBucketsPour() { if err != nil { log.Fatalf("open: %s", err) } + out, err := yaml.Marshal(leaky.BucketPourCache) if err != nil { log.Fatalf("marshal: %s", err) } + b, err := fd.Write(out) if err != nil { log.Fatalf("write: %s", err) } + log.Tracef("wrote %d bytes", b) + if err := fd.Close(); err != nil { log.Fatalf(" close: %s", err) } } func dumpParserState() { - fd, err := os.OpenFile(filepath.Join(parser.DumpFolder, "parser-dump.yaml"), os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0666) if err != nil { log.Fatalf("open: %s", err) } + out, err := yaml.Marshal(parser.StageParseCache) if err != nil { log.Fatalf("marshal: %s", err) } + b, err := fd.Write(out) if err != nil { log.Fatalf("write: %s", err) } + log.Tracef("wrote %d bytes", b) + if err := fd.Close(); err != nil { log.Fatalf(" close: %s", err) } } func dumpOverflowState() { - fd, err := os.OpenFile(filepath.Join(parser.DumpFolder, "bucket-dump.yaml"), os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0666) if err != nil { log.Fatalf("open: %s", err) } + out, err := yaml.Marshal(bucketOverflows) if err != nil { log.Fatalf("marshal: %s", err) } + b, err := fd.Write(out) if err != nil { log.Fatalf("write: %s", err) } + log.Tracef("wrote %d bytes", b) + if err := fd.Close(); err != nil { log.Fatalf(" close: %s", err) } diff --git a/cmd/crowdsec/lapiclient.go b/cmd/crowdsec/lapiclient.go new file mode 100644 index 00000000000..fd29aa9d99b --- /dev/null +++ b/cmd/crowdsec/lapiclient.go @@ -0,0 +1,92 @@ +package main + +import ( + "context" + "fmt" + "net/url" + "time" + + "github.com/go-openapi/strfmt" + + "github.com/crowdsecurity/go-cs-lib/version" + + "github.com/crowdsecurity/crowdsec/pkg/apiclient" + "github.com/crowdsecurity/crowdsec/pkg/csconfig" + "github.com/crowdsecurity/crowdsec/pkg/cwhub" + "github.com/crowdsecurity/crowdsec/pkg/models" +) + +func AuthenticatedLAPIClient(credentials csconfig.ApiCredentialsCfg, hub *cwhub.Hub) (*apiclient.ApiClient, error) { + scenarios, err := hub.GetInstalledItemNames(cwhub.SCENARIOS) + if err != nil { + return nil, fmt.Errorf("loading list of installed hub scenarios: %w", err) + } + + appsecRules, err := hub.GetInstalledItemNames(cwhub.APPSEC_RULES) + if err != nil { + return nil, fmt.Errorf("loading list of installed hub appsec rules: %w", err) + } + + installedScenariosAndAppsecRules := make([]string, 0, len(scenarios)+len(appsecRules)) + installedScenariosAndAppsecRules = append(installedScenariosAndAppsecRules, scenarios...) + installedScenariosAndAppsecRules = append(installedScenariosAndAppsecRules, appsecRules...) + + apiURL, err := url.Parse(credentials.URL) + if err != nil { + return nil, fmt.Errorf("parsing api url ('%s'): %w", credentials.URL, err) + } + + papiURL, err := url.Parse(credentials.PapiURL) + if err != nil { + return nil, fmt.Errorf("parsing polling api url ('%s'): %w", credentials.PapiURL, err) + } + + password := strfmt.Password(credentials.Password) + + client, err := apiclient.NewClient(&apiclient.Config{ + MachineID: credentials.Login, + Password: password, + Scenarios: installedScenariosAndAppsecRules, + UserAgent: fmt.Sprintf("crowdsec/%s", version.String()), + URL: apiURL, + PapiURL: papiURL, + VersionPrefix: "v1", + UpdateScenario: func() ([]string, error) { + scenarios, err := hub.GetInstalledItemNames(cwhub.SCENARIOS) + if err != nil { + return nil, err + } + appsecRules, err := hub.GetInstalledItemNames(cwhub.APPSEC_RULES) + if err != nil { + return nil, err + } + ret := make([]string, 0, len(scenarios)+len(appsecRules)) + ret = append(ret, scenarios...) + ret = append(ret, appsecRules...) + + return ret, nil + }, + }) + if err != nil { + return nil, fmt.Errorf("new client api: %w", err) + } + + authResp, _, err := client.Auth.AuthenticateWatcher(context.Background(), models.WatcherAuthRequest{ + MachineID: &credentials.Login, + Password: &password, + Scenarios: installedScenariosAndAppsecRules, + }) + if err != nil { + return nil, fmt.Errorf("authenticate watcher (%s): %w", credentials.Login, err) + } + + var expiration time.Time + if err := expiration.UnmarshalText([]byte(authResp.Expire)); err != nil { + return nil, fmt.Errorf("unable to parse jwt expiration: %w", err) + } + + client.GetClient().Transport.(*apiclient.JWTTransport).Token = authResp.Token + client.GetClient().Transport.(*apiclient.JWTTransport).Expiration = expiration + + return client, nil +} diff --git a/cmd/crowdsec/metrics.go b/cmd/crowdsec/metrics.go index fa2d8d5de32..1199af0fe16 100644 --- a/cmd/crowdsec/metrics.go +++ b/cmd/crowdsec/metrics.go @@ -114,13 +114,17 @@ func computeDynamicMetrics(next http.Handler, dbClient *database.Client) http.Ha } decisionsFilters := make(map[string][]string, 0) + decisions, err := dbClient.QueryDecisionCountByScenario(decisionsFilters) if err != nil { log.Errorf("Error querying decisions for metrics: %v", err) next.ServeHTTP(w, r) + return } + globalActiveDecisions.Reset() + for _, d := range decisions { globalActiveDecisions.With(prometheus.Labels{"reason": d.Scenario, "origin": d.Origin, "action": d.Type}).Set(float64(d.Count)) } @@ -136,6 +140,7 @@ func computeDynamicMetrics(next http.Handler, dbClient *database.Client) http.Ha if err != nil { log.Errorf("Error querying alerts for metrics: %v", err) next.ServeHTTP(w, r) + return } @@ -173,11 +178,12 @@ func registerPrometheus(config *csconfig.PrometheusCfg) { globalActiveDecisions, globalAlerts, parser.NodesWlHitsOk, parser.NodesWlHits, cache.CacheMetrics, exprhelpers.RegexpCacheMetrics, ) - } } -func servePrometheus(config *csconfig.PrometheusCfg, dbClient *database.Client, apiReady chan bool, agentReady chan bool) { +func servePrometheus(config *csconfig.PrometheusCfg, dbClient *database.Client, agentReady chan bool) { + <-agentReady + if !config.Enabled { return } @@ -185,9 +191,8 @@ func servePrometheus(config *csconfig.PrometheusCfg, dbClient *database.Client, defer trace.CatchPanic("crowdsec/servePrometheus") http.Handle("/metrics", computeDynamicMetrics(promhttp.Handler(), dbClient)) - <-apiReady - <-agentReady log.Debugf("serving metrics after %s ms", time.Since(crowdsecT0)) + if err := http.ListenAndServe(fmt.Sprintf("%s:%d", config.ListenAddr, config.ListenPort), nil); err != nil { log.Warningf("prometheus: %s", err) } diff --git a/cmd/crowdsec/output.go b/cmd/crowdsec/output.go index ad53ce4c827..c4a2c0b6ac1 100644 --- a/cmd/crowdsec/output.go +++ b/cmd/crowdsec/output.go @@ -3,18 +3,12 @@ package main import ( "context" "fmt" - "net/url" "sync" "time" - "github.com/go-openapi/strfmt" log "github.com/sirupsen/logrus" - "github.com/crowdsecurity/go-cs-lib/version" - "github.com/crowdsecurity/crowdsec/pkg/apiclient" - "github.com/crowdsecurity/crowdsec/pkg/csconfig" - "github.com/crowdsecurity/crowdsec/pkg/cwhub" leaky "github.com/crowdsecurity/crowdsec/pkg/leakybucket" "github.com/crowdsecurity/crowdsec/pkg/models" "github.com/crowdsecurity/crowdsec/pkg/parser" @@ -22,7 +16,6 @@ import ( ) func dedupAlerts(alerts []types.RuntimeAlert) ([]*models.Alert, error) { - var dedupCache []*models.Alert for idx, alert := range alerts { @@ -32,16 +25,21 @@ func dedupAlerts(alerts []types.RuntimeAlert) ([]*models.Alert, error) { dedupCache = append(dedupCache, alert.Alert) continue } + for k, src := range alert.Sources { refsrc := *alert.Alert //copy + log.Tracef("source[%s]", k) + refsrc.Source = &src dedupCache = append(dedupCache, &refsrc) } } + if len(dedupCache) != len(alerts) { log.Tracef("went from %d to %d alerts", len(alerts), len(dedupCache)) } + return dedupCache, nil } @@ -52,93 +50,25 @@ func PushAlerts(alerts []types.RuntimeAlert, client *apiclient.ApiClient) error if err != nil { return fmt.Errorf("failed to transform alerts for api: %w", err) } + _, _, err = client.Alerts.Add(ctx, alertsToPush) if err != nil { return fmt.Errorf("failed sending alert to LAPI: %w", err) } + return nil } var bucketOverflows []types.Event -func runOutput(input chan types.Event, overflow chan types.Event, buckets *leaky.Buckets, - postOverflowCTX parser.UnixParserCtx, postOverflowNodes []parser.Node, - apiConfig csconfig.ApiCredentialsCfg, hub *cwhub.Hub) error { +func runOutput(input chan types.Event, overflow chan types.Event, buckets *leaky.Buckets, postOverflowCTX parser.UnixParserCtx, + postOverflowNodes []parser.Node, client *apiclient.ApiClient) error { + var ( + cache []types.RuntimeAlert + cacheMutex sync.Mutex + ) - var err error ticker := time.NewTicker(1 * time.Second) - - var cache []types.RuntimeAlert - var cacheMutex sync.Mutex - - scenarios, err := hub.GetInstalledItemNames(cwhub.SCENARIOS) - if err != nil { - return fmt.Errorf("loading list of installed hub scenarios: %w", err) - } - - appsecRules, err := hub.GetInstalledItemNames(cwhub.APPSEC_RULES) - if err != nil { - return fmt.Errorf("loading list of installed hub appsec rules: %w", err) - } - - installedScenariosAndAppsecRules := make([]string, 0, len(scenarios)+len(appsecRules)) - installedScenariosAndAppsecRules = append(installedScenariosAndAppsecRules, scenarios...) - installedScenariosAndAppsecRules = append(installedScenariosAndAppsecRules, appsecRules...) - - apiURL, err := url.Parse(apiConfig.URL) - if err != nil { - return fmt.Errorf("parsing api url ('%s'): %w", apiConfig.URL, err) - } - papiURL, err := url.Parse(apiConfig.PapiURL) - if err != nil { - return fmt.Errorf("parsing polling api url ('%s'): %w", apiConfig.PapiURL, err) - } - password := strfmt.Password(apiConfig.Password) - - Client, err := apiclient.NewClient(&apiclient.Config{ - MachineID: apiConfig.Login, - Password: password, - Scenarios: installedScenariosAndAppsecRules, - UserAgent: fmt.Sprintf("crowdsec/%s", version.String()), - URL: apiURL, - PapiURL: papiURL, - VersionPrefix: "v1", - UpdateScenario: func() ([]string, error) { - scenarios, err := hub.GetInstalledItemNames(cwhub.SCENARIOS) - if err != nil { - return nil, err - } - appsecRules, err := hub.GetInstalledItemNames(cwhub.APPSEC_RULES) - if err != nil { - return nil, err - } - ret := make([]string, 0, len(scenarios)+len(appsecRules)) - ret = append(ret, scenarios...) - ret = append(ret, appsecRules...) - return ret, nil - }, - }) - if err != nil { - return fmt.Errorf("new client api: %w", err) - } - authResp, _, err := Client.Auth.AuthenticateWatcher(context.Background(), models.WatcherAuthRequest{ - MachineID: &apiConfig.Login, - Password: &password, - Scenarios: installedScenariosAndAppsecRules, - }) - if err != nil { - return fmt.Errorf("authenticate watcher (%s): %w", apiConfig.Login, err) - } - - if err := Client.GetClient().Transport.(*apiclient.JWTTransport).Expiration.UnmarshalText([]byte(authResp.Expire)); err != nil { - return fmt.Errorf("unable to parse jwt expiration: %w", err) - } - - Client.GetClient().Transport.(*apiclient.JWTTransport).Token = authResp.Token - - //start the heartbeat service - log.Debugf("Starting HeartBeat service") - Client.HeartBeat.StartHeartBeat(context.Background(), &outputsTomb) LOOP: for { select { @@ -149,7 +79,7 @@ LOOP: newcache := make([]types.RuntimeAlert, 0) cache = newcache cacheMutex.Unlock() - if err := PushAlerts(cachecopy, Client); err != nil { + if err := PushAlerts(cachecopy, client); err != nil { log.Errorf("while pushing to api : %s", err) //just push back the events to the queue cacheMutex.Lock() @@ -162,10 +92,11 @@ LOOP: cacheMutex.Lock() cachecopy := cache cacheMutex.Unlock() - if err := PushAlerts(cachecopy, Client); err != nil { + if err := PushAlerts(cachecopy, client); err != nil { log.Errorf("while pushing leftovers to api : %s", err) } } + break LOOP case event := <-overflow: /*if alert is empty and mapKey is present, the overflow is just to cleanup bucket*/ @@ -176,7 +107,7 @@ LOOP: /* process post overflow parser nodes */ event, err := parser.Parse(postOverflowCTX, event, postOverflowNodes) if err != nil { - return fmt.Errorf("postoverflow failed : %s", err) + return fmt.Errorf("postoverflow failed: %w", err) } log.Printf("%s", *event.Overflow.Alert.Message) //if the Alert is nil, it's to signal bucket is ready for GC, don't track this @@ -206,6 +137,6 @@ LOOP: } ticker.Stop() - return nil + return nil } diff --git a/cmd/crowdsec/run_in_svc.go b/cmd/crowdsec/run_in_svc.go index 2020537908d..5a8bc9a6cd3 100644 --- a/cmd/crowdsec/run_in_svc.go +++ b/cmd/crowdsec/run_in_svc.go @@ -33,7 +33,6 @@ func StartRunSvc() error { log.Infof("Crowdsec %s", version.String()) - apiReady := make(chan bool, 1) agentReady := make(chan bool, 1) // Enable profiling early @@ -46,14 +45,19 @@ func StartRunSvc() error { dbClient, err = database.NewClient(cConfig.DbConfig) if err != nil { - return fmt.Errorf("unable to create database client: %s", err) + return fmt.Errorf("unable to create database client: %w", err) } } registerPrometheus(cConfig.Prometheus) - go servePrometheus(cConfig.Prometheus, dbClient, apiReady, agentReady) + go servePrometheus(cConfig.Prometheus, dbClient, agentReady) + } else { + // avoid leaking the channel + go func() { + <-agentReady + }() } - return Serve(cConfig, apiReady, agentReady) + return Serve(cConfig, agentReady) } diff --git a/cmd/crowdsec/run_in_svc_windows.go b/cmd/crowdsec/run_in_svc_windows.go index 991f7ae4491..7845e9c58b5 100644 --- a/cmd/crowdsec/run_in_svc_windows.go +++ b/cmd/crowdsec/run_in_svc_windows.go @@ -73,7 +73,6 @@ func WindowsRun() error { log.Infof("Crowdsec %s", version.String()) - apiReady := make(chan bool, 1) agentReady := make(chan bool, 1) // Enable profiling early @@ -85,11 +84,11 @@ func WindowsRun() error { dbClient, err = database.NewClient(cConfig.DbConfig) if err != nil { - return fmt.Errorf("unable to create database client: %s", err) + return fmt.Errorf("unable to create database client: %w", err) } } registerPrometheus(cConfig.Prometheus) - go servePrometheus(cConfig.Prometheus, dbClient, apiReady, agentReady) + go servePrometheus(cConfig.Prometheus, dbClient, agentReady) } - return Serve(cConfig, apiReady, agentReady) + return Serve(cConfig, agentReady) } diff --git a/cmd/crowdsec/serve.go b/cmd/crowdsec/serve.go index a5c8e24cf3f..22f65b927a0 100644 --- a/cmd/crowdsec/serve.go +++ b/cmd/crowdsec/serve.go @@ -42,7 +42,9 @@ func debugHandler(sig os.Signal, cConfig *csconfig.Config) error { if err := leaky.ShutdownAllBuckets(buckets); err != nil { log.Warningf("Failed to shut down routines : %s", err) } + log.Printf("Shutdown is finished, buckets are in %s", tmpFile) + return nil } @@ -66,15 +68,16 @@ func reloadHandler(sig os.Signal) (*csconfig.Config, error) { if !cConfig.DisableAPI { if flags.DisableCAPI { log.Warningf("Communication with CrowdSec Central API disabled from args") + cConfig.API.Server.OnlineClient = nil } + apiServer, err := initAPIServer(cConfig) if err != nil { return nil, fmt.Errorf("unable to init api server: %w", err) } - apiReady := make(chan bool, 1) - serveAPIServer(apiServer, apiReady) + serveAPIServer(apiServer) } if !cConfig.DisableAgent { @@ -110,6 +113,7 @@ func reloadHandler(sig os.Signal) (*csconfig.Config, error) { log.Warningf("Failed to delete temp file (%s) : %s", tmpFile, err) } } + return cConfig, nil } @@ -117,10 +121,12 @@ func ShutdownCrowdsecRoutines() error { var reterr error log.Debugf("Shutting down crowdsec sub-routines") + if len(dataSources) > 0 { acquisTomb.Kill(nil) log.Debugf("waiting for acquisition to finish") drainChan(inputLineChan) + if err := acquisTomb.Wait(); err != nil { log.Warningf("Acquisition returned error : %s", err) reterr = err @@ -130,6 +136,7 @@ func ShutdownCrowdsecRoutines() error { log.Debugf("acquisition is finished, wait for parser/bucket/ouputs.") parsersTomb.Kill(nil) drainChan(inputEventChan) + if err := parsersTomb.Wait(); err != nil { log.Warningf("Parsers returned error : %s", err) reterr = err @@ -160,6 +167,7 @@ func ShutdownCrowdsecRoutines() error { log.Warningf("Outputs returned error : %s", err) reterr = err } + log.Debugf("outputs are done") case <-time.After(3 * time.Second): // this can happen if outputs are stuck in a http retry loop @@ -181,6 +189,7 @@ func shutdownAPI() error { } log.Debugf("done") + return nil } @@ -193,6 +202,7 @@ func shutdownCrowdsec() error { } log.Debugf("done") + return nil } @@ -292,10 +302,11 @@ func HandleSignals(cConfig *csconfig.Config) error { if err == nil { log.Warning("Crowdsec service shutting down") } + return err } -func Serve(cConfig *csconfig.Config, apiReady chan bool, agentReady chan bool) error { +func Serve(cConfig *csconfig.Config, agentReady chan bool) error { acquisTomb = tomb.Tomb{} parsersTomb = tomb.Tomb{} bucketsTomb = tomb.Tomb{} @@ -325,6 +336,7 @@ func Serve(cConfig *csconfig.Config, apiReady chan bool, agentReady chan bool) e if cConfig.API.CTI != nil && *cConfig.API.CTI.Enabled { log.Infof("Crowdsec CTI helper enabled") + if err := exprhelpers.InitCrowdsecCTI(cConfig.API.CTI.Key, cConfig.API.CTI.CacheTimeout, cConfig.API.CTI.CacheSize, cConfig.API.CTI.LogLevel); err != nil { return fmt.Errorf("failed to init crowdsec cti: %w", err) } @@ -337,6 +349,7 @@ func Serve(cConfig *csconfig.Config, apiReady chan bool, agentReady chan bool) e if flags.DisableCAPI { log.Warningf("Communication with CrowdSec Central API disabled from args") + cConfig.API.Server.OnlineClient = nil } @@ -346,10 +359,8 @@ func Serve(cConfig *csconfig.Config, apiReady chan bool, agentReady chan bool) e } if !flags.TestMode { - serveAPIServer(apiServer, apiReady) + serveAPIServer(apiServer) } - } else { - apiReady <- true } if !cConfig.DisableAgent { @@ -366,6 +377,8 @@ func Serve(cConfig *csconfig.Config, apiReady chan bool, agentReady chan bool) e // if it's just linting, we're done if !flags.TestMode { serveCrowdsec(csParsers, cConfig, hub, agentReady) + } else { + agentReady <- true } } else { agentReady <- true @@ -395,6 +408,7 @@ func Serve(cConfig *csconfig.Config, apiReady chan bool, agentReady chan bool) e for _, ch := range waitChans { <-ch + switch ch { case apiTomb.Dead(): log.Infof("api shutdown") @@ -402,5 +416,6 @@ func Serve(cConfig *csconfig.Config, apiReady chan bool, agentReady chan bool) e log.Infof("crowdsec shutdown") } } + return nil } diff --git a/test/bats/01_crowdsec.bats b/test/bats/01_crowdsec.bats index be06ac9261a..a585930e34c 100644 --- a/test/bats/01_crowdsec.bats +++ b/test/bats/01_crowdsec.bats @@ -75,6 +75,9 @@ teardown() { rune -0 ./instance-crowdsec start-pid PID="$output" + + sleep .5 + assert_file_exists "$log_old" assert_file_contains "$log_old" "Starting processing data" diff --git a/test/bats/40_live-ban.bats b/test/bats/40_live-ban.bats index c6b8ddf1563..a544f67be18 100644 --- a/test/bats/40_live-ban.bats +++ b/test/bats/40_live-ban.bats @@ -41,10 +41,23 @@ teardown() { echo -e "---\nfilename: ${tmpfile}\nlabels:\n type: syslog\n" >>"${ACQUIS_YAML}" ./instance-crowdsec start + + sleep 0.2 + fake_log >>"${tmpfile}" - sleep 2 + + sleep 0.2 + rm -f -- "${tmpfile}" - rune -0 cscli decisions list -o json - rune -0 jq -r '.[].decisions[0].value' <(output) - assert_output '1.1.1.172' + + found=0 + # this may take some time in CI + for _ in $(seq 1 10); do + if cscli decisions list -o json | jq -r '.[].decisions[0].value' | grep -q '1.1.1.172'; then + found=1 + break + fi + sleep 0.2 + done + assert_equal 1 "${found}" } From 45571cea08591962b515ed903b0b00488a4f7c13 Mon Sep 17 00:00:00 2001 From: mmetc <92726601+mmetc@users.noreply.github.com> Date: Wed, 14 Feb 2024 09:47:12 +0100 Subject: [PATCH 09/20] use go 1.21.7 (#2830) --- .github/workflows/bats-hub.yml | 2 +- .github/workflows/bats-mysql.yml | 2 +- .github/workflows/bats-postgres.yml | 2 +- .github/workflows/bats-sqlite-coverage.yml | 2 +- .github/workflows/ci-windows-build-msi.yml | 2 +- .github/workflows/codeql-analysis.yml | 3 ++- .github/workflows/go-tests-windows.yml | 2 +- .github/workflows/go-tests.yml | 2 +- .github/workflows/publish-tarball-release.yml | 2 +- Dockerfile | 2 +- Dockerfile.debian | 2 +- azure-pipelines.yml | 2 +- 12 files changed, 13 insertions(+), 12 deletions(-) diff --git a/.github/workflows/bats-hub.yml b/.github/workflows/bats-hub.yml index 7764da84812..075480485ff 100644 --- a/.github/workflows/bats-hub.yml +++ b/.github/workflows/bats-hub.yml @@ -33,7 +33,7 @@ jobs: - name: "Set up Go" uses: actions/setup-go@v5 with: - go-version: "1.21.6" + go-version: "1.21.7" - name: "Install bats dependencies" env: diff --git a/.github/workflows/bats-mysql.yml b/.github/workflows/bats-mysql.yml index 243da6eb25d..5c019933304 100644 --- a/.github/workflows/bats-mysql.yml +++ b/.github/workflows/bats-mysql.yml @@ -36,7 +36,7 @@ jobs: - name: "Set up Go" uses: actions/setup-go@v5 with: - go-version: "1.21.6" + go-version: "1.21.7" - name: "Install bats dependencies" env: diff --git a/.github/workflows/bats-postgres.yml b/.github/workflows/bats-postgres.yml index 07d3cd8d2f1..0f3c69ccefa 100644 --- a/.github/workflows/bats-postgres.yml +++ b/.github/workflows/bats-postgres.yml @@ -45,7 +45,7 @@ jobs: - name: "Set up Go" uses: actions/setup-go@v5 with: - go-version: "1.21.6" + go-version: "1.21.7" - name: "Install bats dependencies" env: diff --git a/.github/workflows/bats-sqlite-coverage.yml b/.github/workflows/bats-sqlite-coverage.yml index 46a5dd8bc86..436eb0f04a4 100644 --- a/.github/workflows/bats-sqlite-coverage.yml +++ b/.github/workflows/bats-sqlite-coverage.yml @@ -28,7 +28,7 @@ jobs: - name: "Set up Go" uses: actions/setup-go@v5 with: - go-version: "1.21.6" + go-version: "1.21.7" - name: "Install bats dependencies" env: diff --git a/.github/workflows/ci-windows-build-msi.yml b/.github/workflows/ci-windows-build-msi.yml index 26c981143ad..7c6a6621de4 100644 --- a/.github/workflows/ci-windows-build-msi.yml +++ b/.github/workflows/ci-windows-build-msi.yml @@ -35,7 +35,7 @@ jobs: - name: "Set up Go" uses: actions/setup-go@v5 with: - go-version: "1.21.6" + go-version: "1.21.7" - name: Build run: make windows_installer BUILD_RE2_WASM=1 diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml index 4b262f13d09..bdc16e650f6 100644 --- a/.github/workflows/codeql-analysis.yml +++ b/.github/workflows/codeql-analysis.yml @@ -48,10 +48,11 @@ jobs: with: # required to pick up tags for BUILD_VERSION fetch-depth: 0 + - name: "Set up Go" uses: actions/setup-go@v5 with: - go-version: "1.21.6" + go-version: "1.21.7" cache-dependency-path: "**/go.sum" # Initializes the CodeQL tools for scanning. diff --git a/.github/workflows/go-tests-windows.yml b/.github/workflows/go-tests-windows.yml index 63781a7b25e..efe16ed66d9 100644 --- a/.github/workflows/go-tests-windows.yml +++ b/.github/workflows/go-tests-windows.yml @@ -34,7 +34,7 @@ jobs: - name: "Set up Go" uses: actions/setup-go@v5 with: - go-version: "1.21.6" + go-version: "1.21.7" - name: Build run: | diff --git a/.github/workflows/go-tests.yml b/.github/workflows/go-tests.yml index e8840c07f4e..865b2782a63 100644 --- a/.github/workflows/go-tests.yml +++ b/.github/workflows/go-tests.yml @@ -126,7 +126,7 @@ jobs: - name: "Set up Go" uses: actions/setup-go@v5 with: - go-version: "1.21.6" + go-version: "1.21.7" - name: Create localstack streams run: | diff --git a/.github/workflows/publish-tarball-release.yml b/.github/workflows/publish-tarball-release.yml index 202882791e7..d251677fd46 100644 --- a/.github/workflows/publish-tarball-release.yml +++ b/.github/workflows/publish-tarball-release.yml @@ -25,7 +25,7 @@ jobs: - name: "Set up Go" uses: actions/setup-go@v5 with: - go-version: "1.21.6" + go-version: "1.21.7" - name: Build the binaries run: | diff --git a/Dockerfile b/Dockerfile index 2369c09dfa6..420c521fa58 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,5 +1,5 @@ # vim: set ft=dockerfile: -FROM golang:1.21.6-alpine3.18 AS build +FROM golang:1.21.7-alpine3.18 AS build ARG BUILD_VERSION diff --git a/Dockerfile.debian b/Dockerfile.debian index ba0cd20fb43..48753e7acdb 100644 --- a/Dockerfile.debian +++ b/Dockerfile.debian @@ -1,5 +1,5 @@ # vim: set ft=dockerfile: -FROM golang:1.21.6-bookworm AS build +FROM golang:1.21.7-bookworm AS build ARG BUILD_VERSION diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 82caba42bae..791f41f50ba 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -27,7 +27,7 @@ stages: - task: GoTool@0 displayName: "Install Go" inputs: - version: '1.21.6' + version: '1.21.7' - pwsh: | choco install -y make From 2bbf0b4762ad58f5e50858132085ac4586502008 Mon Sep 17 00:00:00 2001 From: mmetc <92726601+mmetc@users.noreply.github.com> Date: Wed, 14 Feb 2024 11:19:13 +0100 Subject: [PATCH 10/20] re-generate ent code (#2844) --- go.sum | 2 + pkg/database/ent/alert.go | 28 +- pkg/database/ent/alert/alert.go | 203 +++ pkg/database/ent/alert/where.go | 1525 +++++---------------- pkg/database/ent/alert_create.go | 221 +-- pkg/database/ent/alert_delete.go | 51 +- pkg/database/ent/alert_query.go | 279 ++-- pkg/database/ent/alert_update.go | 738 ++-------- pkg/database/ent/bouncer.go | 22 +- pkg/database/ent/bouncer/bouncer.go | 65 + pkg/database/ent/bouncer/where.go | 687 ++-------- pkg/database/ent/bouncer_create.go | 129 +- pkg/database/ent/bouncer_delete.go | 51 +- pkg/database/ent/bouncer_query.go | 239 ++-- pkg/database/ent/bouncer_update.go | 286 +--- pkg/database/ent/client.go | 466 ++++++- pkg/database/ent/config.go | 65 - pkg/database/ent/configitem.go | 22 +- pkg/database/ent/configitem/configitem.go | 30 + pkg/database/ent/configitem/where.go | 299 +--- pkg/database/ent/configitem_create.go | 87 +- pkg/database/ent/configitem_delete.go | 51 +- pkg/database/ent/configitem_query.go | 239 ++-- pkg/database/ent/configitem_update.go | 162 +-- pkg/database/ent/context.go | 33 - pkg/database/ent/decision.go | 24 +- pkg/database/ent/decision/decision.go | 105 ++ pkg/database/ent/decision/where.go | 930 +++---------- pkg/database/ent/decision_create.go | 158 +-- pkg/database/ent/decision_delete.go | 51 +- pkg/database/ent/decision_query.go | 249 ++-- pkg/database/ent/decision_update.go | 444 +----- pkg/database/ent/ent.go | 233 +++- pkg/database/ent/event.go | 24 +- pkg/database/ent/event/event.go | 50 + pkg/database/ent/event/where.go | 322 +---- pkg/database/ent/event_create.go | 92 +- pkg/database/ent/event_delete.go | 51 +- pkg/database/ent/event_query.go | 249 ++-- pkg/database/ent/event_update.go | 196 +-- pkg/database/ent/hook/hook.go | 49 +- pkg/database/ent/machine.go | 24 +- pkg/database/ent/machine/machine.go | 92 ++ pkg/database/ent/machine/where.go | 766 +++-------- pkg/database/ent/machine_create.go | 140 +- pkg/database/ent/machine_delete.go | 51 +- pkg/database/ent/machine_query.go | 247 ++-- pkg/database/ent/machine_update.go | 352 +---- pkg/database/ent/meta.go | 24 +- pkg/database/ent/meta/meta.go | 50 + pkg/database/ent/meta/where.go | 342 +---- pkg/database/ent/meta_create.go | 92 +- pkg/database/ent/meta_delete.go | 51 +- pkg/database/ent/meta_query.go | 249 ++-- pkg/database/ent/meta_update.go | 196 +-- pkg/database/ent/mutation.go | 112 +- pkg/database/ent/runtime/runtime.go | 4 +- pkg/database/ent/tx.go | 36 +- 58 files changed, 4026 insertions(+), 8009 deletions(-) delete mode 100644 pkg/database/ent/config.go delete mode 100644 pkg/database/ent/context.go diff --git a/go.sum b/go.sum index 8fa2021316b..2daf22cc99c 100644 --- a/go.sum +++ b/go.sum @@ -542,6 +542,8 @@ github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU= github.com/oklog/run v1.0.0 h1:Ru7dDtJNOyC66gQ5dQmaCa0qIsAUFY3sFpK1Xk8igrw= github.com/oklog/run v1.0.0/go.mod h1:dlhp/R75TPv97u0XWUtDeV/lRKWPKSdTuV0TZvrmrQA= +github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec= +github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY= github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= github.com/opencontainers/image-spec v1.0.3-0.20211202183452-c5a74bcca799 h1:rc3tiVYb5z54aKaDfakKn0dDjIyPpTtszkjuMzyt7ec= diff --git a/pkg/database/ent/alert.go b/pkg/database/ent/alert.go index 2649923bf5e..5cb4d1a352c 100644 --- a/pkg/database/ent/alert.go +++ b/pkg/database/ent/alert.go @@ -7,6 +7,7 @@ import ( "strings" "time" + "entgo.io/ent" "entgo.io/ent/dialect/sql" "github.com/crowdsecurity/crowdsec/pkg/database/ent/alert" "github.com/crowdsecurity/crowdsec/pkg/database/ent/machine" @@ -67,6 +68,7 @@ type Alert struct { // The values are being populated by the AlertQuery when eager-loading is set. Edges AlertEdges `json:"edges"` machine_alerts *int + selectValues sql.SelectValues } // AlertEdges holds the relations/edges for other nodes in the graph. @@ -142,7 +144,7 @@ func (*Alert) scanValues(columns []string) ([]any, error) { case alert.ForeignKeys[0]: // machine_alerts values[i] = new(sql.NullInt64) default: - return nil, fmt.Errorf("unexpected column %q for type Alert", columns[i]) + values[i] = new(sql.UnknownType) } } return values, nil @@ -309,36 +311,44 @@ func (a *Alert) assignValues(columns []string, values []any) error { a.machine_alerts = new(int) *a.machine_alerts = int(value.Int64) } + default: + a.selectValues.Set(columns[i], values[i]) } } return nil } +// Value returns the ent.Value that was dynamically selected and assigned to the Alert. +// This includes values selected through modifiers, order, etc. +func (a *Alert) Value(name string) (ent.Value, error) { + return a.selectValues.Get(name) +} + // QueryOwner queries the "owner" edge of the Alert entity. func (a *Alert) QueryOwner() *MachineQuery { - return (&AlertClient{config: a.config}).QueryOwner(a) + return NewAlertClient(a.config).QueryOwner(a) } // QueryDecisions queries the "decisions" edge of the Alert entity. func (a *Alert) QueryDecisions() *DecisionQuery { - return (&AlertClient{config: a.config}).QueryDecisions(a) + return NewAlertClient(a.config).QueryDecisions(a) } // QueryEvents queries the "events" edge of the Alert entity. func (a *Alert) QueryEvents() *EventQuery { - return (&AlertClient{config: a.config}).QueryEvents(a) + return NewAlertClient(a.config).QueryEvents(a) } // QueryMetas queries the "metas" edge of the Alert entity. func (a *Alert) QueryMetas() *MetaQuery { - return (&AlertClient{config: a.config}).QueryMetas(a) + return NewAlertClient(a.config).QueryMetas(a) } // Update returns a builder for updating this Alert. // Note that you need to call Alert.Unwrap() before calling this method if this Alert // was returned from a transaction, and the transaction was committed or rolled back. func (a *Alert) Update() *AlertUpdateOne { - return (&AlertClient{config: a.config}).UpdateOne(a) + return NewAlertClient(a.config).UpdateOne(a) } // Unwrap unwraps the Alert entity that was returned from a transaction after it was closed, @@ -435,9 +445,3 @@ func (a *Alert) String() string { // Alerts is a parsable slice of Alert. type Alerts []*Alert - -func (a Alerts) config(cfg config) { - for _i := range a { - a[_i].config = cfg - } -} diff --git a/pkg/database/ent/alert/alert.go b/pkg/database/ent/alert/alert.go index abee13fb97a..eb9f1d10788 100644 --- a/pkg/database/ent/alert/alert.go +++ b/pkg/database/ent/alert/alert.go @@ -4,6 +4,9 @@ package alert import ( "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" ) const ( @@ -168,3 +171,203 @@ var ( // DefaultSimulated holds the default value on creation for the "simulated" field. DefaultSimulated bool ) + +// OrderOption defines the ordering options for the Alert queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByCreatedAt orders the results by the created_at field. +func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() +} + +// ByUpdatedAt orders the results by the updated_at field. +func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc() +} + +// ByScenario orders the results by the scenario field. +func ByScenario(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldScenario, opts...).ToFunc() +} + +// ByBucketId orders the results by the bucketId field. +func ByBucketId(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldBucketId, opts...).ToFunc() +} + +// ByMessage orders the results by the message field. +func ByMessage(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldMessage, opts...).ToFunc() +} + +// ByEventsCountField orders the results by the eventsCount field. +func ByEventsCountField(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldEventsCount, opts...).ToFunc() +} + +// ByStartedAt orders the results by the startedAt field. +func ByStartedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldStartedAt, opts...).ToFunc() +} + +// ByStoppedAt orders the results by the stoppedAt field. +func ByStoppedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldStoppedAt, opts...).ToFunc() +} + +// BySourceIp orders the results by the sourceIp field. +func BySourceIp(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSourceIp, opts...).ToFunc() +} + +// BySourceRange orders the results by the sourceRange field. +func BySourceRange(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSourceRange, opts...).ToFunc() +} + +// BySourceAsNumber orders the results by the sourceAsNumber field. +func BySourceAsNumber(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSourceAsNumber, opts...).ToFunc() +} + +// BySourceAsName orders the results by the sourceAsName field. +func BySourceAsName(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSourceAsName, opts...).ToFunc() +} + +// BySourceCountry orders the results by the sourceCountry field. +func BySourceCountry(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSourceCountry, opts...).ToFunc() +} + +// BySourceLatitude orders the results by the sourceLatitude field. +func BySourceLatitude(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSourceLatitude, opts...).ToFunc() +} + +// BySourceLongitude orders the results by the sourceLongitude field. +func BySourceLongitude(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSourceLongitude, opts...).ToFunc() +} + +// BySourceScope orders the results by the sourceScope field. +func BySourceScope(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSourceScope, opts...).ToFunc() +} + +// BySourceValue orders the results by the sourceValue field. +func BySourceValue(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSourceValue, opts...).ToFunc() +} + +// ByCapacity orders the results by the capacity field. +func ByCapacity(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCapacity, opts...).ToFunc() +} + +// ByLeakSpeed orders the results by the leakSpeed field. +func ByLeakSpeed(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldLeakSpeed, opts...).ToFunc() +} + +// ByScenarioVersion orders the results by the scenarioVersion field. +func ByScenarioVersion(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldScenarioVersion, opts...).ToFunc() +} + +// ByScenarioHash orders the results by the scenarioHash field. +func ByScenarioHash(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldScenarioHash, opts...).ToFunc() +} + +// BySimulated orders the results by the simulated field. +func BySimulated(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSimulated, opts...).ToFunc() +} + +// ByUUID orders the results by the uuid field. +func ByUUID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUUID, opts...).ToFunc() +} + +// ByOwnerField orders the results by owner field. +func ByOwnerField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newOwnerStep(), sql.OrderByField(field, opts...)) + } +} + +// ByDecisionsCount orders the results by decisions count. +func ByDecisionsCount(opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborsCount(s, newDecisionsStep(), opts...) + } +} + +// ByDecisions orders the results by decisions terms. +func ByDecisions(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newDecisionsStep(), append([]sql.OrderTerm{term}, terms...)...) + } +} + +// ByEventsCount orders the results by events count. +func ByEventsCount(opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborsCount(s, newEventsStep(), opts...) + } +} + +// ByEvents orders the results by events terms. +func ByEvents(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newEventsStep(), append([]sql.OrderTerm{term}, terms...)...) + } +} + +// ByMetasCount orders the results by metas count. +func ByMetasCount(opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborsCount(s, newMetasStep(), opts...) + } +} + +// ByMetas orders the results by metas terms. +func ByMetas(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newMetasStep(), append([]sql.OrderTerm{term}, terms...)...) + } +} +func newOwnerStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(OwnerInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, OwnerTable, OwnerColumn), + ) +} +func newDecisionsStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(DecisionsInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, DecisionsTable, DecisionsColumn), + ) +} +func newEventsStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(EventsInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, EventsTable, EventsColumn), + ) +} +func newMetasStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(MetasInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, MetasTable, MetasColumn), + ) +} diff --git a/pkg/database/ent/alert/where.go b/pkg/database/ent/alert/where.go index ef5b89b615f..516ead50636 100644 --- a/pkg/database/ent/alert/where.go +++ b/pkg/database/ent/alert/where.go @@ -12,2440 +12,1612 @@ import ( // ID filters vertices based on their ID field. func ID(id int) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.Alert(sql.FieldEQ(FieldID, id)) } // IDEQ applies the EQ predicate on the ID field. func IDEQ(id int) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.Alert(sql.FieldEQ(FieldID, id)) } // IDNEQ applies the NEQ predicate on the ID field. func IDNEQ(id int) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldID), id)) - }) + return predicate.Alert(sql.FieldNEQ(FieldID, id)) } // IDIn applies the In predicate on the ID field. func IDIn(ids ...int) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - v := make([]any, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.In(s.C(FieldID), v...)) - }) + return predicate.Alert(sql.FieldIn(FieldID, ids...)) } // IDNotIn applies the NotIn predicate on the ID field. func IDNotIn(ids ...int) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - v := make([]any, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.NotIn(s.C(FieldID), v...)) - }) + return predicate.Alert(sql.FieldNotIn(FieldID, ids...)) } // IDGT applies the GT predicate on the ID field. func IDGT(id int) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldID), id)) - }) + return predicate.Alert(sql.FieldGT(FieldID, id)) } // IDGTE applies the GTE predicate on the ID field. func IDGTE(id int) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldID), id)) - }) + return predicate.Alert(sql.FieldGTE(FieldID, id)) } // IDLT applies the LT predicate on the ID field. func IDLT(id int) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldID), id)) - }) + return predicate.Alert(sql.FieldLT(FieldID, id)) } // IDLTE applies the LTE predicate on the ID field. func IDLTE(id int) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldID), id)) - }) + return predicate.Alert(sql.FieldLTE(FieldID, id)) } // CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. func CreatedAt(v time.Time) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldCreatedAt), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldCreatedAt, v)) } // UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ. func UpdatedAt(v time.Time) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldUpdatedAt), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldUpdatedAt, v)) } // Scenario applies equality check predicate on the "scenario" field. It's identical to ScenarioEQ. func Scenario(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldScenario), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldScenario, v)) } // BucketId applies equality check predicate on the "bucketId" field. It's identical to BucketIdEQ. func BucketId(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldBucketId), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldBucketId, v)) } // Message applies equality check predicate on the "message" field. It's identical to MessageEQ. func Message(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldMessage), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldMessage, v)) } // EventsCount applies equality check predicate on the "eventsCount" field. It's identical to EventsCountEQ. func EventsCount(v int32) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldEventsCount), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldEventsCount, v)) } // StartedAt applies equality check predicate on the "startedAt" field. It's identical to StartedAtEQ. func StartedAt(v time.Time) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldStartedAt), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldStartedAt, v)) } // StoppedAt applies equality check predicate on the "stoppedAt" field. It's identical to StoppedAtEQ. func StoppedAt(v time.Time) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldStoppedAt), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldStoppedAt, v)) } // SourceIp applies equality check predicate on the "sourceIp" field. It's identical to SourceIpEQ. func SourceIp(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldSourceIp), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldSourceIp, v)) } // SourceRange applies equality check predicate on the "sourceRange" field. It's identical to SourceRangeEQ. func SourceRange(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldSourceRange), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldSourceRange, v)) } // SourceAsNumber applies equality check predicate on the "sourceAsNumber" field. It's identical to SourceAsNumberEQ. func SourceAsNumber(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldSourceAsNumber), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldSourceAsNumber, v)) } // SourceAsName applies equality check predicate on the "sourceAsName" field. It's identical to SourceAsNameEQ. func SourceAsName(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldSourceAsName), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldSourceAsName, v)) } // SourceCountry applies equality check predicate on the "sourceCountry" field. It's identical to SourceCountryEQ. func SourceCountry(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldSourceCountry), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldSourceCountry, v)) } // SourceLatitude applies equality check predicate on the "sourceLatitude" field. It's identical to SourceLatitudeEQ. func SourceLatitude(v float32) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldSourceLatitude), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldSourceLatitude, v)) } // SourceLongitude applies equality check predicate on the "sourceLongitude" field. It's identical to SourceLongitudeEQ. func SourceLongitude(v float32) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldSourceLongitude), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldSourceLongitude, v)) } // SourceScope applies equality check predicate on the "sourceScope" field. It's identical to SourceScopeEQ. func SourceScope(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldSourceScope), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldSourceScope, v)) } // SourceValue applies equality check predicate on the "sourceValue" field. It's identical to SourceValueEQ. func SourceValue(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldSourceValue), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldSourceValue, v)) } // Capacity applies equality check predicate on the "capacity" field. It's identical to CapacityEQ. func Capacity(v int32) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldCapacity), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldCapacity, v)) } // LeakSpeed applies equality check predicate on the "leakSpeed" field. It's identical to LeakSpeedEQ. func LeakSpeed(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldLeakSpeed), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldLeakSpeed, v)) } // ScenarioVersion applies equality check predicate on the "scenarioVersion" field. It's identical to ScenarioVersionEQ. func ScenarioVersion(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldScenarioVersion), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldScenarioVersion, v)) } // ScenarioHash applies equality check predicate on the "scenarioHash" field. It's identical to ScenarioHashEQ. func ScenarioHash(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldScenarioHash), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldScenarioHash, v)) } // Simulated applies equality check predicate on the "simulated" field. It's identical to SimulatedEQ. func Simulated(v bool) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldSimulated), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldSimulated, v)) } // UUID applies equality check predicate on the "uuid" field. It's identical to UUIDEQ. func UUID(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldUUID), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldUUID, v)) } // CreatedAtEQ applies the EQ predicate on the "created_at" field. func CreatedAtEQ(v time.Time) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldCreatedAt), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldCreatedAt, v)) } // CreatedAtNEQ applies the NEQ predicate on the "created_at" field. func CreatedAtNEQ(v time.Time) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldCreatedAt), v)) - }) + return predicate.Alert(sql.FieldNEQ(FieldCreatedAt, v)) } // CreatedAtIn applies the In predicate on the "created_at" field. func CreatedAtIn(vs ...time.Time) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldCreatedAt), v...)) - }) + return predicate.Alert(sql.FieldIn(FieldCreatedAt, vs...)) } // CreatedAtNotIn applies the NotIn predicate on the "created_at" field. func CreatedAtNotIn(vs ...time.Time) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldCreatedAt), v...)) - }) + return predicate.Alert(sql.FieldNotIn(FieldCreatedAt, vs...)) } // CreatedAtGT applies the GT predicate on the "created_at" field. func CreatedAtGT(v time.Time) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldCreatedAt), v)) - }) + return predicate.Alert(sql.FieldGT(FieldCreatedAt, v)) } // CreatedAtGTE applies the GTE predicate on the "created_at" field. func CreatedAtGTE(v time.Time) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldCreatedAt), v)) - }) + return predicate.Alert(sql.FieldGTE(FieldCreatedAt, v)) } // CreatedAtLT applies the LT predicate on the "created_at" field. func CreatedAtLT(v time.Time) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldCreatedAt), v)) - }) + return predicate.Alert(sql.FieldLT(FieldCreatedAt, v)) } // CreatedAtLTE applies the LTE predicate on the "created_at" field. func CreatedAtLTE(v time.Time) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldCreatedAt), v)) - }) + return predicate.Alert(sql.FieldLTE(FieldCreatedAt, v)) } // CreatedAtIsNil applies the IsNil predicate on the "created_at" field. func CreatedAtIsNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldCreatedAt))) - }) + return predicate.Alert(sql.FieldIsNull(FieldCreatedAt)) } // CreatedAtNotNil applies the NotNil predicate on the "created_at" field. func CreatedAtNotNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldCreatedAt))) - }) + return predicate.Alert(sql.FieldNotNull(FieldCreatedAt)) } // UpdatedAtEQ applies the EQ predicate on the "updated_at" field. func UpdatedAtEQ(v time.Time) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldUpdatedAt), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldUpdatedAt, v)) } // UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field. func UpdatedAtNEQ(v time.Time) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldUpdatedAt), v)) - }) + return predicate.Alert(sql.FieldNEQ(FieldUpdatedAt, v)) } // UpdatedAtIn applies the In predicate on the "updated_at" field. func UpdatedAtIn(vs ...time.Time) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldUpdatedAt), v...)) - }) + return predicate.Alert(sql.FieldIn(FieldUpdatedAt, vs...)) } // UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field. func UpdatedAtNotIn(vs ...time.Time) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldUpdatedAt), v...)) - }) + return predicate.Alert(sql.FieldNotIn(FieldUpdatedAt, vs...)) } // UpdatedAtGT applies the GT predicate on the "updated_at" field. func UpdatedAtGT(v time.Time) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldUpdatedAt), v)) - }) + return predicate.Alert(sql.FieldGT(FieldUpdatedAt, v)) } // UpdatedAtGTE applies the GTE predicate on the "updated_at" field. func UpdatedAtGTE(v time.Time) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldUpdatedAt), v)) - }) + return predicate.Alert(sql.FieldGTE(FieldUpdatedAt, v)) } // UpdatedAtLT applies the LT predicate on the "updated_at" field. func UpdatedAtLT(v time.Time) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldUpdatedAt), v)) - }) + return predicate.Alert(sql.FieldLT(FieldUpdatedAt, v)) } // UpdatedAtLTE applies the LTE predicate on the "updated_at" field. func UpdatedAtLTE(v time.Time) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldUpdatedAt), v)) - }) + return predicate.Alert(sql.FieldLTE(FieldUpdatedAt, v)) } // UpdatedAtIsNil applies the IsNil predicate on the "updated_at" field. func UpdatedAtIsNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldUpdatedAt))) - }) + return predicate.Alert(sql.FieldIsNull(FieldUpdatedAt)) } // UpdatedAtNotNil applies the NotNil predicate on the "updated_at" field. func UpdatedAtNotNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldUpdatedAt))) - }) + return predicate.Alert(sql.FieldNotNull(FieldUpdatedAt)) } // ScenarioEQ applies the EQ predicate on the "scenario" field. func ScenarioEQ(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldScenario), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldScenario, v)) } // ScenarioNEQ applies the NEQ predicate on the "scenario" field. func ScenarioNEQ(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldScenario), v)) - }) + return predicate.Alert(sql.FieldNEQ(FieldScenario, v)) } // ScenarioIn applies the In predicate on the "scenario" field. func ScenarioIn(vs ...string) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldScenario), v...)) - }) + return predicate.Alert(sql.FieldIn(FieldScenario, vs...)) } // ScenarioNotIn applies the NotIn predicate on the "scenario" field. func ScenarioNotIn(vs ...string) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldScenario), v...)) - }) + return predicate.Alert(sql.FieldNotIn(FieldScenario, vs...)) } // ScenarioGT applies the GT predicate on the "scenario" field. func ScenarioGT(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldScenario), v)) - }) + return predicate.Alert(sql.FieldGT(FieldScenario, v)) } // ScenarioGTE applies the GTE predicate on the "scenario" field. func ScenarioGTE(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldScenario), v)) - }) + return predicate.Alert(sql.FieldGTE(FieldScenario, v)) } // ScenarioLT applies the LT predicate on the "scenario" field. func ScenarioLT(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldScenario), v)) - }) + return predicate.Alert(sql.FieldLT(FieldScenario, v)) } // ScenarioLTE applies the LTE predicate on the "scenario" field. func ScenarioLTE(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldScenario), v)) - }) + return predicate.Alert(sql.FieldLTE(FieldScenario, v)) } // ScenarioContains applies the Contains predicate on the "scenario" field. func ScenarioContains(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldScenario), v)) - }) + return predicate.Alert(sql.FieldContains(FieldScenario, v)) } // ScenarioHasPrefix applies the HasPrefix predicate on the "scenario" field. func ScenarioHasPrefix(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldScenario), v)) - }) + return predicate.Alert(sql.FieldHasPrefix(FieldScenario, v)) } // ScenarioHasSuffix applies the HasSuffix predicate on the "scenario" field. func ScenarioHasSuffix(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldScenario), v)) - }) + return predicate.Alert(sql.FieldHasSuffix(FieldScenario, v)) } // ScenarioEqualFold applies the EqualFold predicate on the "scenario" field. func ScenarioEqualFold(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldScenario), v)) - }) + return predicate.Alert(sql.FieldEqualFold(FieldScenario, v)) } // ScenarioContainsFold applies the ContainsFold predicate on the "scenario" field. func ScenarioContainsFold(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldScenario), v)) - }) + return predicate.Alert(sql.FieldContainsFold(FieldScenario, v)) } // BucketIdEQ applies the EQ predicate on the "bucketId" field. func BucketIdEQ(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldBucketId), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldBucketId, v)) } // BucketIdNEQ applies the NEQ predicate on the "bucketId" field. func BucketIdNEQ(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldBucketId), v)) - }) + return predicate.Alert(sql.FieldNEQ(FieldBucketId, v)) } // BucketIdIn applies the In predicate on the "bucketId" field. func BucketIdIn(vs ...string) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldBucketId), v...)) - }) + return predicate.Alert(sql.FieldIn(FieldBucketId, vs...)) } // BucketIdNotIn applies the NotIn predicate on the "bucketId" field. func BucketIdNotIn(vs ...string) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldBucketId), v...)) - }) + return predicate.Alert(sql.FieldNotIn(FieldBucketId, vs...)) } // BucketIdGT applies the GT predicate on the "bucketId" field. func BucketIdGT(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldBucketId), v)) - }) + return predicate.Alert(sql.FieldGT(FieldBucketId, v)) } // BucketIdGTE applies the GTE predicate on the "bucketId" field. func BucketIdGTE(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldBucketId), v)) - }) + return predicate.Alert(sql.FieldGTE(FieldBucketId, v)) } // BucketIdLT applies the LT predicate on the "bucketId" field. func BucketIdLT(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldBucketId), v)) - }) + return predicate.Alert(sql.FieldLT(FieldBucketId, v)) } // BucketIdLTE applies the LTE predicate on the "bucketId" field. func BucketIdLTE(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldBucketId), v)) - }) + return predicate.Alert(sql.FieldLTE(FieldBucketId, v)) } // BucketIdContains applies the Contains predicate on the "bucketId" field. func BucketIdContains(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldBucketId), v)) - }) + return predicate.Alert(sql.FieldContains(FieldBucketId, v)) } // BucketIdHasPrefix applies the HasPrefix predicate on the "bucketId" field. func BucketIdHasPrefix(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldBucketId), v)) - }) + return predicate.Alert(sql.FieldHasPrefix(FieldBucketId, v)) } // BucketIdHasSuffix applies the HasSuffix predicate on the "bucketId" field. func BucketIdHasSuffix(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldBucketId), v)) - }) + return predicate.Alert(sql.FieldHasSuffix(FieldBucketId, v)) } // BucketIdIsNil applies the IsNil predicate on the "bucketId" field. func BucketIdIsNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldBucketId))) - }) + return predicate.Alert(sql.FieldIsNull(FieldBucketId)) } // BucketIdNotNil applies the NotNil predicate on the "bucketId" field. func BucketIdNotNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldBucketId))) - }) + return predicate.Alert(sql.FieldNotNull(FieldBucketId)) } // BucketIdEqualFold applies the EqualFold predicate on the "bucketId" field. func BucketIdEqualFold(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldBucketId), v)) - }) + return predicate.Alert(sql.FieldEqualFold(FieldBucketId, v)) } // BucketIdContainsFold applies the ContainsFold predicate on the "bucketId" field. func BucketIdContainsFold(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldBucketId), v)) - }) + return predicate.Alert(sql.FieldContainsFold(FieldBucketId, v)) } // MessageEQ applies the EQ predicate on the "message" field. func MessageEQ(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldMessage), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldMessage, v)) } // MessageNEQ applies the NEQ predicate on the "message" field. func MessageNEQ(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldMessage), v)) - }) + return predicate.Alert(sql.FieldNEQ(FieldMessage, v)) } // MessageIn applies the In predicate on the "message" field. func MessageIn(vs ...string) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldMessage), v...)) - }) + return predicate.Alert(sql.FieldIn(FieldMessage, vs...)) } // MessageNotIn applies the NotIn predicate on the "message" field. func MessageNotIn(vs ...string) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldMessage), v...)) - }) + return predicate.Alert(sql.FieldNotIn(FieldMessage, vs...)) } // MessageGT applies the GT predicate on the "message" field. func MessageGT(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldMessage), v)) - }) + return predicate.Alert(sql.FieldGT(FieldMessage, v)) } // MessageGTE applies the GTE predicate on the "message" field. func MessageGTE(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldMessage), v)) - }) + return predicate.Alert(sql.FieldGTE(FieldMessage, v)) } // MessageLT applies the LT predicate on the "message" field. func MessageLT(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldMessage), v)) - }) + return predicate.Alert(sql.FieldLT(FieldMessage, v)) } // MessageLTE applies the LTE predicate on the "message" field. func MessageLTE(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldMessage), v)) - }) + return predicate.Alert(sql.FieldLTE(FieldMessage, v)) } // MessageContains applies the Contains predicate on the "message" field. func MessageContains(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldMessage), v)) - }) + return predicate.Alert(sql.FieldContains(FieldMessage, v)) } // MessageHasPrefix applies the HasPrefix predicate on the "message" field. func MessageHasPrefix(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldMessage), v)) - }) + return predicate.Alert(sql.FieldHasPrefix(FieldMessage, v)) } // MessageHasSuffix applies the HasSuffix predicate on the "message" field. func MessageHasSuffix(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldMessage), v)) - }) + return predicate.Alert(sql.FieldHasSuffix(FieldMessage, v)) } // MessageIsNil applies the IsNil predicate on the "message" field. func MessageIsNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldMessage))) - }) + return predicate.Alert(sql.FieldIsNull(FieldMessage)) } // MessageNotNil applies the NotNil predicate on the "message" field. func MessageNotNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldMessage))) - }) + return predicate.Alert(sql.FieldNotNull(FieldMessage)) } // MessageEqualFold applies the EqualFold predicate on the "message" field. func MessageEqualFold(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldMessage), v)) - }) + return predicate.Alert(sql.FieldEqualFold(FieldMessage, v)) } // MessageContainsFold applies the ContainsFold predicate on the "message" field. func MessageContainsFold(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldMessage), v)) - }) + return predicate.Alert(sql.FieldContainsFold(FieldMessage, v)) } // EventsCountEQ applies the EQ predicate on the "eventsCount" field. func EventsCountEQ(v int32) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldEventsCount), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldEventsCount, v)) } // EventsCountNEQ applies the NEQ predicate on the "eventsCount" field. func EventsCountNEQ(v int32) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldEventsCount), v)) - }) + return predicate.Alert(sql.FieldNEQ(FieldEventsCount, v)) } // EventsCountIn applies the In predicate on the "eventsCount" field. func EventsCountIn(vs ...int32) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldEventsCount), v...)) - }) + return predicate.Alert(sql.FieldIn(FieldEventsCount, vs...)) } // EventsCountNotIn applies the NotIn predicate on the "eventsCount" field. func EventsCountNotIn(vs ...int32) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldEventsCount), v...)) - }) + return predicate.Alert(sql.FieldNotIn(FieldEventsCount, vs...)) } // EventsCountGT applies the GT predicate on the "eventsCount" field. func EventsCountGT(v int32) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldEventsCount), v)) - }) + return predicate.Alert(sql.FieldGT(FieldEventsCount, v)) } // EventsCountGTE applies the GTE predicate on the "eventsCount" field. func EventsCountGTE(v int32) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldEventsCount), v)) - }) + return predicate.Alert(sql.FieldGTE(FieldEventsCount, v)) } // EventsCountLT applies the LT predicate on the "eventsCount" field. func EventsCountLT(v int32) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldEventsCount), v)) - }) + return predicate.Alert(sql.FieldLT(FieldEventsCount, v)) } // EventsCountLTE applies the LTE predicate on the "eventsCount" field. func EventsCountLTE(v int32) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldEventsCount), v)) - }) + return predicate.Alert(sql.FieldLTE(FieldEventsCount, v)) } // EventsCountIsNil applies the IsNil predicate on the "eventsCount" field. func EventsCountIsNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldEventsCount))) - }) + return predicate.Alert(sql.FieldIsNull(FieldEventsCount)) } // EventsCountNotNil applies the NotNil predicate on the "eventsCount" field. func EventsCountNotNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldEventsCount))) - }) + return predicate.Alert(sql.FieldNotNull(FieldEventsCount)) } // StartedAtEQ applies the EQ predicate on the "startedAt" field. func StartedAtEQ(v time.Time) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldStartedAt), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldStartedAt, v)) } // StartedAtNEQ applies the NEQ predicate on the "startedAt" field. func StartedAtNEQ(v time.Time) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldStartedAt), v)) - }) + return predicate.Alert(sql.FieldNEQ(FieldStartedAt, v)) } // StartedAtIn applies the In predicate on the "startedAt" field. func StartedAtIn(vs ...time.Time) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldStartedAt), v...)) - }) + return predicate.Alert(sql.FieldIn(FieldStartedAt, vs...)) } // StartedAtNotIn applies the NotIn predicate on the "startedAt" field. func StartedAtNotIn(vs ...time.Time) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldStartedAt), v...)) - }) + return predicate.Alert(sql.FieldNotIn(FieldStartedAt, vs...)) } // StartedAtGT applies the GT predicate on the "startedAt" field. func StartedAtGT(v time.Time) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldStartedAt), v)) - }) + return predicate.Alert(sql.FieldGT(FieldStartedAt, v)) } // StartedAtGTE applies the GTE predicate on the "startedAt" field. func StartedAtGTE(v time.Time) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldStartedAt), v)) - }) + return predicate.Alert(sql.FieldGTE(FieldStartedAt, v)) } // StartedAtLT applies the LT predicate on the "startedAt" field. func StartedAtLT(v time.Time) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldStartedAt), v)) - }) + return predicate.Alert(sql.FieldLT(FieldStartedAt, v)) } // StartedAtLTE applies the LTE predicate on the "startedAt" field. func StartedAtLTE(v time.Time) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldStartedAt), v)) - }) + return predicate.Alert(sql.FieldLTE(FieldStartedAt, v)) } // StartedAtIsNil applies the IsNil predicate on the "startedAt" field. func StartedAtIsNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldStartedAt))) - }) + return predicate.Alert(sql.FieldIsNull(FieldStartedAt)) } // StartedAtNotNil applies the NotNil predicate on the "startedAt" field. func StartedAtNotNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldStartedAt))) - }) + return predicate.Alert(sql.FieldNotNull(FieldStartedAt)) } // StoppedAtEQ applies the EQ predicate on the "stoppedAt" field. func StoppedAtEQ(v time.Time) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldStoppedAt), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldStoppedAt, v)) } // StoppedAtNEQ applies the NEQ predicate on the "stoppedAt" field. func StoppedAtNEQ(v time.Time) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldStoppedAt), v)) - }) + return predicate.Alert(sql.FieldNEQ(FieldStoppedAt, v)) } // StoppedAtIn applies the In predicate on the "stoppedAt" field. func StoppedAtIn(vs ...time.Time) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldStoppedAt), v...)) - }) + return predicate.Alert(sql.FieldIn(FieldStoppedAt, vs...)) } // StoppedAtNotIn applies the NotIn predicate on the "stoppedAt" field. func StoppedAtNotIn(vs ...time.Time) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldStoppedAt), v...)) - }) + return predicate.Alert(sql.FieldNotIn(FieldStoppedAt, vs...)) } // StoppedAtGT applies the GT predicate on the "stoppedAt" field. func StoppedAtGT(v time.Time) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldStoppedAt), v)) - }) + return predicate.Alert(sql.FieldGT(FieldStoppedAt, v)) } // StoppedAtGTE applies the GTE predicate on the "stoppedAt" field. func StoppedAtGTE(v time.Time) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldStoppedAt), v)) - }) + return predicate.Alert(sql.FieldGTE(FieldStoppedAt, v)) } // StoppedAtLT applies the LT predicate on the "stoppedAt" field. func StoppedAtLT(v time.Time) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldStoppedAt), v)) - }) + return predicate.Alert(sql.FieldLT(FieldStoppedAt, v)) } // StoppedAtLTE applies the LTE predicate on the "stoppedAt" field. func StoppedAtLTE(v time.Time) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldStoppedAt), v)) - }) + return predicate.Alert(sql.FieldLTE(FieldStoppedAt, v)) } // StoppedAtIsNil applies the IsNil predicate on the "stoppedAt" field. func StoppedAtIsNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldStoppedAt))) - }) + return predicate.Alert(sql.FieldIsNull(FieldStoppedAt)) } // StoppedAtNotNil applies the NotNil predicate on the "stoppedAt" field. func StoppedAtNotNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldStoppedAt))) - }) + return predicate.Alert(sql.FieldNotNull(FieldStoppedAt)) } // SourceIpEQ applies the EQ predicate on the "sourceIp" field. func SourceIpEQ(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldSourceIp), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldSourceIp, v)) } // SourceIpNEQ applies the NEQ predicate on the "sourceIp" field. func SourceIpNEQ(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldSourceIp), v)) - }) + return predicate.Alert(sql.FieldNEQ(FieldSourceIp, v)) } // SourceIpIn applies the In predicate on the "sourceIp" field. func SourceIpIn(vs ...string) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldSourceIp), v...)) - }) + return predicate.Alert(sql.FieldIn(FieldSourceIp, vs...)) } // SourceIpNotIn applies the NotIn predicate on the "sourceIp" field. func SourceIpNotIn(vs ...string) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldSourceIp), v...)) - }) + return predicate.Alert(sql.FieldNotIn(FieldSourceIp, vs...)) } // SourceIpGT applies the GT predicate on the "sourceIp" field. func SourceIpGT(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldSourceIp), v)) - }) + return predicate.Alert(sql.FieldGT(FieldSourceIp, v)) } // SourceIpGTE applies the GTE predicate on the "sourceIp" field. func SourceIpGTE(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldSourceIp), v)) - }) + return predicate.Alert(sql.FieldGTE(FieldSourceIp, v)) } // SourceIpLT applies the LT predicate on the "sourceIp" field. func SourceIpLT(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldSourceIp), v)) - }) + return predicate.Alert(sql.FieldLT(FieldSourceIp, v)) } // SourceIpLTE applies the LTE predicate on the "sourceIp" field. func SourceIpLTE(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldSourceIp), v)) - }) + return predicate.Alert(sql.FieldLTE(FieldSourceIp, v)) } // SourceIpContains applies the Contains predicate on the "sourceIp" field. func SourceIpContains(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldSourceIp), v)) - }) + return predicate.Alert(sql.FieldContains(FieldSourceIp, v)) } // SourceIpHasPrefix applies the HasPrefix predicate on the "sourceIp" field. func SourceIpHasPrefix(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldSourceIp), v)) - }) + return predicate.Alert(sql.FieldHasPrefix(FieldSourceIp, v)) } // SourceIpHasSuffix applies the HasSuffix predicate on the "sourceIp" field. func SourceIpHasSuffix(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldSourceIp), v)) - }) + return predicate.Alert(sql.FieldHasSuffix(FieldSourceIp, v)) } // SourceIpIsNil applies the IsNil predicate on the "sourceIp" field. func SourceIpIsNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldSourceIp))) - }) + return predicate.Alert(sql.FieldIsNull(FieldSourceIp)) } // SourceIpNotNil applies the NotNil predicate on the "sourceIp" field. func SourceIpNotNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldSourceIp))) - }) + return predicate.Alert(sql.FieldNotNull(FieldSourceIp)) } // SourceIpEqualFold applies the EqualFold predicate on the "sourceIp" field. func SourceIpEqualFold(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldSourceIp), v)) - }) + return predicate.Alert(sql.FieldEqualFold(FieldSourceIp, v)) } // SourceIpContainsFold applies the ContainsFold predicate on the "sourceIp" field. func SourceIpContainsFold(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldSourceIp), v)) - }) + return predicate.Alert(sql.FieldContainsFold(FieldSourceIp, v)) } // SourceRangeEQ applies the EQ predicate on the "sourceRange" field. func SourceRangeEQ(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldSourceRange), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldSourceRange, v)) } // SourceRangeNEQ applies the NEQ predicate on the "sourceRange" field. func SourceRangeNEQ(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldSourceRange), v)) - }) + return predicate.Alert(sql.FieldNEQ(FieldSourceRange, v)) } // SourceRangeIn applies the In predicate on the "sourceRange" field. func SourceRangeIn(vs ...string) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldSourceRange), v...)) - }) + return predicate.Alert(sql.FieldIn(FieldSourceRange, vs...)) } // SourceRangeNotIn applies the NotIn predicate on the "sourceRange" field. func SourceRangeNotIn(vs ...string) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldSourceRange), v...)) - }) + return predicate.Alert(sql.FieldNotIn(FieldSourceRange, vs...)) } // SourceRangeGT applies the GT predicate on the "sourceRange" field. func SourceRangeGT(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldSourceRange), v)) - }) + return predicate.Alert(sql.FieldGT(FieldSourceRange, v)) } // SourceRangeGTE applies the GTE predicate on the "sourceRange" field. func SourceRangeGTE(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldSourceRange), v)) - }) + return predicate.Alert(sql.FieldGTE(FieldSourceRange, v)) } // SourceRangeLT applies the LT predicate on the "sourceRange" field. func SourceRangeLT(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldSourceRange), v)) - }) + return predicate.Alert(sql.FieldLT(FieldSourceRange, v)) } // SourceRangeLTE applies the LTE predicate on the "sourceRange" field. func SourceRangeLTE(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldSourceRange), v)) - }) + return predicate.Alert(sql.FieldLTE(FieldSourceRange, v)) } // SourceRangeContains applies the Contains predicate on the "sourceRange" field. func SourceRangeContains(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldSourceRange), v)) - }) + return predicate.Alert(sql.FieldContains(FieldSourceRange, v)) } // SourceRangeHasPrefix applies the HasPrefix predicate on the "sourceRange" field. func SourceRangeHasPrefix(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldSourceRange), v)) - }) + return predicate.Alert(sql.FieldHasPrefix(FieldSourceRange, v)) } // SourceRangeHasSuffix applies the HasSuffix predicate on the "sourceRange" field. func SourceRangeHasSuffix(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldSourceRange), v)) - }) + return predicate.Alert(sql.FieldHasSuffix(FieldSourceRange, v)) } // SourceRangeIsNil applies the IsNil predicate on the "sourceRange" field. func SourceRangeIsNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldSourceRange))) - }) + return predicate.Alert(sql.FieldIsNull(FieldSourceRange)) } // SourceRangeNotNil applies the NotNil predicate on the "sourceRange" field. func SourceRangeNotNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldSourceRange))) - }) + return predicate.Alert(sql.FieldNotNull(FieldSourceRange)) } // SourceRangeEqualFold applies the EqualFold predicate on the "sourceRange" field. func SourceRangeEqualFold(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldSourceRange), v)) - }) + return predicate.Alert(sql.FieldEqualFold(FieldSourceRange, v)) } // SourceRangeContainsFold applies the ContainsFold predicate on the "sourceRange" field. func SourceRangeContainsFold(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldSourceRange), v)) - }) + return predicate.Alert(sql.FieldContainsFold(FieldSourceRange, v)) } // SourceAsNumberEQ applies the EQ predicate on the "sourceAsNumber" field. func SourceAsNumberEQ(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldSourceAsNumber), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldSourceAsNumber, v)) } // SourceAsNumberNEQ applies the NEQ predicate on the "sourceAsNumber" field. func SourceAsNumberNEQ(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldSourceAsNumber), v)) - }) + return predicate.Alert(sql.FieldNEQ(FieldSourceAsNumber, v)) } // SourceAsNumberIn applies the In predicate on the "sourceAsNumber" field. func SourceAsNumberIn(vs ...string) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldSourceAsNumber), v...)) - }) + return predicate.Alert(sql.FieldIn(FieldSourceAsNumber, vs...)) } // SourceAsNumberNotIn applies the NotIn predicate on the "sourceAsNumber" field. func SourceAsNumberNotIn(vs ...string) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldSourceAsNumber), v...)) - }) + return predicate.Alert(sql.FieldNotIn(FieldSourceAsNumber, vs...)) } // SourceAsNumberGT applies the GT predicate on the "sourceAsNumber" field. func SourceAsNumberGT(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldSourceAsNumber), v)) - }) + return predicate.Alert(sql.FieldGT(FieldSourceAsNumber, v)) } // SourceAsNumberGTE applies the GTE predicate on the "sourceAsNumber" field. func SourceAsNumberGTE(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldSourceAsNumber), v)) - }) + return predicate.Alert(sql.FieldGTE(FieldSourceAsNumber, v)) } // SourceAsNumberLT applies the LT predicate on the "sourceAsNumber" field. func SourceAsNumberLT(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldSourceAsNumber), v)) - }) + return predicate.Alert(sql.FieldLT(FieldSourceAsNumber, v)) } // SourceAsNumberLTE applies the LTE predicate on the "sourceAsNumber" field. func SourceAsNumberLTE(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldSourceAsNumber), v)) - }) + return predicate.Alert(sql.FieldLTE(FieldSourceAsNumber, v)) } // SourceAsNumberContains applies the Contains predicate on the "sourceAsNumber" field. func SourceAsNumberContains(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldSourceAsNumber), v)) - }) + return predicate.Alert(sql.FieldContains(FieldSourceAsNumber, v)) } // SourceAsNumberHasPrefix applies the HasPrefix predicate on the "sourceAsNumber" field. func SourceAsNumberHasPrefix(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldSourceAsNumber), v)) - }) + return predicate.Alert(sql.FieldHasPrefix(FieldSourceAsNumber, v)) } // SourceAsNumberHasSuffix applies the HasSuffix predicate on the "sourceAsNumber" field. func SourceAsNumberHasSuffix(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldSourceAsNumber), v)) - }) + return predicate.Alert(sql.FieldHasSuffix(FieldSourceAsNumber, v)) } // SourceAsNumberIsNil applies the IsNil predicate on the "sourceAsNumber" field. func SourceAsNumberIsNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldSourceAsNumber))) - }) + return predicate.Alert(sql.FieldIsNull(FieldSourceAsNumber)) } // SourceAsNumberNotNil applies the NotNil predicate on the "sourceAsNumber" field. func SourceAsNumberNotNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldSourceAsNumber))) - }) + return predicate.Alert(sql.FieldNotNull(FieldSourceAsNumber)) } // SourceAsNumberEqualFold applies the EqualFold predicate on the "sourceAsNumber" field. func SourceAsNumberEqualFold(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldSourceAsNumber), v)) - }) + return predicate.Alert(sql.FieldEqualFold(FieldSourceAsNumber, v)) } // SourceAsNumberContainsFold applies the ContainsFold predicate on the "sourceAsNumber" field. func SourceAsNumberContainsFold(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldSourceAsNumber), v)) - }) + return predicate.Alert(sql.FieldContainsFold(FieldSourceAsNumber, v)) } // SourceAsNameEQ applies the EQ predicate on the "sourceAsName" field. func SourceAsNameEQ(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldSourceAsName), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldSourceAsName, v)) } // SourceAsNameNEQ applies the NEQ predicate on the "sourceAsName" field. func SourceAsNameNEQ(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldSourceAsName), v)) - }) + return predicate.Alert(sql.FieldNEQ(FieldSourceAsName, v)) } // SourceAsNameIn applies the In predicate on the "sourceAsName" field. func SourceAsNameIn(vs ...string) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldSourceAsName), v...)) - }) + return predicate.Alert(sql.FieldIn(FieldSourceAsName, vs...)) } // SourceAsNameNotIn applies the NotIn predicate on the "sourceAsName" field. func SourceAsNameNotIn(vs ...string) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldSourceAsName), v...)) - }) + return predicate.Alert(sql.FieldNotIn(FieldSourceAsName, vs...)) } // SourceAsNameGT applies the GT predicate on the "sourceAsName" field. func SourceAsNameGT(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldSourceAsName), v)) - }) + return predicate.Alert(sql.FieldGT(FieldSourceAsName, v)) } // SourceAsNameGTE applies the GTE predicate on the "sourceAsName" field. func SourceAsNameGTE(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldSourceAsName), v)) - }) + return predicate.Alert(sql.FieldGTE(FieldSourceAsName, v)) } // SourceAsNameLT applies the LT predicate on the "sourceAsName" field. func SourceAsNameLT(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldSourceAsName), v)) - }) + return predicate.Alert(sql.FieldLT(FieldSourceAsName, v)) } // SourceAsNameLTE applies the LTE predicate on the "sourceAsName" field. func SourceAsNameLTE(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldSourceAsName), v)) - }) + return predicate.Alert(sql.FieldLTE(FieldSourceAsName, v)) } // SourceAsNameContains applies the Contains predicate on the "sourceAsName" field. func SourceAsNameContains(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldSourceAsName), v)) - }) + return predicate.Alert(sql.FieldContains(FieldSourceAsName, v)) } // SourceAsNameHasPrefix applies the HasPrefix predicate on the "sourceAsName" field. func SourceAsNameHasPrefix(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldSourceAsName), v)) - }) + return predicate.Alert(sql.FieldHasPrefix(FieldSourceAsName, v)) } // SourceAsNameHasSuffix applies the HasSuffix predicate on the "sourceAsName" field. func SourceAsNameHasSuffix(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldSourceAsName), v)) - }) + return predicate.Alert(sql.FieldHasSuffix(FieldSourceAsName, v)) } // SourceAsNameIsNil applies the IsNil predicate on the "sourceAsName" field. func SourceAsNameIsNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldSourceAsName))) - }) + return predicate.Alert(sql.FieldIsNull(FieldSourceAsName)) } // SourceAsNameNotNil applies the NotNil predicate on the "sourceAsName" field. func SourceAsNameNotNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldSourceAsName))) - }) + return predicate.Alert(sql.FieldNotNull(FieldSourceAsName)) } // SourceAsNameEqualFold applies the EqualFold predicate on the "sourceAsName" field. func SourceAsNameEqualFold(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldSourceAsName), v)) - }) + return predicate.Alert(sql.FieldEqualFold(FieldSourceAsName, v)) } // SourceAsNameContainsFold applies the ContainsFold predicate on the "sourceAsName" field. func SourceAsNameContainsFold(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldSourceAsName), v)) - }) + return predicate.Alert(sql.FieldContainsFold(FieldSourceAsName, v)) } // SourceCountryEQ applies the EQ predicate on the "sourceCountry" field. func SourceCountryEQ(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldSourceCountry), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldSourceCountry, v)) } // SourceCountryNEQ applies the NEQ predicate on the "sourceCountry" field. func SourceCountryNEQ(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldSourceCountry), v)) - }) + return predicate.Alert(sql.FieldNEQ(FieldSourceCountry, v)) } // SourceCountryIn applies the In predicate on the "sourceCountry" field. func SourceCountryIn(vs ...string) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldSourceCountry), v...)) - }) + return predicate.Alert(sql.FieldIn(FieldSourceCountry, vs...)) } // SourceCountryNotIn applies the NotIn predicate on the "sourceCountry" field. func SourceCountryNotIn(vs ...string) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldSourceCountry), v...)) - }) + return predicate.Alert(sql.FieldNotIn(FieldSourceCountry, vs...)) } // SourceCountryGT applies the GT predicate on the "sourceCountry" field. func SourceCountryGT(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldSourceCountry), v)) - }) + return predicate.Alert(sql.FieldGT(FieldSourceCountry, v)) } // SourceCountryGTE applies the GTE predicate on the "sourceCountry" field. func SourceCountryGTE(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldSourceCountry), v)) - }) + return predicate.Alert(sql.FieldGTE(FieldSourceCountry, v)) } // SourceCountryLT applies the LT predicate on the "sourceCountry" field. func SourceCountryLT(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldSourceCountry), v)) - }) + return predicate.Alert(sql.FieldLT(FieldSourceCountry, v)) } // SourceCountryLTE applies the LTE predicate on the "sourceCountry" field. func SourceCountryLTE(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldSourceCountry), v)) - }) + return predicate.Alert(sql.FieldLTE(FieldSourceCountry, v)) } // SourceCountryContains applies the Contains predicate on the "sourceCountry" field. func SourceCountryContains(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldSourceCountry), v)) - }) + return predicate.Alert(sql.FieldContains(FieldSourceCountry, v)) } // SourceCountryHasPrefix applies the HasPrefix predicate on the "sourceCountry" field. func SourceCountryHasPrefix(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldSourceCountry), v)) - }) + return predicate.Alert(sql.FieldHasPrefix(FieldSourceCountry, v)) } // SourceCountryHasSuffix applies the HasSuffix predicate on the "sourceCountry" field. func SourceCountryHasSuffix(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldSourceCountry), v)) - }) + return predicate.Alert(sql.FieldHasSuffix(FieldSourceCountry, v)) } // SourceCountryIsNil applies the IsNil predicate on the "sourceCountry" field. func SourceCountryIsNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldSourceCountry))) - }) + return predicate.Alert(sql.FieldIsNull(FieldSourceCountry)) } // SourceCountryNotNil applies the NotNil predicate on the "sourceCountry" field. func SourceCountryNotNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldSourceCountry))) - }) + return predicate.Alert(sql.FieldNotNull(FieldSourceCountry)) } // SourceCountryEqualFold applies the EqualFold predicate on the "sourceCountry" field. func SourceCountryEqualFold(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldSourceCountry), v)) - }) + return predicate.Alert(sql.FieldEqualFold(FieldSourceCountry, v)) } // SourceCountryContainsFold applies the ContainsFold predicate on the "sourceCountry" field. func SourceCountryContainsFold(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldSourceCountry), v)) - }) + return predicate.Alert(sql.FieldContainsFold(FieldSourceCountry, v)) } // SourceLatitudeEQ applies the EQ predicate on the "sourceLatitude" field. func SourceLatitudeEQ(v float32) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldSourceLatitude), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldSourceLatitude, v)) } // SourceLatitudeNEQ applies the NEQ predicate on the "sourceLatitude" field. func SourceLatitudeNEQ(v float32) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldSourceLatitude), v)) - }) + return predicate.Alert(sql.FieldNEQ(FieldSourceLatitude, v)) } // SourceLatitudeIn applies the In predicate on the "sourceLatitude" field. func SourceLatitudeIn(vs ...float32) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldSourceLatitude), v...)) - }) + return predicate.Alert(sql.FieldIn(FieldSourceLatitude, vs...)) } // SourceLatitudeNotIn applies the NotIn predicate on the "sourceLatitude" field. func SourceLatitudeNotIn(vs ...float32) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldSourceLatitude), v...)) - }) + return predicate.Alert(sql.FieldNotIn(FieldSourceLatitude, vs...)) } // SourceLatitudeGT applies the GT predicate on the "sourceLatitude" field. func SourceLatitudeGT(v float32) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldSourceLatitude), v)) - }) + return predicate.Alert(sql.FieldGT(FieldSourceLatitude, v)) } // SourceLatitudeGTE applies the GTE predicate on the "sourceLatitude" field. -func SourceLatitudeGTE(v float32) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldSourceLatitude), v)) - }) +func SourceLatitudeGTE(v float32) predicate.Alert { + return predicate.Alert(sql.FieldGTE(FieldSourceLatitude, v)) } // SourceLatitudeLT applies the LT predicate on the "sourceLatitude" field. func SourceLatitudeLT(v float32) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldSourceLatitude), v)) - }) + return predicate.Alert(sql.FieldLT(FieldSourceLatitude, v)) } // SourceLatitudeLTE applies the LTE predicate on the "sourceLatitude" field. func SourceLatitudeLTE(v float32) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldSourceLatitude), v)) - }) + return predicate.Alert(sql.FieldLTE(FieldSourceLatitude, v)) } // SourceLatitudeIsNil applies the IsNil predicate on the "sourceLatitude" field. func SourceLatitudeIsNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldSourceLatitude))) - }) + return predicate.Alert(sql.FieldIsNull(FieldSourceLatitude)) } // SourceLatitudeNotNil applies the NotNil predicate on the "sourceLatitude" field. func SourceLatitudeNotNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldSourceLatitude))) - }) + return predicate.Alert(sql.FieldNotNull(FieldSourceLatitude)) } // SourceLongitudeEQ applies the EQ predicate on the "sourceLongitude" field. func SourceLongitudeEQ(v float32) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldSourceLongitude), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldSourceLongitude, v)) } // SourceLongitudeNEQ applies the NEQ predicate on the "sourceLongitude" field. func SourceLongitudeNEQ(v float32) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldSourceLongitude), v)) - }) + return predicate.Alert(sql.FieldNEQ(FieldSourceLongitude, v)) } // SourceLongitudeIn applies the In predicate on the "sourceLongitude" field. func SourceLongitudeIn(vs ...float32) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldSourceLongitude), v...)) - }) + return predicate.Alert(sql.FieldIn(FieldSourceLongitude, vs...)) } // SourceLongitudeNotIn applies the NotIn predicate on the "sourceLongitude" field. func SourceLongitudeNotIn(vs ...float32) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldSourceLongitude), v...)) - }) + return predicate.Alert(sql.FieldNotIn(FieldSourceLongitude, vs...)) } // SourceLongitudeGT applies the GT predicate on the "sourceLongitude" field. func SourceLongitudeGT(v float32) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldSourceLongitude), v)) - }) + return predicate.Alert(sql.FieldGT(FieldSourceLongitude, v)) } // SourceLongitudeGTE applies the GTE predicate on the "sourceLongitude" field. func SourceLongitudeGTE(v float32) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldSourceLongitude), v)) - }) + return predicate.Alert(sql.FieldGTE(FieldSourceLongitude, v)) } // SourceLongitudeLT applies the LT predicate on the "sourceLongitude" field. func SourceLongitudeLT(v float32) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldSourceLongitude), v)) - }) + return predicate.Alert(sql.FieldLT(FieldSourceLongitude, v)) } // SourceLongitudeLTE applies the LTE predicate on the "sourceLongitude" field. func SourceLongitudeLTE(v float32) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldSourceLongitude), v)) - }) + return predicate.Alert(sql.FieldLTE(FieldSourceLongitude, v)) } // SourceLongitudeIsNil applies the IsNil predicate on the "sourceLongitude" field. func SourceLongitudeIsNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldSourceLongitude))) - }) + return predicate.Alert(sql.FieldIsNull(FieldSourceLongitude)) } // SourceLongitudeNotNil applies the NotNil predicate on the "sourceLongitude" field. func SourceLongitudeNotNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldSourceLongitude))) - }) + return predicate.Alert(sql.FieldNotNull(FieldSourceLongitude)) } // SourceScopeEQ applies the EQ predicate on the "sourceScope" field. func SourceScopeEQ(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldSourceScope), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldSourceScope, v)) } // SourceScopeNEQ applies the NEQ predicate on the "sourceScope" field. func SourceScopeNEQ(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldSourceScope), v)) - }) + return predicate.Alert(sql.FieldNEQ(FieldSourceScope, v)) } // SourceScopeIn applies the In predicate on the "sourceScope" field. func SourceScopeIn(vs ...string) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldSourceScope), v...)) - }) + return predicate.Alert(sql.FieldIn(FieldSourceScope, vs...)) } // SourceScopeNotIn applies the NotIn predicate on the "sourceScope" field. func SourceScopeNotIn(vs ...string) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldSourceScope), v...)) - }) + return predicate.Alert(sql.FieldNotIn(FieldSourceScope, vs...)) } // SourceScopeGT applies the GT predicate on the "sourceScope" field. func SourceScopeGT(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldSourceScope), v)) - }) + return predicate.Alert(sql.FieldGT(FieldSourceScope, v)) } // SourceScopeGTE applies the GTE predicate on the "sourceScope" field. func SourceScopeGTE(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldSourceScope), v)) - }) + return predicate.Alert(sql.FieldGTE(FieldSourceScope, v)) } // SourceScopeLT applies the LT predicate on the "sourceScope" field. func SourceScopeLT(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldSourceScope), v)) - }) + return predicate.Alert(sql.FieldLT(FieldSourceScope, v)) } // SourceScopeLTE applies the LTE predicate on the "sourceScope" field. func SourceScopeLTE(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldSourceScope), v)) - }) + return predicate.Alert(sql.FieldLTE(FieldSourceScope, v)) } // SourceScopeContains applies the Contains predicate on the "sourceScope" field. func SourceScopeContains(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldSourceScope), v)) - }) + return predicate.Alert(sql.FieldContains(FieldSourceScope, v)) } // SourceScopeHasPrefix applies the HasPrefix predicate on the "sourceScope" field. func SourceScopeHasPrefix(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldSourceScope), v)) - }) + return predicate.Alert(sql.FieldHasPrefix(FieldSourceScope, v)) } // SourceScopeHasSuffix applies the HasSuffix predicate on the "sourceScope" field. func SourceScopeHasSuffix(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldSourceScope), v)) - }) + return predicate.Alert(sql.FieldHasSuffix(FieldSourceScope, v)) } // SourceScopeIsNil applies the IsNil predicate on the "sourceScope" field. func SourceScopeIsNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldSourceScope))) - }) + return predicate.Alert(sql.FieldIsNull(FieldSourceScope)) } // SourceScopeNotNil applies the NotNil predicate on the "sourceScope" field. func SourceScopeNotNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldSourceScope))) - }) + return predicate.Alert(sql.FieldNotNull(FieldSourceScope)) } // SourceScopeEqualFold applies the EqualFold predicate on the "sourceScope" field. func SourceScopeEqualFold(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldSourceScope), v)) - }) + return predicate.Alert(sql.FieldEqualFold(FieldSourceScope, v)) } // SourceScopeContainsFold applies the ContainsFold predicate on the "sourceScope" field. func SourceScopeContainsFold(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldSourceScope), v)) - }) + return predicate.Alert(sql.FieldContainsFold(FieldSourceScope, v)) } // SourceValueEQ applies the EQ predicate on the "sourceValue" field. func SourceValueEQ(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldSourceValue), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldSourceValue, v)) } // SourceValueNEQ applies the NEQ predicate on the "sourceValue" field. func SourceValueNEQ(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldSourceValue), v)) - }) + return predicate.Alert(sql.FieldNEQ(FieldSourceValue, v)) } // SourceValueIn applies the In predicate on the "sourceValue" field. func SourceValueIn(vs ...string) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldSourceValue), v...)) - }) + return predicate.Alert(sql.FieldIn(FieldSourceValue, vs...)) } // SourceValueNotIn applies the NotIn predicate on the "sourceValue" field. func SourceValueNotIn(vs ...string) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldSourceValue), v...)) - }) + return predicate.Alert(sql.FieldNotIn(FieldSourceValue, vs...)) } // SourceValueGT applies the GT predicate on the "sourceValue" field. func SourceValueGT(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldSourceValue), v)) - }) + return predicate.Alert(sql.FieldGT(FieldSourceValue, v)) } // SourceValueGTE applies the GTE predicate on the "sourceValue" field. func SourceValueGTE(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldSourceValue), v)) - }) + return predicate.Alert(sql.FieldGTE(FieldSourceValue, v)) } // SourceValueLT applies the LT predicate on the "sourceValue" field. func SourceValueLT(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldSourceValue), v)) - }) + return predicate.Alert(sql.FieldLT(FieldSourceValue, v)) } // SourceValueLTE applies the LTE predicate on the "sourceValue" field. func SourceValueLTE(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldSourceValue), v)) - }) + return predicate.Alert(sql.FieldLTE(FieldSourceValue, v)) } // SourceValueContains applies the Contains predicate on the "sourceValue" field. func SourceValueContains(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldSourceValue), v)) - }) + return predicate.Alert(sql.FieldContains(FieldSourceValue, v)) } // SourceValueHasPrefix applies the HasPrefix predicate on the "sourceValue" field. func SourceValueHasPrefix(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldSourceValue), v)) - }) + return predicate.Alert(sql.FieldHasPrefix(FieldSourceValue, v)) } // SourceValueHasSuffix applies the HasSuffix predicate on the "sourceValue" field. func SourceValueHasSuffix(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldSourceValue), v)) - }) + return predicate.Alert(sql.FieldHasSuffix(FieldSourceValue, v)) } // SourceValueIsNil applies the IsNil predicate on the "sourceValue" field. func SourceValueIsNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldSourceValue))) - }) + return predicate.Alert(sql.FieldIsNull(FieldSourceValue)) } // SourceValueNotNil applies the NotNil predicate on the "sourceValue" field. func SourceValueNotNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldSourceValue))) - }) + return predicate.Alert(sql.FieldNotNull(FieldSourceValue)) } // SourceValueEqualFold applies the EqualFold predicate on the "sourceValue" field. func SourceValueEqualFold(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldSourceValue), v)) - }) + return predicate.Alert(sql.FieldEqualFold(FieldSourceValue, v)) } // SourceValueContainsFold applies the ContainsFold predicate on the "sourceValue" field. func SourceValueContainsFold(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldSourceValue), v)) - }) + return predicate.Alert(sql.FieldContainsFold(FieldSourceValue, v)) } // CapacityEQ applies the EQ predicate on the "capacity" field. func CapacityEQ(v int32) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldCapacity), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldCapacity, v)) } // CapacityNEQ applies the NEQ predicate on the "capacity" field. func CapacityNEQ(v int32) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldCapacity), v)) - }) + return predicate.Alert(sql.FieldNEQ(FieldCapacity, v)) } // CapacityIn applies the In predicate on the "capacity" field. func CapacityIn(vs ...int32) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldCapacity), v...)) - }) + return predicate.Alert(sql.FieldIn(FieldCapacity, vs...)) } // CapacityNotIn applies the NotIn predicate on the "capacity" field. func CapacityNotIn(vs ...int32) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldCapacity), v...)) - }) + return predicate.Alert(sql.FieldNotIn(FieldCapacity, vs...)) } // CapacityGT applies the GT predicate on the "capacity" field. func CapacityGT(v int32) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldCapacity), v)) - }) + return predicate.Alert(sql.FieldGT(FieldCapacity, v)) } // CapacityGTE applies the GTE predicate on the "capacity" field. func CapacityGTE(v int32) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldCapacity), v)) - }) + return predicate.Alert(sql.FieldGTE(FieldCapacity, v)) } // CapacityLT applies the LT predicate on the "capacity" field. func CapacityLT(v int32) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldCapacity), v)) - }) + return predicate.Alert(sql.FieldLT(FieldCapacity, v)) } // CapacityLTE applies the LTE predicate on the "capacity" field. func CapacityLTE(v int32) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldCapacity), v)) - }) + return predicate.Alert(sql.FieldLTE(FieldCapacity, v)) } // CapacityIsNil applies the IsNil predicate on the "capacity" field. func CapacityIsNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldCapacity))) - }) + return predicate.Alert(sql.FieldIsNull(FieldCapacity)) } // CapacityNotNil applies the NotNil predicate on the "capacity" field. func CapacityNotNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldCapacity))) - }) + return predicate.Alert(sql.FieldNotNull(FieldCapacity)) } // LeakSpeedEQ applies the EQ predicate on the "leakSpeed" field. func LeakSpeedEQ(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldLeakSpeed), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldLeakSpeed, v)) } // LeakSpeedNEQ applies the NEQ predicate on the "leakSpeed" field. func LeakSpeedNEQ(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldLeakSpeed), v)) - }) + return predicate.Alert(sql.FieldNEQ(FieldLeakSpeed, v)) } // LeakSpeedIn applies the In predicate on the "leakSpeed" field. func LeakSpeedIn(vs ...string) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldLeakSpeed), v...)) - }) + return predicate.Alert(sql.FieldIn(FieldLeakSpeed, vs...)) } // LeakSpeedNotIn applies the NotIn predicate on the "leakSpeed" field. func LeakSpeedNotIn(vs ...string) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldLeakSpeed), v...)) - }) + return predicate.Alert(sql.FieldNotIn(FieldLeakSpeed, vs...)) } // LeakSpeedGT applies the GT predicate on the "leakSpeed" field. func LeakSpeedGT(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldLeakSpeed), v)) - }) + return predicate.Alert(sql.FieldGT(FieldLeakSpeed, v)) } // LeakSpeedGTE applies the GTE predicate on the "leakSpeed" field. func LeakSpeedGTE(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldLeakSpeed), v)) - }) + return predicate.Alert(sql.FieldGTE(FieldLeakSpeed, v)) } // LeakSpeedLT applies the LT predicate on the "leakSpeed" field. func LeakSpeedLT(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldLeakSpeed), v)) - }) + return predicate.Alert(sql.FieldLT(FieldLeakSpeed, v)) } // LeakSpeedLTE applies the LTE predicate on the "leakSpeed" field. func LeakSpeedLTE(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldLeakSpeed), v)) - }) + return predicate.Alert(sql.FieldLTE(FieldLeakSpeed, v)) } // LeakSpeedContains applies the Contains predicate on the "leakSpeed" field. func LeakSpeedContains(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldLeakSpeed), v)) - }) + return predicate.Alert(sql.FieldContains(FieldLeakSpeed, v)) } // LeakSpeedHasPrefix applies the HasPrefix predicate on the "leakSpeed" field. func LeakSpeedHasPrefix(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldLeakSpeed), v)) - }) + return predicate.Alert(sql.FieldHasPrefix(FieldLeakSpeed, v)) } // LeakSpeedHasSuffix applies the HasSuffix predicate on the "leakSpeed" field. func LeakSpeedHasSuffix(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldLeakSpeed), v)) - }) + return predicate.Alert(sql.FieldHasSuffix(FieldLeakSpeed, v)) } // LeakSpeedIsNil applies the IsNil predicate on the "leakSpeed" field. func LeakSpeedIsNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldLeakSpeed))) - }) + return predicate.Alert(sql.FieldIsNull(FieldLeakSpeed)) } // LeakSpeedNotNil applies the NotNil predicate on the "leakSpeed" field. func LeakSpeedNotNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldLeakSpeed))) - }) + return predicate.Alert(sql.FieldNotNull(FieldLeakSpeed)) } // LeakSpeedEqualFold applies the EqualFold predicate on the "leakSpeed" field. func LeakSpeedEqualFold(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldLeakSpeed), v)) - }) + return predicate.Alert(sql.FieldEqualFold(FieldLeakSpeed, v)) } // LeakSpeedContainsFold applies the ContainsFold predicate on the "leakSpeed" field. func LeakSpeedContainsFold(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldLeakSpeed), v)) - }) + return predicate.Alert(sql.FieldContainsFold(FieldLeakSpeed, v)) } // ScenarioVersionEQ applies the EQ predicate on the "scenarioVersion" field. func ScenarioVersionEQ(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldScenarioVersion), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldScenarioVersion, v)) } // ScenarioVersionNEQ applies the NEQ predicate on the "scenarioVersion" field. func ScenarioVersionNEQ(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldScenarioVersion), v)) - }) + return predicate.Alert(sql.FieldNEQ(FieldScenarioVersion, v)) } // ScenarioVersionIn applies the In predicate on the "scenarioVersion" field. func ScenarioVersionIn(vs ...string) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldScenarioVersion), v...)) - }) + return predicate.Alert(sql.FieldIn(FieldScenarioVersion, vs...)) } // ScenarioVersionNotIn applies the NotIn predicate on the "scenarioVersion" field. func ScenarioVersionNotIn(vs ...string) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldScenarioVersion), v...)) - }) + return predicate.Alert(sql.FieldNotIn(FieldScenarioVersion, vs...)) } // ScenarioVersionGT applies the GT predicate on the "scenarioVersion" field. func ScenarioVersionGT(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldScenarioVersion), v)) - }) + return predicate.Alert(sql.FieldGT(FieldScenarioVersion, v)) } // ScenarioVersionGTE applies the GTE predicate on the "scenarioVersion" field. func ScenarioVersionGTE(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldScenarioVersion), v)) - }) + return predicate.Alert(sql.FieldGTE(FieldScenarioVersion, v)) } // ScenarioVersionLT applies the LT predicate on the "scenarioVersion" field. func ScenarioVersionLT(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldScenarioVersion), v)) - }) + return predicate.Alert(sql.FieldLT(FieldScenarioVersion, v)) } // ScenarioVersionLTE applies the LTE predicate on the "scenarioVersion" field. func ScenarioVersionLTE(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldScenarioVersion), v)) - }) + return predicate.Alert(sql.FieldLTE(FieldScenarioVersion, v)) } // ScenarioVersionContains applies the Contains predicate on the "scenarioVersion" field. func ScenarioVersionContains(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldScenarioVersion), v)) - }) + return predicate.Alert(sql.FieldContains(FieldScenarioVersion, v)) } // ScenarioVersionHasPrefix applies the HasPrefix predicate on the "scenarioVersion" field. func ScenarioVersionHasPrefix(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldScenarioVersion), v)) - }) + return predicate.Alert(sql.FieldHasPrefix(FieldScenarioVersion, v)) } // ScenarioVersionHasSuffix applies the HasSuffix predicate on the "scenarioVersion" field. func ScenarioVersionHasSuffix(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldScenarioVersion), v)) - }) + return predicate.Alert(sql.FieldHasSuffix(FieldScenarioVersion, v)) } // ScenarioVersionIsNil applies the IsNil predicate on the "scenarioVersion" field. func ScenarioVersionIsNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldScenarioVersion))) - }) + return predicate.Alert(sql.FieldIsNull(FieldScenarioVersion)) } // ScenarioVersionNotNil applies the NotNil predicate on the "scenarioVersion" field. func ScenarioVersionNotNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldScenarioVersion))) - }) + return predicate.Alert(sql.FieldNotNull(FieldScenarioVersion)) } // ScenarioVersionEqualFold applies the EqualFold predicate on the "scenarioVersion" field. func ScenarioVersionEqualFold(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldScenarioVersion), v)) - }) + return predicate.Alert(sql.FieldEqualFold(FieldScenarioVersion, v)) } // ScenarioVersionContainsFold applies the ContainsFold predicate on the "scenarioVersion" field. func ScenarioVersionContainsFold(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldScenarioVersion), v)) - }) + return predicate.Alert(sql.FieldContainsFold(FieldScenarioVersion, v)) } // ScenarioHashEQ applies the EQ predicate on the "scenarioHash" field. func ScenarioHashEQ(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldScenarioHash), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldScenarioHash, v)) } // ScenarioHashNEQ applies the NEQ predicate on the "scenarioHash" field. func ScenarioHashNEQ(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldScenarioHash), v)) - }) + return predicate.Alert(sql.FieldNEQ(FieldScenarioHash, v)) } // ScenarioHashIn applies the In predicate on the "scenarioHash" field. func ScenarioHashIn(vs ...string) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldScenarioHash), v...)) - }) + return predicate.Alert(sql.FieldIn(FieldScenarioHash, vs...)) } // ScenarioHashNotIn applies the NotIn predicate on the "scenarioHash" field. func ScenarioHashNotIn(vs ...string) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldScenarioHash), v...)) - }) + return predicate.Alert(sql.FieldNotIn(FieldScenarioHash, vs...)) } // ScenarioHashGT applies the GT predicate on the "scenarioHash" field. func ScenarioHashGT(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldScenarioHash), v)) - }) + return predicate.Alert(sql.FieldGT(FieldScenarioHash, v)) } // ScenarioHashGTE applies the GTE predicate on the "scenarioHash" field. func ScenarioHashGTE(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldScenarioHash), v)) - }) + return predicate.Alert(sql.FieldGTE(FieldScenarioHash, v)) } // ScenarioHashLT applies the LT predicate on the "scenarioHash" field. func ScenarioHashLT(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldScenarioHash), v)) - }) + return predicate.Alert(sql.FieldLT(FieldScenarioHash, v)) } // ScenarioHashLTE applies the LTE predicate on the "scenarioHash" field. func ScenarioHashLTE(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldScenarioHash), v)) - }) + return predicate.Alert(sql.FieldLTE(FieldScenarioHash, v)) } // ScenarioHashContains applies the Contains predicate on the "scenarioHash" field. func ScenarioHashContains(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldScenarioHash), v)) - }) + return predicate.Alert(sql.FieldContains(FieldScenarioHash, v)) } // ScenarioHashHasPrefix applies the HasPrefix predicate on the "scenarioHash" field. func ScenarioHashHasPrefix(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldScenarioHash), v)) - }) + return predicate.Alert(sql.FieldHasPrefix(FieldScenarioHash, v)) } // ScenarioHashHasSuffix applies the HasSuffix predicate on the "scenarioHash" field. func ScenarioHashHasSuffix(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldScenarioHash), v)) - }) + return predicate.Alert(sql.FieldHasSuffix(FieldScenarioHash, v)) } // ScenarioHashIsNil applies the IsNil predicate on the "scenarioHash" field. func ScenarioHashIsNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldScenarioHash))) - }) + return predicate.Alert(sql.FieldIsNull(FieldScenarioHash)) } // ScenarioHashNotNil applies the NotNil predicate on the "scenarioHash" field. func ScenarioHashNotNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldScenarioHash))) - }) + return predicate.Alert(sql.FieldNotNull(FieldScenarioHash)) } // ScenarioHashEqualFold applies the EqualFold predicate on the "scenarioHash" field. func ScenarioHashEqualFold(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldScenarioHash), v)) - }) + return predicate.Alert(sql.FieldEqualFold(FieldScenarioHash, v)) } // ScenarioHashContainsFold applies the ContainsFold predicate on the "scenarioHash" field. func ScenarioHashContainsFold(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldScenarioHash), v)) - }) + return predicate.Alert(sql.FieldContainsFold(FieldScenarioHash, v)) } // SimulatedEQ applies the EQ predicate on the "simulated" field. func SimulatedEQ(v bool) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldSimulated), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldSimulated, v)) } // SimulatedNEQ applies the NEQ predicate on the "simulated" field. func SimulatedNEQ(v bool) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldSimulated), v)) - }) + return predicate.Alert(sql.FieldNEQ(FieldSimulated, v)) } // UUIDEQ applies the EQ predicate on the "uuid" field. func UUIDEQ(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldUUID), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldUUID, v)) } // UUIDNEQ applies the NEQ predicate on the "uuid" field. func UUIDNEQ(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldUUID), v)) - }) + return predicate.Alert(sql.FieldNEQ(FieldUUID, v)) } // UUIDIn applies the In predicate on the "uuid" field. func UUIDIn(vs ...string) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldUUID), v...)) - }) + return predicate.Alert(sql.FieldIn(FieldUUID, vs...)) } // UUIDNotIn applies the NotIn predicate on the "uuid" field. func UUIDNotIn(vs ...string) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldUUID), v...)) - }) + return predicate.Alert(sql.FieldNotIn(FieldUUID, vs...)) } // UUIDGT applies the GT predicate on the "uuid" field. func UUIDGT(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldUUID), v)) - }) + return predicate.Alert(sql.FieldGT(FieldUUID, v)) } // UUIDGTE applies the GTE predicate on the "uuid" field. func UUIDGTE(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldUUID), v)) - }) + return predicate.Alert(sql.FieldGTE(FieldUUID, v)) } // UUIDLT applies the LT predicate on the "uuid" field. func UUIDLT(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldUUID), v)) - }) + return predicate.Alert(sql.FieldLT(FieldUUID, v)) } // UUIDLTE applies the LTE predicate on the "uuid" field. func UUIDLTE(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldUUID), v)) - }) + return predicate.Alert(sql.FieldLTE(FieldUUID, v)) } // UUIDContains applies the Contains predicate on the "uuid" field. func UUIDContains(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldUUID), v)) - }) + return predicate.Alert(sql.FieldContains(FieldUUID, v)) } // UUIDHasPrefix applies the HasPrefix predicate on the "uuid" field. func UUIDHasPrefix(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldUUID), v)) - }) + return predicate.Alert(sql.FieldHasPrefix(FieldUUID, v)) } // UUIDHasSuffix applies the HasSuffix predicate on the "uuid" field. func UUIDHasSuffix(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldUUID), v)) - }) + return predicate.Alert(sql.FieldHasSuffix(FieldUUID, v)) } // UUIDIsNil applies the IsNil predicate on the "uuid" field. func UUIDIsNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldUUID))) - }) + return predicate.Alert(sql.FieldIsNull(FieldUUID)) } // UUIDNotNil applies the NotNil predicate on the "uuid" field. func UUIDNotNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldUUID))) - }) + return predicate.Alert(sql.FieldNotNull(FieldUUID)) } // UUIDEqualFold applies the EqualFold predicate on the "uuid" field. func UUIDEqualFold(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldUUID), v)) - }) + return predicate.Alert(sql.FieldEqualFold(FieldUUID, v)) } // UUIDContainsFold applies the ContainsFold predicate on the "uuid" field. func UUIDContainsFold(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldUUID), v)) - }) + return predicate.Alert(sql.FieldContainsFold(FieldUUID, v)) } // HasOwner applies the HasEdge predicate on the "owner" edge. @@ -2453,7 +1625,6 @@ func HasOwner() predicate.Alert { return predicate.Alert(func(s *sql.Selector) { step := sqlgraph.NewStep( sqlgraph.From(Table, FieldID), - sqlgraph.To(OwnerTable, FieldID), sqlgraph.Edge(sqlgraph.M2O, true, OwnerTable, OwnerColumn), ) sqlgraph.HasNeighbors(s, step) @@ -2463,11 +1634,7 @@ func HasOwner() predicate.Alert { // HasOwnerWith applies the HasEdge predicate on the "owner" edge with a given conditions (other predicates). func HasOwnerWith(preds ...predicate.Machine) predicate.Alert { return predicate.Alert(func(s *sql.Selector) { - step := sqlgraph.NewStep( - sqlgraph.From(Table, FieldID), - sqlgraph.To(OwnerInverseTable, FieldID), - sqlgraph.Edge(sqlgraph.M2O, true, OwnerTable, OwnerColumn), - ) + step := newOwnerStep() sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { for _, p := range preds { p(s) @@ -2481,7 +1648,6 @@ func HasDecisions() predicate.Alert { return predicate.Alert(func(s *sql.Selector) { step := sqlgraph.NewStep( sqlgraph.From(Table, FieldID), - sqlgraph.To(DecisionsTable, FieldID), sqlgraph.Edge(sqlgraph.O2M, false, DecisionsTable, DecisionsColumn), ) sqlgraph.HasNeighbors(s, step) @@ -2491,11 +1657,7 @@ func HasDecisions() predicate.Alert { // HasDecisionsWith applies the HasEdge predicate on the "decisions" edge with a given conditions (other predicates). func HasDecisionsWith(preds ...predicate.Decision) predicate.Alert { return predicate.Alert(func(s *sql.Selector) { - step := sqlgraph.NewStep( - sqlgraph.From(Table, FieldID), - sqlgraph.To(DecisionsInverseTable, FieldID), - sqlgraph.Edge(sqlgraph.O2M, false, DecisionsTable, DecisionsColumn), - ) + step := newDecisionsStep() sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { for _, p := range preds { p(s) @@ -2509,7 +1671,6 @@ func HasEvents() predicate.Alert { return predicate.Alert(func(s *sql.Selector) { step := sqlgraph.NewStep( sqlgraph.From(Table, FieldID), - sqlgraph.To(EventsTable, FieldID), sqlgraph.Edge(sqlgraph.O2M, false, EventsTable, EventsColumn), ) sqlgraph.HasNeighbors(s, step) @@ -2519,11 +1680,7 @@ func HasEvents() predicate.Alert { // HasEventsWith applies the HasEdge predicate on the "events" edge with a given conditions (other predicates). func HasEventsWith(preds ...predicate.Event) predicate.Alert { return predicate.Alert(func(s *sql.Selector) { - step := sqlgraph.NewStep( - sqlgraph.From(Table, FieldID), - sqlgraph.To(EventsInverseTable, FieldID), - sqlgraph.Edge(sqlgraph.O2M, false, EventsTable, EventsColumn), - ) + step := newEventsStep() sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { for _, p := range preds { p(s) @@ -2537,7 +1694,6 @@ func HasMetas() predicate.Alert { return predicate.Alert(func(s *sql.Selector) { step := sqlgraph.NewStep( sqlgraph.From(Table, FieldID), - sqlgraph.To(MetasTable, FieldID), sqlgraph.Edge(sqlgraph.O2M, false, MetasTable, MetasColumn), ) sqlgraph.HasNeighbors(s, step) @@ -2547,11 +1703,7 @@ func HasMetas() predicate.Alert { // HasMetasWith applies the HasEdge predicate on the "metas" edge with a given conditions (other predicates). func HasMetasWith(preds ...predicate.Meta) predicate.Alert { return predicate.Alert(func(s *sql.Selector) { - step := sqlgraph.NewStep( - sqlgraph.From(Table, FieldID), - sqlgraph.To(MetasInverseTable, FieldID), - sqlgraph.Edge(sqlgraph.O2M, false, MetasTable, MetasColumn), - ) + step := newMetasStep() sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { for _, p := range preds { p(s) @@ -2562,32 +1714,15 @@ func HasMetasWith(preds ...predicate.Meta) predicate.Alert { // And groups predicates with the AND operator between them. func And(predicates ...predicate.Alert) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s1 := s.Clone().SetP(nil) - for _, p := range predicates { - p(s1) - } - s.Where(s1.P()) - }) + return predicate.Alert(sql.AndPredicates(predicates...)) } // Or groups predicates with the OR operator between them. func Or(predicates ...predicate.Alert) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s1 := s.Clone().SetP(nil) - for i, p := range predicates { - if i > 0 { - s1.Or() - } - p(s1) - } - s.Where(s1.P()) - }) + return predicate.Alert(sql.OrPredicates(predicates...)) } // Not applies the not operator on the given predicate. func Not(p predicate.Alert) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - p(s.Not()) - }) + return predicate.Alert(sql.NotPredicates(p)) } diff --git a/pkg/database/ent/alert_create.go b/pkg/database/ent/alert_create.go index 42da5b137ba..c7498442c06 100644 --- a/pkg/database/ent/alert_create.go +++ b/pkg/database/ent/alert_create.go @@ -409,50 +409,8 @@ func (ac *AlertCreate) Mutation() *AlertMutation { // Save creates the Alert in the database. func (ac *AlertCreate) Save(ctx context.Context) (*Alert, error) { - var ( - err error - node *Alert - ) ac.defaults() - if len(ac.hooks) == 0 { - if err = ac.check(); err != nil { - return nil, err - } - node, err = ac.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*AlertMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - if err = ac.check(); err != nil { - return nil, err - } - ac.mutation = mutation - if node, err = ac.sqlSave(ctx); err != nil { - return nil, err - } - mutation.id = &node.ID - mutation.done = true - return node, err - }) - for i := len(ac.hooks) - 1; i >= 0; i-- { - if ac.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = ac.hooks[i](mut) - } - v, err := mut.Mutate(ctx, ac.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*Alert) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from AlertMutation", v) - } - node = nv - } - return node, err + return withHooks(ctx, ac.sqlSave, ac.mutation, ac.hooks) } // SaveX calls Save and panics if Save returns an error. @@ -525,6 +483,9 @@ func (ac *AlertCreate) check() error { } func (ac *AlertCreate) sqlSave(ctx context.Context) (*Alert, error) { + if err := ac.check(); err != nil { + return nil, err + } _node, _spec := ac.createSpec() if err := sqlgraph.CreateNode(ctx, ac.driver, _spec); err != nil { if sqlgraph.IsConstraintError(err) { @@ -534,202 +495,106 @@ func (ac *AlertCreate) sqlSave(ctx context.Context) (*Alert, error) { } id := _spec.ID.Value.(int64) _node.ID = int(id) + ac.mutation.id = &_node.ID + ac.mutation.done = true return _node, nil } func (ac *AlertCreate) createSpec() (*Alert, *sqlgraph.CreateSpec) { var ( _node = &Alert{config: ac.config} - _spec = &sqlgraph.CreateSpec{ - Table: alert.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: alert.FieldID, - }, - } + _spec = sqlgraph.NewCreateSpec(alert.Table, sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt)) ) if value, ok := ac.mutation.CreatedAt(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: alert.FieldCreatedAt, - }) + _spec.SetField(alert.FieldCreatedAt, field.TypeTime, value) _node.CreatedAt = &value } if value, ok := ac.mutation.UpdatedAt(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: alert.FieldUpdatedAt, - }) + _spec.SetField(alert.FieldUpdatedAt, field.TypeTime, value) _node.UpdatedAt = &value } if value, ok := ac.mutation.Scenario(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldScenario, - }) + _spec.SetField(alert.FieldScenario, field.TypeString, value) _node.Scenario = value } if value, ok := ac.mutation.BucketId(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldBucketId, - }) + _spec.SetField(alert.FieldBucketId, field.TypeString, value) _node.BucketId = value } if value, ok := ac.mutation.Message(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldMessage, - }) + _spec.SetField(alert.FieldMessage, field.TypeString, value) _node.Message = value } if value, ok := ac.mutation.EventsCount(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeInt32, - Value: value, - Column: alert.FieldEventsCount, - }) + _spec.SetField(alert.FieldEventsCount, field.TypeInt32, value) _node.EventsCount = value } if value, ok := ac.mutation.StartedAt(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: alert.FieldStartedAt, - }) + _spec.SetField(alert.FieldStartedAt, field.TypeTime, value) _node.StartedAt = value } if value, ok := ac.mutation.StoppedAt(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: alert.FieldStoppedAt, - }) + _spec.SetField(alert.FieldStoppedAt, field.TypeTime, value) _node.StoppedAt = value } if value, ok := ac.mutation.SourceIp(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldSourceIp, - }) + _spec.SetField(alert.FieldSourceIp, field.TypeString, value) _node.SourceIp = value } if value, ok := ac.mutation.SourceRange(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldSourceRange, - }) + _spec.SetField(alert.FieldSourceRange, field.TypeString, value) _node.SourceRange = value } if value, ok := ac.mutation.SourceAsNumber(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldSourceAsNumber, - }) + _spec.SetField(alert.FieldSourceAsNumber, field.TypeString, value) _node.SourceAsNumber = value } if value, ok := ac.mutation.SourceAsName(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldSourceAsName, - }) + _spec.SetField(alert.FieldSourceAsName, field.TypeString, value) _node.SourceAsName = value } if value, ok := ac.mutation.SourceCountry(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldSourceCountry, - }) + _spec.SetField(alert.FieldSourceCountry, field.TypeString, value) _node.SourceCountry = value } if value, ok := ac.mutation.SourceLatitude(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeFloat32, - Value: value, - Column: alert.FieldSourceLatitude, - }) + _spec.SetField(alert.FieldSourceLatitude, field.TypeFloat32, value) _node.SourceLatitude = value } if value, ok := ac.mutation.SourceLongitude(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeFloat32, - Value: value, - Column: alert.FieldSourceLongitude, - }) + _spec.SetField(alert.FieldSourceLongitude, field.TypeFloat32, value) _node.SourceLongitude = value } if value, ok := ac.mutation.SourceScope(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldSourceScope, - }) + _spec.SetField(alert.FieldSourceScope, field.TypeString, value) _node.SourceScope = value } if value, ok := ac.mutation.SourceValue(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldSourceValue, - }) + _spec.SetField(alert.FieldSourceValue, field.TypeString, value) _node.SourceValue = value } if value, ok := ac.mutation.Capacity(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeInt32, - Value: value, - Column: alert.FieldCapacity, - }) + _spec.SetField(alert.FieldCapacity, field.TypeInt32, value) _node.Capacity = value } if value, ok := ac.mutation.LeakSpeed(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldLeakSpeed, - }) + _spec.SetField(alert.FieldLeakSpeed, field.TypeString, value) _node.LeakSpeed = value } if value, ok := ac.mutation.ScenarioVersion(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldScenarioVersion, - }) + _spec.SetField(alert.FieldScenarioVersion, field.TypeString, value) _node.ScenarioVersion = value } if value, ok := ac.mutation.ScenarioHash(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldScenarioHash, - }) + _spec.SetField(alert.FieldScenarioHash, field.TypeString, value) _node.ScenarioHash = value } if value, ok := ac.mutation.Simulated(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeBool, - Value: value, - Column: alert.FieldSimulated, - }) + _spec.SetField(alert.FieldSimulated, field.TypeBool, value) _node.Simulated = value } if value, ok := ac.mutation.UUID(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldUUID, - }) + _spec.SetField(alert.FieldUUID, field.TypeString, value) _node.UUID = value } if nodes := ac.mutation.OwnerIDs(); len(nodes) > 0 { @@ -740,10 +605,7 @@ func (ac *AlertCreate) createSpec() (*Alert, *sqlgraph.CreateSpec) { Columns: []string{alert.OwnerColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: machine.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(machine.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -760,10 +622,7 @@ func (ac *AlertCreate) createSpec() (*Alert, *sqlgraph.CreateSpec) { Columns: []string{alert.DecisionsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: decision.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(decision.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -779,10 +638,7 @@ func (ac *AlertCreate) createSpec() (*Alert, *sqlgraph.CreateSpec) { Columns: []string{alert.EventsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: event.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(event.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -798,10 +654,7 @@ func (ac *AlertCreate) createSpec() (*Alert, *sqlgraph.CreateSpec) { Columns: []string{alert.MetasColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: meta.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(meta.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -815,11 +668,15 @@ func (ac *AlertCreate) createSpec() (*Alert, *sqlgraph.CreateSpec) { // AlertCreateBulk is the builder for creating many Alert entities in bulk. type AlertCreateBulk struct { config + err error builders []*AlertCreate } // Save creates the Alert entities in the database. func (acb *AlertCreateBulk) Save(ctx context.Context) ([]*Alert, error) { + if acb.err != nil { + return nil, acb.err + } specs := make([]*sqlgraph.CreateSpec, len(acb.builders)) nodes := make([]*Alert, len(acb.builders)) mutators := make([]Mutator, len(acb.builders)) @@ -836,8 +693,8 @@ func (acb *AlertCreateBulk) Save(ctx context.Context) ([]*Alert, error) { return nil, err } builder.mutation = mutation - nodes[i], specs[i] = builder.createSpec() var err error + nodes[i], specs[i] = builder.createSpec() if i < len(mutators)-1 { _, err = mutators[i+1].Mutate(root, acb.builders[i+1].mutation) } else { diff --git a/pkg/database/ent/alert_delete.go b/pkg/database/ent/alert_delete.go index 014bcc2e0c6..15b3a4c822a 100644 --- a/pkg/database/ent/alert_delete.go +++ b/pkg/database/ent/alert_delete.go @@ -4,7 +4,6 @@ package ent import ( "context" - "fmt" "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" @@ -28,34 +27,7 @@ func (ad *AlertDelete) Where(ps ...predicate.Alert) *AlertDelete { // Exec executes the deletion query and returns how many vertices were deleted. func (ad *AlertDelete) Exec(ctx context.Context) (int, error) { - var ( - err error - affected int - ) - if len(ad.hooks) == 0 { - affected, err = ad.sqlExec(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*AlertMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - ad.mutation = mutation - affected, err = ad.sqlExec(ctx) - mutation.done = true - return affected, err - }) - for i := len(ad.hooks) - 1; i >= 0; i-- { - if ad.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = ad.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, ad.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks(ctx, ad.sqlExec, ad.mutation, ad.hooks) } // ExecX is like Exec, but panics if an error occurs. @@ -68,15 +40,7 @@ func (ad *AlertDelete) ExecX(ctx context.Context) int { } func (ad *AlertDelete) sqlExec(ctx context.Context) (int, error) { - _spec := &sqlgraph.DeleteSpec{ - Node: &sqlgraph.NodeSpec{ - Table: alert.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: alert.FieldID, - }, - }, - } + _spec := sqlgraph.NewDeleteSpec(alert.Table, sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt)) if ps := ad.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -88,6 +52,7 @@ func (ad *AlertDelete) sqlExec(ctx context.Context) (int, error) { if err != nil && sqlgraph.IsConstraintError(err) { err = &ConstraintError{msg: err.Error(), wrap: err} } + ad.mutation.done = true return affected, err } @@ -96,6 +61,12 @@ type AlertDeleteOne struct { ad *AlertDelete } +// Where appends a list predicates to the AlertDelete builder. +func (ado *AlertDeleteOne) Where(ps ...predicate.Alert) *AlertDeleteOne { + ado.ad.mutation.Where(ps...) + return ado +} + // Exec executes the deletion query. func (ado *AlertDeleteOne) Exec(ctx context.Context) error { n, err := ado.ad.Exec(ctx) @@ -111,5 +82,7 @@ func (ado *AlertDeleteOne) Exec(ctx context.Context) error { // ExecX is like Exec, but panics if an error occurs. func (ado *AlertDeleteOne) ExecX(ctx context.Context) { - ado.ad.ExecX(ctx) + if err := ado.Exec(ctx); err != nil { + panic(err) + } } diff --git a/pkg/database/ent/alert_query.go b/pkg/database/ent/alert_query.go index 68789196d24..7eddb6ce024 100644 --- a/pkg/database/ent/alert_query.go +++ b/pkg/database/ent/alert_query.go @@ -22,11 +22,9 @@ import ( // AlertQuery is the builder for querying Alert entities. type AlertQuery struct { config - limit *int - offset *int - unique *bool - order []OrderFunc - fields []string + ctx *QueryContext + order []alert.OrderOption + inters []Interceptor predicates []predicate.Alert withOwner *MachineQuery withDecisions *DecisionQuery @@ -44,34 +42,34 @@ func (aq *AlertQuery) Where(ps ...predicate.Alert) *AlertQuery { return aq } -// Limit adds a limit step to the query. +// Limit the number of records to be returned by this query. func (aq *AlertQuery) Limit(limit int) *AlertQuery { - aq.limit = &limit + aq.ctx.Limit = &limit return aq } -// Offset adds an offset step to the query. +// Offset to start from. func (aq *AlertQuery) Offset(offset int) *AlertQuery { - aq.offset = &offset + aq.ctx.Offset = &offset return aq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (aq *AlertQuery) Unique(unique bool) *AlertQuery { - aq.unique = &unique + aq.ctx.Unique = &unique return aq } -// Order adds an order step to the query. -func (aq *AlertQuery) Order(o ...OrderFunc) *AlertQuery { +// Order specifies how the records should be ordered. +func (aq *AlertQuery) Order(o ...alert.OrderOption) *AlertQuery { aq.order = append(aq.order, o...) return aq } // QueryOwner chains the current query on the "owner" edge. func (aq *AlertQuery) QueryOwner() *MachineQuery { - query := &MachineQuery{config: aq.config} + query := (&MachineClient{config: aq.config}).Query() query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { if err := aq.prepareQuery(ctx); err != nil { return nil, err @@ -93,7 +91,7 @@ func (aq *AlertQuery) QueryOwner() *MachineQuery { // QueryDecisions chains the current query on the "decisions" edge. func (aq *AlertQuery) QueryDecisions() *DecisionQuery { - query := &DecisionQuery{config: aq.config} + query := (&DecisionClient{config: aq.config}).Query() query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { if err := aq.prepareQuery(ctx); err != nil { return nil, err @@ -115,7 +113,7 @@ func (aq *AlertQuery) QueryDecisions() *DecisionQuery { // QueryEvents chains the current query on the "events" edge. func (aq *AlertQuery) QueryEvents() *EventQuery { - query := &EventQuery{config: aq.config} + query := (&EventClient{config: aq.config}).Query() query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { if err := aq.prepareQuery(ctx); err != nil { return nil, err @@ -137,7 +135,7 @@ func (aq *AlertQuery) QueryEvents() *EventQuery { // QueryMetas chains the current query on the "metas" edge. func (aq *AlertQuery) QueryMetas() *MetaQuery { - query := &MetaQuery{config: aq.config} + query := (&MetaClient{config: aq.config}).Query() query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { if err := aq.prepareQuery(ctx); err != nil { return nil, err @@ -160,7 +158,7 @@ func (aq *AlertQuery) QueryMetas() *MetaQuery { // First returns the first Alert entity from the query. // Returns a *NotFoundError when no Alert was found. func (aq *AlertQuery) First(ctx context.Context) (*Alert, error) { - nodes, err := aq.Limit(1).All(ctx) + nodes, err := aq.Limit(1).All(setContextOp(ctx, aq.ctx, "First")) if err != nil { return nil, err } @@ -183,7 +181,7 @@ func (aq *AlertQuery) FirstX(ctx context.Context) *Alert { // Returns a *NotFoundError when no Alert ID was found. func (aq *AlertQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = aq.Limit(1).IDs(ctx); err != nil { + if ids, err = aq.Limit(1).IDs(setContextOp(ctx, aq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -206,7 +204,7 @@ func (aq *AlertQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one Alert entity is found. // Returns a *NotFoundError when no Alert entities are found. func (aq *AlertQuery) Only(ctx context.Context) (*Alert, error) { - nodes, err := aq.Limit(2).All(ctx) + nodes, err := aq.Limit(2).All(setContextOp(ctx, aq.ctx, "Only")) if err != nil { return nil, err } @@ -234,7 +232,7 @@ func (aq *AlertQuery) OnlyX(ctx context.Context) *Alert { // Returns a *NotFoundError when no entities are found. func (aq *AlertQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = aq.Limit(2).IDs(ctx); err != nil { + if ids, err = aq.Limit(2).IDs(setContextOp(ctx, aq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -259,10 +257,12 @@ func (aq *AlertQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of Alerts. func (aq *AlertQuery) All(ctx context.Context) ([]*Alert, error) { + ctx = setContextOp(ctx, aq.ctx, "All") if err := aq.prepareQuery(ctx); err != nil { return nil, err } - return aq.sqlAll(ctx) + qr := querierAll[[]*Alert, *AlertQuery]() + return withInterceptors[[]*Alert](ctx, aq, qr, aq.inters) } // AllX is like All, but panics if an error occurs. @@ -275,9 +275,12 @@ func (aq *AlertQuery) AllX(ctx context.Context) []*Alert { } // IDs executes the query and returns a list of Alert IDs. -func (aq *AlertQuery) IDs(ctx context.Context) ([]int, error) { - var ids []int - if err := aq.Select(alert.FieldID).Scan(ctx, &ids); err != nil { +func (aq *AlertQuery) IDs(ctx context.Context) (ids []int, err error) { + if aq.ctx.Unique == nil && aq.path != nil { + aq.Unique(true) + } + ctx = setContextOp(ctx, aq.ctx, "IDs") + if err = aq.Select(alert.FieldID).Scan(ctx, &ids); err != nil { return nil, err } return ids, nil @@ -294,10 +297,11 @@ func (aq *AlertQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (aq *AlertQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, aq.ctx, "Count") if err := aq.prepareQuery(ctx); err != nil { return 0, err } - return aq.sqlCount(ctx) + return withInterceptors[int](ctx, aq, querierCount[*AlertQuery](), aq.inters) } // CountX is like Count, but panics if an error occurs. @@ -311,10 +315,15 @@ func (aq *AlertQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (aq *AlertQuery) Exist(ctx context.Context) (bool, error) { - if err := aq.prepareQuery(ctx); err != nil { - return false, err + ctx = setContextOp(ctx, aq.ctx, "Exist") + switch _, err := aq.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil } - return aq.sqlExist(ctx) } // ExistX is like Exist, but panics if an error occurs. @@ -334,25 +343,24 @@ func (aq *AlertQuery) Clone() *AlertQuery { } return &AlertQuery{ config: aq.config, - limit: aq.limit, - offset: aq.offset, - order: append([]OrderFunc{}, aq.order...), + ctx: aq.ctx.Clone(), + order: append([]alert.OrderOption{}, aq.order...), + inters: append([]Interceptor{}, aq.inters...), predicates: append([]predicate.Alert{}, aq.predicates...), withOwner: aq.withOwner.Clone(), withDecisions: aq.withDecisions.Clone(), withEvents: aq.withEvents.Clone(), withMetas: aq.withMetas.Clone(), // clone intermediate query. - sql: aq.sql.Clone(), - path: aq.path, - unique: aq.unique, + sql: aq.sql.Clone(), + path: aq.path, } } // WithOwner tells the query-builder to eager-load the nodes that are connected to // the "owner" edge. The optional arguments are used to configure the query builder of the edge. func (aq *AlertQuery) WithOwner(opts ...func(*MachineQuery)) *AlertQuery { - query := &MachineQuery{config: aq.config} + query := (&MachineClient{config: aq.config}).Query() for _, opt := range opts { opt(query) } @@ -363,7 +371,7 @@ func (aq *AlertQuery) WithOwner(opts ...func(*MachineQuery)) *AlertQuery { // WithDecisions tells the query-builder to eager-load the nodes that are connected to // the "decisions" edge. The optional arguments are used to configure the query builder of the edge. func (aq *AlertQuery) WithDecisions(opts ...func(*DecisionQuery)) *AlertQuery { - query := &DecisionQuery{config: aq.config} + query := (&DecisionClient{config: aq.config}).Query() for _, opt := range opts { opt(query) } @@ -374,7 +382,7 @@ func (aq *AlertQuery) WithDecisions(opts ...func(*DecisionQuery)) *AlertQuery { // WithEvents tells the query-builder to eager-load the nodes that are connected to // the "events" edge. The optional arguments are used to configure the query builder of the edge. func (aq *AlertQuery) WithEvents(opts ...func(*EventQuery)) *AlertQuery { - query := &EventQuery{config: aq.config} + query := (&EventClient{config: aq.config}).Query() for _, opt := range opts { opt(query) } @@ -385,7 +393,7 @@ func (aq *AlertQuery) WithEvents(opts ...func(*EventQuery)) *AlertQuery { // WithMetas tells the query-builder to eager-load the nodes that are connected to // the "metas" edge. The optional arguments are used to configure the query builder of the edge. func (aq *AlertQuery) WithMetas(opts ...func(*MetaQuery)) *AlertQuery { - query := &MetaQuery{config: aq.config} + query := (&MetaClient{config: aq.config}).Query() for _, opt := range opts { opt(query) } @@ -408,16 +416,11 @@ func (aq *AlertQuery) WithMetas(opts ...func(*MetaQuery)) *AlertQuery { // Aggregate(ent.Count()). // Scan(ctx, &v) func (aq *AlertQuery) GroupBy(field string, fields ...string) *AlertGroupBy { - grbuild := &AlertGroupBy{config: aq.config} - grbuild.fields = append([]string{field}, fields...) - grbuild.path = func(ctx context.Context) (prev *sql.Selector, err error) { - if err := aq.prepareQuery(ctx); err != nil { - return nil, err - } - return aq.sqlQuery(ctx), nil - } + aq.ctx.Fields = append([]string{field}, fields...) + grbuild := &AlertGroupBy{build: aq} + grbuild.flds = &aq.ctx.Fields grbuild.label = alert.Label - grbuild.flds, grbuild.scan = &grbuild.fields, grbuild.Scan + grbuild.scan = grbuild.Scan return grbuild } @@ -434,15 +437,30 @@ func (aq *AlertQuery) GroupBy(field string, fields ...string) *AlertGroupBy { // Select(alert.FieldCreatedAt). // Scan(ctx, &v) func (aq *AlertQuery) Select(fields ...string) *AlertSelect { - aq.fields = append(aq.fields, fields...) - selbuild := &AlertSelect{AlertQuery: aq} - selbuild.label = alert.Label - selbuild.flds, selbuild.scan = &aq.fields, selbuild.Scan - return selbuild + aq.ctx.Fields = append(aq.ctx.Fields, fields...) + sbuild := &AlertSelect{AlertQuery: aq} + sbuild.label = alert.Label + sbuild.flds, sbuild.scan = &aq.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a AlertSelect configured with the given aggregations. +func (aq *AlertQuery) Aggregate(fns ...AggregateFunc) *AlertSelect { + return aq.Select().Aggregate(fns...) } func (aq *AlertQuery) prepareQuery(ctx context.Context) error { - for _, f := range aq.fields { + for _, inter := range aq.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, aq); err != nil { + return err + } + } + } + for _, f := range aq.ctx.Fields { if !alert.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -536,6 +554,9 @@ func (aq *AlertQuery) loadOwner(ctx context.Context, query *MachineQuery, nodes } nodeids[fk] = append(nodeids[fk], nodes[i]) } + if len(ids) == 0 { + return nil + } query.Where(machine.IDIn(ids...)) neighbors, err := query.All(ctx) if err != nil { @@ -562,8 +583,11 @@ func (aq *AlertQuery) loadDecisions(ctx context.Context, query *DecisionQuery, n init(nodes[i]) } } + if len(query.ctx.Fields) > 0 { + query.ctx.AppendFieldOnce(decision.FieldAlertDecisions) + } query.Where(predicate.Decision(func(s *sql.Selector) { - s.Where(sql.InValues(alert.DecisionsColumn, fks...)) + s.Where(sql.InValues(s.C(alert.DecisionsColumn), fks...)) })) neighbors, err := query.All(ctx) if err != nil { @@ -573,7 +597,7 @@ func (aq *AlertQuery) loadDecisions(ctx context.Context, query *DecisionQuery, n fk := n.AlertDecisions node, ok := nodeids[fk] if !ok { - return fmt.Errorf(`unexpected foreign-key "alert_decisions" returned %v for node %v`, fk, n.ID) + return fmt.Errorf(`unexpected referenced foreign-key "alert_decisions" returned %v for node %v`, fk, n.ID) } assign(node, n) } @@ -589,8 +613,11 @@ func (aq *AlertQuery) loadEvents(ctx context.Context, query *EventQuery, nodes [ init(nodes[i]) } } + if len(query.ctx.Fields) > 0 { + query.ctx.AppendFieldOnce(event.FieldAlertEvents) + } query.Where(predicate.Event(func(s *sql.Selector) { - s.Where(sql.InValues(alert.EventsColumn, fks...)) + s.Where(sql.InValues(s.C(alert.EventsColumn), fks...)) })) neighbors, err := query.All(ctx) if err != nil { @@ -600,7 +627,7 @@ func (aq *AlertQuery) loadEvents(ctx context.Context, query *EventQuery, nodes [ fk := n.AlertEvents node, ok := nodeids[fk] if !ok { - return fmt.Errorf(`unexpected foreign-key "alert_events" returned %v for node %v`, fk, n.ID) + return fmt.Errorf(`unexpected referenced foreign-key "alert_events" returned %v for node %v`, fk, n.ID) } assign(node, n) } @@ -616,8 +643,11 @@ func (aq *AlertQuery) loadMetas(ctx context.Context, query *MetaQuery, nodes []* init(nodes[i]) } } + if len(query.ctx.Fields) > 0 { + query.ctx.AppendFieldOnce(meta.FieldAlertMetas) + } query.Where(predicate.Meta(func(s *sql.Selector) { - s.Where(sql.InValues(alert.MetasColumn, fks...)) + s.Where(sql.InValues(s.C(alert.MetasColumn), fks...)) })) neighbors, err := query.All(ctx) if err != nil { @@ -627,7 +657,7 @@ func (aq *AlertQuery) loadMetas(ctx context.Context, query *MetaQuery, nodes []* fk := n.AlertMetas node, ok := nodeids[fk] if !ok { - return fmt.Errorf(`unexpected foreign-key "alert_metas" returned %v for node %v`, fk, n.ID) + return fmt.Errorf(`unexpected referenced foreign-key "alert_metas" returned %v for node %v`, fk, n.ID) } assign(node, n) } @@ -636,41 +666,22 @@ func (aq *AlertQuery) loadMetas(ctx context.Context, query *MetaQuery, nodes []* func (aq *AlertQuery) sqlCount(ctx context.Context) (int, error) { _spec := aq.querySpec() - _spec.Node.Columns = aq.fields - if len(aq.fields) > 0 { - _spec.Unique = aq.unique != nil && *aq.unique + _spec.Node.Columns = aq.ctx.Fields + if len(aq.ctx.Fields) > 0 { + _spec.Unique = aq.ctx.Unique != nil && *aq.ctx.Unique } return sqlgraph.CountNodes(ctx, aq.driver, _spec) } -func (aq *AlertQuery) sqlExist(ctx context.Context) (bool, error) { - switch _, err := aq.FirstID(ctx); { - case IsNotFound(err): - return false, nil - case err != nil: - return false, fmt.Errorf("ent: check existence: %w", err) - default: - return true, nil - } -} - func (aq *AlertQuery) querySpec() *sqlgraph.QuerySpec { - _spec := &sqlgraph.QuerySpec{ - Node: &sqlgraph.NodeSpec{ - Table: alert.Table, - Columns: alert.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: alert.FieldID, - }, - }, - From: aq.sql, - Unique: true, - } - if unique := aq.unique; unique != nil { + _spec := sqlgraph.NewQuerySpec(alert.Table, alert.Columns, sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt)) + _spec.From = aq.sql + if unique := aq.ctx.Unique; unique != nil { _spec.Unique = *unique + } else if aq.path != nil { + _spec.Unique = true } - if fields := aq.fields; len(fields) > 0 { + if fields := aq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, alert.FieldID) for i := range fields { @@ -686,10 +697,10 @@ func (aq *AlertQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := aq.limit; limit != nil { + if limit := aq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := aq.offset; offset != nil { + if offset := aq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := aq.order; len(ps) > 0 { @@ -705,7 +716,7 @@ func (aq *AlertQuery) querySpec() *sqlgraph.QuerySpec { func (aq *AlertQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(aq.driver.Dialect()) t1 := builder.Table(alert.Table) - columns := aq.fields + columns := aq.ctx.Fields if len(columns) == 0 { columns = alert.Columns } @@ -714,7 +725,7 @@ func (aq *AlertQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = aq.sql selector.Select(selector.Columns(columns...)...) } - if aq.unique != nil && *aq.unique { + if aq.ctx.Unique != nil && *aq.ctx.Unique { selector.Distinct() } for _, p := range aq.predicates { @@ -723,12 +734,12 @@ func (aq *AlertQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range aq.order { p(selector) } - if offset := aq.offset; offset != nil { + if offset := aq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := aq.limit; limit != nil { + if limit := aq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -736,13 +747,8 @@ func (aq *AlertQuery) sqlQuery(ctx context.Context) *sql.Selector { // AlertGroupBy is the group-by builder for Alert entities. type AlertGroupBy struct { - config selector - fields []string - fns []AggregateFunc - // intermediate query (i.e. traversal path). - sql *sql.Selector - path func(context.Context) (*sql.Selector, error) + build *AlertQuery } // Aggregate adds the given aggregation functions to the group-by query. @@ -751,74 +757,77 @@ func (agb *AlertGroupBy) Aggregate(fns ...AggregateFunc) *AlertGroupBy { return agb } -// Scan applies the group-by query and scans the result into the given value. +// Scan applies the selector query and scans the result into the given value. func (agb *AlertGroupBy) Scan(ctx context.Context, v any) error { - query, err := agb.path(ctx) - if err != nil { + ctx = setContextOp(ctx, agb.build.ctx, "GroupBy") + if err := agb.build.prepareQuery(ctx); err != nil { return err } - agb.sql = query - return agb.sqlScan(ctx, v) + return scanWithInterceptors[*AlertQuery, *AlertGroupBy](ctx, agb.build, agb, agb.build.inters, v) } -func (agb *AlertGroupBy) sqlScan(ctx context.Context, v any) error { - for _, f := range agb.fields { - if !alert.ValidColumn(f) { - return &ValidationError{Name: f, err: fmt.Errorf("invalid field %q for group-by", f)} +func (agb *AlertGroupBy) sqlScan(ctx context.Context, root *AlertQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(agb.fns)) + for _, fn := range agb.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*agb.flds)+len(agb.fns)) + for _, f := range *agb.flds { + columns = append(columns, selector.C(f)) } + columns = append(columns, aggregation...) + selector.Select(columns...) } - selector := agb.sqlQuery() + selector.GroupBy(selector.Columns(*agb.flds...)...) if err := selector.Err(); err != nil { return err } rows := &sql.Rows{} query, args := selector.Query() - if err := agb.driver.Query(ctx, query, args, rows); err != nil { + if err := agb.build.driver.Query(ctx, query, args, rows); err != nil { return err } defer rows.Close() return sql.ScanSlice(rows, v) } -func (agb *AlertGroupBy) sqlQuery() *sql.Selector { - selector := agb.sql.Select() - aggregation := make([]string, 0, len(agb.fns)) - for _, fn := range agb.fns { - aggregation = append(aggregation, fn(selector)) - } - // If no columns were selected in a custom aggregation function, the default - // selection is the fields used for "group-by", and the aggregation functions. - if len(selector.SelectedColumns()) == 0 { - columns := make([]string, 0, len(agb.fields)+len(agb.fns)) - for _, f := range agb.fields { - columns = append(columns, selector.C(f)) - } - columns = append(columns, aggregation...) - selector.Select(columns...) - } - return selector.GroupBy(selector.Columns(agb.fields...)...) -} - // AlertSelect is the builder for selecting fields of Alert entities. type AlertSelect struct { *AlertQuery selector - // intermediate query (i.e. traversal path). - sql *sql.Selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (as *AlertSelect) Aggregate(fns ...AggregateFunc) *AlertSelect { + as.fns = append(as.fns, fns...) + return as } // Scan applies the selector query and scans the result into the given value. func (as *AlertSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, as.ctx, "Select") if err := as.prepareQuery(ctx); err != nil { return err } - as.sql = as.AlertQuery.sqlQuery(ctx) - return as.sqlScan(ctx, v) + return scanWithInterceptors[*AlertQuery, *AlertSelect](ctx, as.AlertQuery, as, as.inters, v) } -func (as *AlertSelect) sqlScan(ctx context.Context, v any) error { +func (as *AlertSelect) sqlScan(ctx context.Context, root *AlertQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(as.fns)) + for _, fn := range as.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*as.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } rows := &sql.Rows{} - query, args := as.sql.Query() + query, args := selector.Query() if err := as.driver.Query(ctx, query, args, rows); err != nil { return err } diff --git a/pkg/database/ent/alert_update.go b/pkg/database/ent/alert_update.go index aaa12ef20a3..0e41ba18109 100644 --- a/pkg/database/ent/alert_update.go +++ b/pkg/database/ent/alert_update.go @@ -624,35 +624,8 @@ func (au *AlertUpdate) RemoveMetas(m ...*Meta) *AlertUpdate { // Save executes the query and returns the number of nodes affected by the update operation. func (au *AlertUpdate) Save(ctx context.Context) (int, error) { - var ( - err error - affected int - ) au.defaults() - if len(au.hooks) == 0 { - affected, err = au.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*AlertMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - au.mutation = mutation - affected, err = au.sqlSave(ctx) - mutation.done = true - return affected, err - }) - for i := len(au.hooks) - 1; i >= 0; i-- { - if au.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = au.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, au.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks(ctx, au.sqlSave, au.mutation, au.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -690,16 +663,7 @@ func (au *AlertUpdate) defaults() { } func (au *AlertUpdate) sqlSave(ctx context.Context) (n int, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: alert.Table, - Columns: alert.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: alert.FieldID, - }, - }, - } + _spec := sqlgraph.NewUpdateSpec(alert.Table, alert.Columns, sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt)) if ps := au.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -708,319 +672,148 @@ func (au *AlertUpdate) sqlSave(ctx context.Context) (n int, err error) { } } if value, ok := au.mutation.CreatedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: alert.FieldCreatedAt, - }) + _spec.SetField(alert.FieldCreatedAt, field.TypeTime, value) } if au.mutation.CreatedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: alert.FieldCreatedAt, - }) + _spec.ClearField(alert.FieldCreatedAt, field.TypeTime) } if value, ok := au.mutation.UpdatedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: alert.FieldUpdatedAt, - }) + _spec.SetField(alert.FieldUpdatedAt, field.TypeTime, value) } if au.mutation.UpdatedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: alert.FieldUpdatedAt, - }) + _spec.ClearField(alert.FieldUpdatedAt, field.TypeTime) } if value, ok := au.mutation.Scenario(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldScenario, - }) + _spec.SetField(alert.FieldScenario, field.TypeString, value) } if value, ok := au.mutation.BucketId(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldBucketId, - }) + _spec.SetField(alert.FieldBucketId, field.TypeString, value) } if au.mutation.BucketIdCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: alert.FieldBucketId, - }) + _spec.ClearField(alert.FieldBucketId, field.TypeString) } if value, ok := au.mutation.Message(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldMessage, - }) + _spec.SetField(alert.FieldMessage, field.TypeString, value) } if au.mutation.MessageCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: alert.FieldMessage, - }) + _spec.ClearField(alert.FieldMessage, field.TypeString) } if value, ok := au.mutation.EventsCount(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeInt32, - Value: value, - Column: alert.FieldEventsCount, - }) + _spec.SetField(alert.FieldEventsCount, field.TypeInt32, value) } if value, ok := au.mutation.AddedEventsCount(); ok { - _spec.Fields.Add = append(_spec.Fields.Add, &sqlgraph.FieldSpec{ - Type: field.TypeInt32, - Value: value, - Column: alert.FieldEventsCount, - }) + _spec.AddField(alert.FieldEventsCount, field.TypeInt32, value) } if au.mutation.EventsCountCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeInt32, - Column: alert.FieldEventsCount, - }) + _spec.ClearField(alert.FieldEventsCount, field.TypeInt32) } if value, ok := au.mutation.StartedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: alert.FieldStartedAt, - }) + _spec.SetField(alert.FieldStartedAt, field.TypeTime, value) } if au.mutation.StartedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: alert.FieldStartedAt, - }) + _spec.ClearField(alert.FieldStartedAt, field.TypeTime) } if value, ok := au.mutation.StoppedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: alert.FieldStoppedAt, - }) + _spec.SetField(alert.FieldStoppedAt, field.TypeTime, value) } if au.mutation.StoppedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: alert.FieldStoppedAt, - }) + _spec.ClearField(alert.FieldStoppedAt, field.TypeTime) } if value, ok := au.mutation.SourceIp(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldSourceIp, - }) + _spec.SetField(alert.FieldSourceIp, field.TypeString, value) } if au.mutation.SourceIpCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: alert.FieldSourceIp, - }) + _spec.ClearField(alert.FieldSourceIp, field.TypeString) } if value, ok := au.mutation.SourceRange(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldSourceRange, - }) + _spec.SetField(alert.FieldSourceRange, field.TypeString, value) } if au.mutation.SourceRangeCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: alert.FieldSourceRange, - }) + _spec.ClearField(alert.FieldSourceRange, field.TypeString) } if value, ok := au.mutation.SourceAsNumber(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldSourceAsNumber, - }) + _spec.SetField(alert.FieldSourceAsNumber, field.TypeString, value) } if au.mutation.SourceAsNumberCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: alert.FieldSourceAsNumber, - }) + _spec.ClearField(alert.FieldSourceAsNumber, field.TypeString) } if value, ok := au.mutation.SourceAsName(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldSourceAsName, - }) + _spec.SetField(alert.FieldSourceAsName, field.TypeString, value) } if au.mutation.SourceAsNameCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: alert.FieldSourceAsName, - }) + _spec.ClearField(alert.FieldSourceAsName, field.TypeString) } if value, ok := au.mutation.SourceCountry(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldSourceCountry, - }) + _spec.SetField(alert.FieldSourceCountry, field.TypeString, value) } if au.mutation.SourceCountryCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: alert.FieldSourceCountry, - }) + _spec.ClearField(alert.FieldSourceCountry, field.TypeString) } if value, ok := au.mutation.SourceLatitude(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeFloat32, - Value: value, - Column: alert.FieldSourceLatitude, - }) + _spec.SetField(alert.FieldSourceLatitude, field.TypeFloat32, value) } if value, ok := au.mutation.AddedSourceLatitude(); ok { - _spec.Fields.Add = append(_spec.Fields.Add, &sqlgraph.FieldSpec{ - Type: field.TypeFloat32, - Value: value, - Column: alert.FieldSourceLatitude, - }) + _spec.AddField(alert.FieldSourceLatitude, field.TypeFloat32, value) } if au.mutation.SourceLatitudeCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeFloat32, - Column: alert.FieldSourceLatitude, - }) + _spec.ClearField(alert.FieldSourceLatitude, field.TypeFloat32) } if value, ok := au.mutation.SourceLongitude(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeFloat32, - Value: value, - Column: alert.FieldSourceLongitude, - }) + _spec.SetField(alert.FieldSourceLongitude, field.TypeFloat32, value) } if value, ok := au.mutation.AddedSourceLongitude(); ok { - _spec.Fields.Add = append(_spec.Fields.Add, &sqlgraph.FieldSpec{ - Type: field.TypeFloat32, - Value: value, - Column: alert.FieldSourceLongitude, - }) + _spec.AddField(alert.FieldSourceLongitude, field.TypeFloat32, value) } if au.mutation.SourceLongitudeCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeFloat32, - Column: alert.FieldSourceLongitude, - }) + _spec.ClearField(alert.FieldSourceLongitude, field.TypeFloat32) } if value, ok := au.mutation.SourceScope(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldSourceScope, - }) + _spec.SetField(alert.FieldSourceScope, field.TypeString, value) } if au.mutation.SourceScopeCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: alert.FieldSourceScope, - }) + _spec.ClearField(alert.FieldSourceScope, field.TypeString) } if value, ok := au.mutation.SourceValue(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldSourceValue, - }) + _spec.SetField(alert.FieldSourceValue, field.TypeString, value) } if au.mutation.SourceValueCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: alert.FieldSourceValue, - }) + _spec.ClearField(alert.FieldSourceValue, field.TypeString) } if value, ok := au.mutation.Capacity(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeInt32, - Value: value, - Column: alert.FieldCapacity, - }) + _spec.SetField(alert.FieldCapacity, field.TypeInt32, value) } if value, ok := au.mutation.AddedCapacity(); ok { - _spec.Fields.Add = append(_spec.Fields.Add, &sqlgraph.FieldSpec{ - Type: field.TypeInt32, - Value: value, - Column: alert.FieldCapacity, - }) + _spec.AddField(alert.FieldCapacity, field.TypeInt32, value) } if au.mutation.CapacityCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeInt32, - Column: alert.FieldCapacity, - }) + _spec.ClearField(alert.FieldCapacity, field.TypeInt32) } if value, ok := au.mutation.LeakSpeed(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldLeakSpeed, - }) + _spec.SetField(alert.FieldLeakSpeed, field.TypeString, value) } if au.mutation.LeakSpeedCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: alert.FieldLeakSpeed, - }) + _spec.ClearField(alert.FieldLeakSpeed, field.TypeString) } if value, ok := au.mutation.ScenarioVersion(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldScenarioVersion, - }) + _spec.SetField(alert.FieldScenarioVersion, field.TypeString, value) } if au.mutation.ScenarioVersionCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: alert.FieldScenarioVersion, - }) + _spec.ClearField(alert.FieldScenarioVersion, field.TypeString) } if value, ok := au.mutation.ScenarioHash(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldScenarioHash, - }) + _spec.SetField(alert.FieldScenarioHash, field.TypeString, value) } if au.mutation.ScenarioHashCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: alert.FieldScenarioHash, - }) + _spec.ClearField(alert.FieldScenarioHash, field.TypeString) } if value, ok := au.mutation.Simulated(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeBool, - Value: value, - Column: alert.FieldSimulated, - }) + _spec.SetField(alert.FieldSimulated, field.TypeBool, value) } if value, ok := au.mutation.UUID(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldUUID, - }) + _spec.SetField(alert.FieldUUID, field.TypeString, value) } if au.mutation.UUIDCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: alert.FieldUUID, - }) + _spec.ClearField(alert.FieldUUID, field.TypeString) } if au.mutation.OwnerCleared() { edge := &sqlgraph.EdgeSpec{ @@ -1030,10 +823,7 @@ func (au *AlertUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{alert.OwnerColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: machine.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(machine.FieldID, field.TypeInt), }, } _spec.Edges.Clear = append(_spec.Edges.Clear, edge) @@ -1046,10 +836,7 @@ func (au *AlertUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{alert.OwnerColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: machine.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(machine.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -1065,10 +852,7 @@ func (au *AlertUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{alert.DecisionsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: decision.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(decision.FieldID, field.TypeInt), }, } _spec.Edges.Clear = append(_spec.Edges.Clear, edge) @@ -1081,10 +865,7 @@ func (au *AlertUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{alert.DecisionsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: decision.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(decision.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -1100,10 +881,7 @@ func (au *AlertUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{alert.DecisionsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: decision.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(decision.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -1119,10 +897,7 @@ func (au *AlertUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{alert.EventsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: event.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(event.FieldID, field.TypeInt), }, } _spec.Edges.Clear = append(_spec.Edges.Clear, edge) @@ -1135,10 +910,7 @@ func (au *AlertUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{alert.EventsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: event.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(event.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -1154,10 +926,7 @@ func (au *AlertUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{alert.EventsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: event.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(event.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -1173,10 +942,7 @@ func (au *AlertUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{alert.MetasColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: meta.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(meta.FieldID, field.TypeInt), }, } _spec.Edges.Clear = append(_spec.Edges.Clear, edge) @@ -1189,10 +955,7 @@ func (au *AlertUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{alert.MetasColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: meta.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(meta.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -1208,10 +971,7 @@ func (au *AlertUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{alert.MetasColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: meta.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(meta.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -1227,6 +987,7 @@ func (au *AlertUpdate) sqlSave(ctx context.Context) (n int, err error) { } return 0, err } + au.mutation.done = true return n, nil } @@ -1828,6 +1589,12 @@ func (auo *AlertUpdateOne) RemoveMetas(m ...*Meta) *AlertUpdateOne { return auo.RemoveMetaIDs(ids...) } +// Where appends a list predicates to the AlertUpdate builder. +func (auo *AlertUpdateOne) Where(ps ...predicate.Alert) *AlertUpdateOne { + auo.mutation.Where(ps...) + return auo +} + // Select allows selecting one or more fields (columns) of the returned entity. // The default is selecting all fields defined in the entity schema. func (auo *AlertUpdateOne) Select(field string, fields ...string) *AlertUpdateOne { @@ -1837,41 +1604,8 @@ func (auo *AlertUpdateOne) Select(field string, fields ...string) *AlertUpdateOn // Save executes the query and returns the updated Alert entity. func (auo *AlertUpdateOne) Save(ctx context.Context) (*Alert, error) { - var ( - err error - node *Alert - ) auo.defaults() - if len(auo.hooks) == 0 { - node, err = auo.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*AlertMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - auo.mutation = mutation - node, err = auo.sqlSave(ctx) - mutation.done = true - return node, err - }) - for i := len(auo.hooks) - 1; i >= 0; i-- { - if auo.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = auo.hooks[i](mut) - } - v, err := mut.Mutate(ctx, auo.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*Alert) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from AlertMutation", v) - } - node = nv - } - return node, err + return withHooks(ctx, auo.sqlSave, auo.mutation, auo.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -1909,16 +1643,7 @@ func (auo *AlertUpdateOne) defaults() { } func (auo *AlertUpdateOne) sqlSave(ctx context.Context) (_node *Alert, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: alert.Table, - Columns: alert.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: alert.FieldID, - }, - }, - } + _spec := sqlgraph.NewUpdateSpec(alert.Table, alert.Columns, sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt)) id, ok := auo.mutation.ID() if !ok { return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "Alert.id" for update`)} @@ -1944,319 +1669,148 @@ func (auo *AlertUpdateOne) sqlSave(ctx context.Context) (_node *Alert, err error } } if value, ok := auo.mutation.CreatedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: alert.FieldCreatedAt, - }) + _spec.SetField(alert.FieldCreatedAt, field.TypeTime, value) } if auo.mutation.CreatedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: alert.FieldCreatedAt, - }) + _spec.ClearField(alert.FieldCreatedAt, field.TypeTime) } if value, ok := auo.mutation.UpdatedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: alert.FieldUpdatedAt, - }) + _spec.SetField(alert.FieldUpdatedAt, field.TypeTime, value) } if auo.mutation.UpdatedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: alert.FieldUpdatedAt, - }) + _spec.ClearField(alert.FieldUpdatedAt, field.TypeTime) } if value, ok := auo.mutation.Scenario(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldScenario, - }) + _spec.SetField(alert.FieldScenario, field.TypeString, value) } if value, ok := auo.mutation.BucketId(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldBucketId, - }) + _spec.SetField(alert.FieldBucketId, field.TypeString, value) } if auo.mutation.BucketIdCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: alert.FieldBucketId, - }) + _spec.ClearField(alert.FieldBucketId, field.TypeString) } if value, ok := auo.mutation.Message(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldMessage, - }) + _spec.SetField(alert.FieldMessage, field.TypeString, value) } if auo.mutation.MessageCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: alert.FieldMessage, - }) + _spec.ClearField(alert.FieldMessage, field.TypeString) } if value, ok := auo.mutation.EventsCount(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeInt32, - Value: value, - Column: alert.FieldEventsCount, - }) + _spec.SetField(alert.FieldEventsCount, field.TypeInt32, value) } if value, ok := auo.mutation.AddedEventsCount(); ok { - _spec.Fields.Add = append(_spec.Fields.Add, &sqlgraph.FieldSpec{ - Type: field.TypeInt32, - Value: value, - Column: alert.FieldEventsCount, - }) + _spec.AddField(alert.FieldEventsCount, field.TypeInt32, value) } if auo.mutation.EventsCountCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeInt32, - Column: alert.FieldEventsCount, - }) + _spec.ClearField(alert.FieldEventsCount, field.TypeInt32) } if value, ok := auo.mutation.StartedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: alert.FieldStartedAt, - }) + _spec.SetField(alert.FieldStartedAt, field.TypeTime, value) } if auo.mutation.StartedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: alert.FieldStartedAt, - }) + _spec.ClearField(alert.FieldStartedAt, field.TypeTime) } if value, ok := auo.mutation.StoppedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: alert.FieldStoppedAt, - }) + _spec.SetField(alert.FieldStoppedAt, field.TypeTime, value) } if auo.mutation.StoppedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: alert.FieldStoppedAt, - }) + _spec.ClearField(alert.FieldStoppedAt, field.TypeTime) } if value, ok := auo.mutation.SourceIp(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldSourceIp, - }) + _spec.SetField(alert.FieldSourceIp, field.TypeString, value) } if auo.mutation.SourceIpCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: alert.FieldSourceIp, - }) + _spec.ClearField(alert.FieldSourceIp, field.TypeString) } if value, ok := auo.mutation.SourceRange(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldSourceRange, - }) + _spec.SetField(alert.FieldSourceRange, field.TypeString, value) } if auo.mutation.SourceRangeCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: alert.FieldSourceRange, - }) + _spec.ClearField(alert.FieldSourceRange, field.TypeString) } if value, ok := auo.mutation.SourceAsNumber(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldSourceAsNumber, - }) + _spec.SetField(alert.FieldSourceAsNumber, field.TypeString, value) } if auo.mutation.SourceAsNumberCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: alert.FieldSourceAsNumber, - }) + _spec.ClearField(alert.FieldSourceAsNumber, field.TypeString) } if value, ok := auo.mutation.SourceAsName(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldSourceAsName, - }) + _spec.SetField(alert.FieldSourceAsName, field.TypeString, value) } if auo.mutation.SourceAsNameCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: alert.FieldSourceAsName, - }) + _spec.ClearField(alert.FieldSourceAsName, field.TypeString) } if value, ok := auo.mutation.SourceCountry(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldSourceCountry, - }) + _spec.SetField(alert.FieldSourceCountry, field.TypeString, value) } if auo.mutation.SourceCountryCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: alert.FieldSourceCountry, - }) + _spec.ClearField(alert.FieldSourceCountry, field.TypeString) } if value, ok := auo.mutation.SourceLatitude(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeFloat32, - Value: value, - Column: alert.FieldSourceLatitude, - }) + _spec.SetField(alert.FieldSourceLatitude, field.TypeFloat32, value) } if value, ok := auo.mutation.AddedSourceLatitude(); ok { - _spec.Fields.Add = append(_spec.Fields.Add, &sqlgraph.FieldSpec{ - Type: field.TypeFloat32, - Value: value, - Column: alert.FieldSourceLatitude, - }) + _spec.AddField(alert.FieldSourceLatitude, field.TypeFloat32, value) } if auo.mutation.SourceLatitudeCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeFloat32, - Column: alert.FieldSourceLatitude, - }) + _spec.ClearField(alert.FieldSourceLatitude, field.TypeFloat32) } if value, ok := auo.mutation.SourceLongitude(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeFloat32, - Value: value, - Column: alert.FieldSourceLongitude, - }) + _spec.SetField(alert.FieldSourceLongitude, field.TypeFloat32, value) } if value, ok := auo.mutation.AddedSourceLongitude(); ok { - _spec.Fields.Add = append(_spec.Fields.Add, &sqlgraph.FieldSpec{ - Type: field.TypeFloat32, - Value: value, - Column: alert.FieldSourceLongitude, - }) + _spec.AddField(alert.FieldSourceLongitude, field.TypeFloat32, value) } if auo.mutation.SourceLongitudeCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeFloat32, - Column: alert.FieldSourceLongitude, - }) + _spec.ClearField(alert.FieldSourceLongitude, field.TypeFloat32) } if value, ok := auo.mutation.SourceScope(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldSourceScope, - }) + _spec.SetField(alert.FieldSourceScope, field.TypeString, value) } if auo.mutation.SourceScopeCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: alert.FieldSourceScope, - }) + _spec.ClearField(alert.FieldSourceScope, field.TypeString) } if value, ok := auo.mutation.SourceValue(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldSourceValue, - }) + _spec.SetField(alert.FieldSourceValue, field.TypeString, value) } if auo.mutation.SourceValueCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: alert.FieldSourceValue, - }) + _spec.ClearField(alert.FieldSourceValue, field.TypeString) } if value, ok := auo.mutation.Capacity(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeInt32, - Value: value, - Column: alert.FieldCapacity, - }) + _spec.SetField(alert.FieldCapacity, field.TypeInt32, value) } if value, ok := auo.mutation.AddedCapacity(); ok { - _spec.Fields.Add = append(_spec.Fields.Add, &sqlgraph.FieldSpec{ - Type: field.TypeInt32, - Value: value, - Column: alert.FieldCapacity, - }) + _spec.AddField(alert.FieldCapacity, field.TypeInt32, value) } if auo.mutation.CapacityCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeInt32, - Column: alert.FieldCapacity, - }) + _spec.ClearField(alert.FieldCapacity, field.TypeInt32) } if value, ok := auo.mutation.LeakSpeed(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldLeakSpeed, - }) + _spec.SetField(alert.FieldLeakSpeed, field.TypeString, value) } if auo.mutation.LeakSpeedCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: alert.FieldLeakSpeed, - }) + _spec.ClearField(alert.FieldLeakSpeed, field.TypeString) } if value, ok := auo.mutation.ScenarioVersion(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldScenarioVersion, - }) + _spec.SetField(alert.FieldScenarioVersion, field.TypeString, value) } if auo.mutation.ScenarioVersionCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: alert.FieldScenarioVersion, - }) + _spec.ClearField(alert.FieldScenarioVersion, field.TypeString) } if value, ok := auo.mutation.ScenarioHash(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldScenarioHash, - }) + _spec.SetField(alert.FieldScenarioHash, field.TypeString, value) } if auo.mutation.ScenarioHashCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: alert.FieldScenarioHash, - }) + _spec.ClearField(alert.FieldScenarioHash, field.TypeString) } if value, ok := auo.mutation.Simulated(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeBool, - Value: value, - Column: alert.FieldSimulated, - }) + _spec.SetField(alert.FieldSimulated, field.TypeBool, value) } if value, ok := auo.mutation.UUID(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldUUID, - }) + _spec.SetField(alert.FieldUUID, field.TypeString, value) } if auo.mutation.UUIDCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: alert.FieldUUID, - }) + _spec.ClearField(alert.FieldUUID, field.TypeString) } if auo.mutation.OwnerCleared() { edge := &sqlgraph.EdgeSpec{ @@ -2266,10 +1820,7 @@ func (auo *AlertUpdateOne) sqlSave(ctx context.Context) (_node *Alert, err error Columns: []string{alert.OwnerColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: machine.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(machine.FieldID, field.TypeInt), }, } _spec.Edges.Clear = append(_spec.Edges.Clear, edge) @@ -2282,10 +1833,7 @@ func (auo *AlertUpdateOne) sqlSave(ctx context.Context) (_node *Alert, err error Columns: []string{alert.OwnerColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: machine.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(machine.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -2301,10 +1849,7 @@ func (auo *AlertUpdateOne) sqlSave(ctx context.Context) (_node *Alert, err error Columns: []string{alert.DecisionsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: decision.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(decision.FieldID, field.TypeInt), }, } _spec.Edges.Clear = append(_spec.Edges.Clear, edge) @@ -2317,10 +1862,7 @@ func (auo *AlertUpdateOne) sqlSave(ctx context.Context) (_node *Alert, err error Columns: []string{alert.DecisionsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: decision.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(decision.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -2336,10 +1878,7 @@ func (auo *AlertUpdateOne) sqlSave(ctx context.Context) (_node *Alert, err error Columns: []string{alert.DecisionsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: decision.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(decision.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -2355,10 +1894,7 @@ func (auo *AlertUpdateOne) sqlSave(ctx context.Context) (_node *Alert, err error Columns: []string{alert.EventsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: event.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(event.FieldID, field.TypeInt), }, } _spec.Edges.Clear = append(_spec.Edges.Clear, edge) @@ -2371,10 +1907,7 @@ func (auo *AlertUpdateOne) sqlSave(ctx context.Context) (_node *Alert, err error Columns: []string{alert.EventsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: event.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(event.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -2390,10 +1923,7 @@ func (auo *AlertUpdateOne) sqlSave(ctx context.Context) (_node *Alert, err error Columns: []string{alert.EventsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: event.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(event.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -2409,10 +1939,7 @@ func (auo *AlertUpdateOne) sqlSave(ctx context.Context) (_node *Alert, err error Columns: []string{alert.MetasColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: meta.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(meta.FieldID, field.TypeInt), }, } _spec.Edges.Clear = append(_spec.Edges.Clear, edge) @@ -2425,10 +1952,7 @@ func (auo *AlertUpdateOne) sqlSave(ctx context.Context) (_node *Alert, err error Columns: []string{alert.MetasColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: meta.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(meta.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -2444,10 +1968,7 @@ func (auo *AlertUpdateOne) sqlSave(ctx context.Context) (_node *Alert, err error Columns: []string{alert.MetasColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: meta.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(meta.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -2466,5 +1987,6 @@ func (auo *AlertUpdateOne) sqlSave(ctx context.Context) (_node *Alert, err error } return nil, err } + auo.mutation.done = true return _node, nil } diff --git a/pkg/database/ent/bouncer.go b/pkg/database/ent/bouncer.go index 068fc6c6713..fe189c3817e 100644 --- a/pkg/database/ent/bouncer.go +++ b/pkg/database/ent/bouncer.go @@ -7,6 +7,7 @@ import ( "strings" "time" + "entgo.io/ent" "entgo.io/ent/dialect/sql" "github.com/crowdsecurity/crowdsec/pkg/database/ent/bouncer" ) @@ -37,7 +38,8 @@ type Bouncer struct { // LastPull holds the value of the "last_pull" field. LastPull time.Time `json:"last_pull"` // AuthType holds the value of the "auth_type" field. - AuthType string `json:"auth_type"` + AuthType string `json:"auth_type"` + selectValues sql.SelectValues } // scanValues returns the types for scanning values from sql.Rows. @@ -54,7 +56,7 @@ func (*Bouncer) scanValues(columns []string) ([]any, error) { case bouncer.FieldCreatedAt, bouncer.FieldUpdatedAt, bouncer.FieldUntil, bouncer.FieldLastPull: values[i] = new(sql.NullTime) default: - return nil, fmt.Errorf("unexpected column %q for type Bouncer", columns[i]) + values[i] = new(sql.UnknownType) } } return values, nil @@ -142,16 +144,24 @@ func (b *Bouncer) assignValues(columns []string, values []any) error { } else if value.Valid { b.AuthType = value.String } + default: + b.selectValues.Set(columns[i], values[i]) } } return nil } +// Value returns the ent.Value that was dynamically selected and assigned to the Bouncer. +// This includes values selected through modifiers, order, etc. +func (b *Bouncer) Value(name string) (ent.Value, error) { + return b.selectValues.Get(name) +} + // Update returns a builder for updating this Bouncer. // Note that you need to call Bouncer.Unwrap() before calling this method if this Bouncer // was returned from a transaction, and the transaction was committed or rolled back. func (b *Bouncer) Update() *BouncerUpdateOne { - return (&BouncerClient{config: b.config}).UpdateOne(b) + return NewBouncerClient(b.config).UpdateOne(b) } // Unwrap unwraps the Bouncer entity that was returned from a transaction after it was closed, @@ -212,9 +222,3 @@ func (b *Bouncer) String() string { // Bouncers is a parsable slice of Bouncer. type Bouncers []*Bouncer - -func (b Bouncers) config(cfg config) { - for _i := range b { - b[_i].config = cfg - } -} diff --git a/pkg/database/ent/bouncer/bouncer.go b/pkg/database/ent/bouncer/bouncer.go index b688594ece4..24d230d3b54 100644 --- a/pkg/database/ent/bouncer/bouncer.go +++ b/pkg/database/ent/bouncer/bouncer.go @@ -4,6 +4,8 @@ package bouncer import ( "time" + + "entgo.io/ent/dialect/sql" ) const ( @@ -81,3 +83,66 @@ var ( // DefaultAuthType holds the default value on creation for the "auth_type" field. DefaultAuthType string ) + +// OrderOption defines the ordering options for the Bouncer queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByCreatedAt orders the results by the created_at field. +func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() +} + +// ByUpdatedAt orders the results by the updated_at field. +func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc() +} + +// ByName orders the results by the name field. +func ByName(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldName, opts...).ToFunc() +} + +// ByAPIKey orders the results by the api_key field. +func ByAPIKey(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldAPIKey, opts...).ToFunc() +} + +// ByRevoked orders the results by the revoked field. +func ByRevoked(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldRevoked, opts...).ToFunc() +} + +// ByIPAddress orders the results by the ip_address field. +func ByIPAddress(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldIPAddress, opts...).ToFunc() +} + +// ByType orders the results by the type field. +func ByType(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldType, opts...).ToFunc() +} + +// ByVersion orders the results by the version field. +func ByVersion(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldVersion, opts...).ToFunc() +} + +// ByUntil orders the results by the until field. +func ByUntil(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUntil, opts...).ToFunc() +} + +// ByLastPull orders the results by the last_pull field. +func ByLastPull(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldLastPull, opts...).ToFunc() +} + +// ByAuthType orders the results by the auth_type field. +func ByAuthType(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldAuthType, opts...).ToFunc() +} diff --git a/pkg/database/ent/bouncer/where.go b/pkg/database/ent/bouncer/where.go index 03a543f6d4f..5bf721dbf51 100644 --- a/pkg/database/ent/bouncer/where.go +++ b/pkg/database/ent/bouncer/where.go @@ -11,1128 +11,735 @@ import ( // ID filters vertices based on their ID field. func ID(id int) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.Bouncer(sql.FieldEQ(FieldID, id)) } // IDEQ applies the EQ predicate on the ID field. func IDEQ(id int) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.Bouncer(sql.FieldEQ(FieldID, id)) } // IDNEQ applies the NEQ predicate on the ID field. func IDNEQ(id int) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldID), id)) - }) + return predicate.Bouncer(sql.FieldNEQ(FieldID, id)) } // IDIn applies the In predicate on the ID field. func IDIn(ids ...int) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - v := make([]any, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.In(s.C(FieldID), v...)) - }) + return predicate.Bouncer(sql.FieldIn(FieldID, ids...)) } // IDNotIn applies the NotIn predicate on the ID field. func IDNotIn(ids ...int) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - v := make([]any, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.NotIn(s.C(FieldID), v...)) - }) + return predicate.Bouncer(sql.FieldNotIn(FieldID, ids...)) } // IDGT applies the GT predicate on the ID field. func IDGT(id int) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldID), id)) - }) + return predicate.Bouncer(sql.FieldGT(FieldID, id)) } // IDGTE applies the GTE predicate on the ID field. func IDGTE(id int) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldID), id)) - }) + return predicate.Bouncer(sql.FieldGTE(FieldID, id)) } // IDLT applies the LT predicate on the ID field. func IDLT(id int) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldID), id)) - }) + return predicate.Bouncer(sql.FieldLT(FieldID, id)) } // IDLTE applies the LTE predicate on the ID field. func IDLTE(id int) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldID), id)) - }) + return predicate.Bouncer(sql.FieldLTE(FieldID, id)) } // CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. func CreatedAt(v time.Time) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldCreatedAt), v)) - }) + return predicate.Bouncer(sql.FieldEQ(FieldCreatedAt, v)) } // UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ. func UpdatedAt(v time.Time) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldUpdatedAt), v)) - }) + return predicate.Bouncer(sql.FieldEQ(FieldUpdatedAt, v)) } // Name applies equality check predicate on the "name" field. It's identical to NameEQ. func Name(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldName), v)) - }) + return predicate.Bouncer(sql.FieldEQ(FieldName, v)) } // APIKey applies equality check predicate on the "api_key" field. It's identical to APIKeyEQ. func APIKey(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldAPIKey), v)) - }) + return predicate.Bouncer(sql.FieldEQ(FieldAPIKey, v)) } // Revoked applies equality check predicate on the "revoked" field. It's identical to RevokedEQ. func Revoked(v bool) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldRevoked), v)) - }) + return predicate.Bouncer(sql.FieldEQ(FieldRevoked, v)) } // IPAddress applies equality check predicate on the "ip_address" field. It's identical to IPAddressEQ. func IPAddress(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldIPAddress), v)) - }) + return predicate.Bouncer(sql.FieldEQ(FieldIPAddress, v)) } // Type applies equality check predicate on the "type" field. It's identical to TypeEQ. func Type(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldType), v)) - }) + return predicate.Bouncer(sql.FieldEQ(FieldType, v)) } // Version applies equality check predicate on the "version" field. It's identical to VersionEQ. func Version(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldVersion), v)) - }) + return predicate.Bouncer(sql.FieldEQ(FieldVersion, v)) } // Until applies equality check predicate on the "until" field. It's identical to UntilEQ. func Until(v time.Time) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldUntil), v)) - }) + return predicate.Bouncer(sql.FieldEQ(FieldUntil, v)) } // LastPull applies equality check predicate on the "last_pull" field. It's identical to LastPullEQ. func LastPull(v time.Time) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldLastPull), v)) - }) + return predicate.Bouncer(sql.FieldEQ(FieldLastPull, v)) } // AuthType applies equality check predicate on the "auth_type" field. It's identical to AuthTypeEQ. func AuthType(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldAuthType), v)) - }) + return predicate.Bouncer(sql.FieldEQ(FieldAuthType, v)) } // CreatedAtEQ applies the EQ predicate on the "created_at" field. func CreatedAtEQ(v time.Time) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldCreatedAt), v)) - }) + return predicate.Bouncer(sql.FieldEQ(FieldCreatedAt, v)) } // CreatedAtNEQ applies the NEQ predicate on the "created_at" field. func CreatedAtNEQ(v time.Time) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldCreatedAt), v)) - }) + return predicate.Bouncer(sql.FieldNEQ(FieldCreatedAt, v)) } // CreatedAtIn applies the In predicate on the "created_at" field. func CreatedAtIn(vs ...time.Time) predicate.Bouncer { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldCreatedAt), v...)) - }) + return predicate.Bouncer(sql.FieldIn(FieldCreatedAt, vs...)) } // CreatedAtNotIn applies the NotIn predicate on the "created_at" field. func CreatedAtNotIn(vs ...time.Time) predicate.Bouncer { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldCreatedAt), v...)) - }) + return predicate.Bouncer(sql.FieldNotIn(FieldCreatedAt, vs...)) } // CreatedAtGT applies the GT predicate on the "created_at" field. func CreatedAtGT(v time.Time) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldCreatedAt), v)) - }) + return predicate.Bouncer(sql.FieldGT(FieldCreatedAt, v)) } // CreatedAtGTE applies the GTE predicate on the "created_at" field. func CreatedAtGTE(v time.Time) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldCreatedAt), v)) - }) + return predicate.Bouncer(sql.FieldGTE(FieldCreatedAt, v)) } // CreatedAtLT applies the LT predicate on the "created_at" field. func CreatedAtLT(v time.Time) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldCreatedAt), v)) - }) + return predicate.Bouncer(sql.FieldLT(FieldCreatedAt, v)) } // CreatedAtLTE applies the LTE predicate on the "created_at" field. func CreatedAtLTE(v time.Time) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldCreatedAt), v)) - }) + return predicate.Bouncer(sql.FieldLTE(FieldCreatedAt, v)) } // CreatedAtIsNil applies the IsNil predicate on the "created_at" field. func CreatedAtIsNil() predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldCreatedAt))) - }) + return predicate.Bouncer(sql.FieldIsNull(FieldCreatedAt)) } // CreatedAtNotNil applies the NotNil predicate on the "created_at" field. func CreatedAtNotNil() predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldCreatedAt))) - }) + return predicate.Bouncer(sql.FieldNotNull(FieldCreatedAt)) } // UpdatedAtEQ applies the EQ predicate on the "updated_at" field. func UpdatedAtEQ(v time.Time) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldUpdatedAt), v)) - }) + return predicate.Bouncer(sql.FieldEQ(FieldUpdatedAt, v)) } // UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field. func UpdatedAtNEQ(v time.Time) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldUpdatedAt), v)) - }) + return predicate.Bouncer(sql.FieldNEQ(FieldUpdatedAt, v)) } // UpdatedAtIn applies the In predicate on the "updated_at" field. func UpdatedAtIn(vs ...time.Time) predicate.Bouncer { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldUpdatedAt), v...)) - }) + return predicate.Bouncer(sql.FieldIn(FieldUpdatedAt, vs...)) } // UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field. func UpdatedAtNotIn(vs ...time.Time) predicate.Bouncer { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldUpdatedAt), v...)) - }) + return predicate.Bouncer(sql.FieldNotIn(FieldUpdatedAt, vs...)) } // UpdatedAtGT applies the GT predicate on the "updated_at" field. func UpdatedAtGT(v time.Time) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldUpdatedAt), v)) - }) + return predicate.Bouncer(sql.FieldGT(FieldUpdatedAt, v)) } // UpdatedAtGTE applies the GTE predicate on the "updated_at" field. func UpdatedAtGTE(v time.Time) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldUpdatedAt), v)) - }) + return predicate.Bouncer(sql.FieldGTE(FieldUpdatedAt, v)) } // UpdatedAtLT applies the LT predicate on the "updated_at" field. func UpdatedAtLT(v time.Time) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldUpdatedAt), v)) - }) + return predicate.Bouncer(sql.FieldLT(FieldUpdatedAt, v)) } // UpdatedAtLTE applies the LTE predicate on the "updated_at" field. func UpdatedAtLTE(v time.Time) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldUpdatedAt), v)) - }) + return predicate.Bouncer(sql.FieldLTE(FieldUpdatedAt, v)) } // UpdatedAtIsNil applies the IsNil predicate on the "updated_at" field. func UpdatedAtIsNil() predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldUpdatedAt))) - }) + return predicate.Bouncer(sql.FieldIsNull(FieldUpdatedAt)) } // UpdatedAtNotNil applies the NotNil predicate on the "updated_at" field. func UpdatedAtNotNil() predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldUpdatedAt))) - }) + return predicate.Bouncer(sql.FieldNotNull(FieldUpdatedAt)) } // NameEQ applies the EQ predicate on the "name" field. func NameEQ(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldName), v)) - }) + return predicate.Bouncer(sql.FieldEQ(FieldName, v)) } // NameNEQ applies the NEQ predicate on the "name" field. func NameNEQ(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldName), v)) - }) + return predicate.Bouncer(sql.FieldNEQ(FieldName, v)) } // NameIn applies the In predicate on the "name" field. func NameIn(vs ...string) predicate.Bouncer { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldName), v...)) - }) + return predicate.Bouncer(sql.FieldIn(FieldName, vs...)) } // NameNotIn applies the NotIn predicate on the "name" field. func NameNotIn(vs ...string) predicate.Bouncer { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldName), v...)) - }) + return predicate.Bouncer(sql.FieldNotIn(FieldName, vs...)) } // NameGT applies the GT predicate on the "name" field. func NameGT(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldName), v)) - }) + return predicate.Bouncer(sql.FieldGT(FieldName, v)) } // NameGTE applies the GTE predicate on the "name" field. func NameGTE(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldName), v)) - }) + return predicate.Bouncer(sql.FieldGTE(FieldName, v)) } // NameLT applies the LT predicate on the "name" field. func NameLT(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldName), v)) - }) + return predicate.Bouncer(sql.FieldLT(FieldName, v)) } // NameLTE applies the LTE predicate on the "name" field. func NameLTE(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldName), v)) - }) + return predicate.Bouncer(sql.FieldLTE(FieldName, v)) } // NameContains applies the Contains predicate on the "name" field. func NameContains(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldName), v)) - }) + return predicate.Bouncer(sql.FieldContains(FieldName, v)) } // NameHasPrefix applies the HasPrefix predicate on the "name" field. func NameHasPrefix(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldName), v)) - }) + return predicate.Bouncer(sql.FieldHasPrefix(FieldName, v)) } // NameHasSuffix applies the HasSuffix predicate on the "name" field. func NameHasSuffix(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldName), v)) - }) + return predicate.Bouncer(sql.FieldHasSuffix(FieldName, v)) } // NameEqualFold applies the EqualFold predicate on the "name" field. func NameEqualFold(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldName), v)) - }) + return predicate.Bouncer(sql.FieldEqualFold(FieldName, v)) } // NameContainsFold applies the ContainsFold predicate on the "name" field. func NameContainsFold(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldName), v)) - }) + return predicate.Bouncer(sql.FieldContainsFold(FieldName, v)) } // APIKeyEQ applies the EQ predicate on the "api_key" field. func APIKeyEQ(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldAPIKey), v)) - }) + return predicate.Bouncer(sql.FieldEQ(FieldAPIKey, v)) } // APIKeyNEQ applies the NEQ predicate on the "api_key" field. func APIKeyNEQ(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldAPIKey), v)) - }) + return predicate.Bouncer(sql.FieldNEQ(FieldAPIKey, v)) } // APIKeyIn applies the In predicate on the "api_key" field. func APIKeyIn(vs ...string) predicate.Bouncer { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldAPIKey), v...)) - }) + return predicate.Bouncer(sql.FieldIn(FieldAPIKey, vs...)) } // APIKeyNotIn applies the NotIn predicate on the "api_key" field. func APIKeyNotIn(vs ...string) predicate.Bouncer { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldAPIKey), v...)) - }) + return predicate.Bouncer(sql.FieldNotIn(FieldAPIKey, vs...)) } // APIKeyGT applies the GT predicate on the "api_key" field. func APIKeyGT(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldAPIKey), v)) - }) + return predicate.Bouncer(sql.FieldGT(FieldAPIKey, v)) } // APIKeyGTE applies the GTE predicate on the "api_key" field. func APIKeyGTE(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldAPIKey), v)) - }) + return predicate.Bouncer(sql.FieldGTE(FieldAPIKey, v)) } // APIKeyLT applies the LT predicate on the "api_key" field. func APIKeyLT(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldAPIKey), v)) - }) + return predicate.Bouncer(sql.FieldLT(FieldAPIKey, v)) } // APIKeyLTE applies the LTE predicate on the "api_key" field. func APIKeyLTE(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldAPIKey), v)) - }) + return predicate.Bouncer(sql.FieldLTE(FieldAPIKey, v)) } // APIKeyContains applies the Contains predicate on the "api_key" field. func APIKeyContains(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldAPIKey), v)) - }) + return predicate.Bouncer(sql.FieldContains(FieldAPIKey, v)) } // APIKeyHasPrefix applies the HasPrefix predicate on the "api_key" field. func APIKeyHasPrefix(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldAPIKey), v)) - }) + return predicate.Bouncer(sql.FieldHasPrefix(FieldAPIKey, v)) } // APIKeyHasSuffix applies the HasSuffix predicate on the "api_key" field. func APIKeyHasSuffix(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldAPIKey), v)) - }) + return predicate.Bouncer(sql.FieldHasSuffix(FieldAPIKey, v)) } // APIKeyEqualFold applies the EqualFold predicate on the "api_key" field. func APIKeyEqualFold(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldAPIKey), v)) - }) + return predicate.Bouncer(sql.FieldEqualFold(FieldAPIKey, v)) } // APIKeyContainsFold applies the ContainsFold predicate on the "api_key" field. func APIKeyContainsFold(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldAPIKey), v)) - }) + return predicate.Bouncer(sql.FieldContainsFold(FieldAPIKey, v)) } // RevokedEQ applies the EQ predicate on the "revoked" field. func RevokedEQ(v bool) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldRevoked), v)) - }) + return predicate.Bouncer(sql.FieldEQ(FieldRevoked, v)) } // RevokedNEQ applies the NEQ predicate on the "revoked" field. func RevokedNEQ(v bool) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldRevoked), v)) - }) + return predicate.Bouncer(sql.FieldNEQ(FieldRevoked, v)) } // IPAddressEQ applies the EQ predicate on the "ip_address" field. func IPAddressEQ(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldIPAddress), v)) - }) + return predicate.Bouncer(sql.FieldEQ(FieldIPAddress, v)) } // IPAddressNEQ applies the NEQ predicate on the "ip_address" field. func IPAddressNEQ(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldIPAddress), v)) - }) + return predicate.Bouncer(sql.FieldNEQ(FieldIPAddress, v)) } // IPAddressIn applies the In predicate on the "ip_address" field. func IPAddressIn(vs ...string) predicate.Bouncer { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldIPAddress), v...)) - }) + return predicate.Bouncer(sql.FieldIn(FieldIPAddress, vs...)) } // IPAddressNotIn applies the NotIn predicate on the "ip_address" field. func IPAddressNotIn(vs ...string) predicate.Bouncer { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldIPAddress), v...)) - }) + return predicate.Bouncer(sql.FieldNotIn(FieldIPAddress, vs...)) } // IPAddressGT applies the GT predicate on the "ip_address" field. func IPAddressGT(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldIPAddress), v)) - }) + return predicate.Bouncer(sql.FieldGT(FieldIPAddress, v)) } // IPAddressGTE applies the GTE predicate on the "ip_address" field. func IPAddressGTE(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldIPAddress), v)) - }) + return predicate.Bouncer(sql.FieldGTE(FieldIPAddress, v)) } // IPAddressLT applies the LT predicate on the "ip_address" field. func IPAddressLT(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldIPAddress), v)) - }) + return predicate.Bouncer(sql.FieldLT(FieldIPAddress, v)) } // IPAddressLTE applies the LTE predicate on the "ip_address" field. func IPAddressLTE(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldIPAddress), v)) - }) + return predicate.Bouncer(sql.FieldLTE(FieldIPAddress, v)) } // IPAddressContains applies the Contains predicate on the "ip_address" field. func IPAddressContains(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldIPAddress), v)) - }) + return predicate.Bouncer(sql.FieldContains(FieldIPAddress, v)) } // IPAddressHasPrefix applies the HasPrefix predicate on the "ip_address" field. func IPAddressHasPrefix(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldIPAddress), v)) - }) + return predicate.Bouncer(sql.FieldHasPrefix(FieldIPAddress, v)) } // IPAddressHasSuffix applies the HasSuffix predicate on the "ip_address" field. func IPAddressHasSuffix(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldIPAddress), v)) - }) + return predicate.Bouncer(sql.FieldHasSuffix(FieldIPAddress, v)) } // IPAddressIsNil applies the IsNil predicate on the "ip_address" field. func IPAddressIsNil() predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldIPAddress))) - }) + return predicate.Bouncer(sql.FieldIsNull(FieldIPAddress)) } // IPAddressNotNil applies the NotNil predicate on the "ip_address" field. func IPAddressNotNil() predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldIPAddress))) - }) + return predicate.Bouncer(sql.FieldNotNull(FieldIPAddress)) } // IPAddressEqualFold applies the EqualFold predicate on the "ip_address" field. func IPAddressEqualFold(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldIPAddress), v)) - }) + return predicate.Bouncer(sql.FieldEqualFold(FieldIPAddress, v)) } // IPAddressContainsFold applies the ContainsFold predicate on the "ip_address" field. func IPAddressContainsFold(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldIPAddress), v)) - }) + return predicate.Bouncer(sql.FieldContainsFold(FieldIPAddress, v)) } // TypeEQ applies the EQ predicate on the "type" field. func TypeEQ(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldType), v)) - }) + return predicate.Bouncer(sql.FieldEQ(FieldType, v)) } // TypeNEQ applies the NEQ predicate on the "type" field. func TypeNEQ(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldType), v)) - }) + return predicate.Bouncer(sql.FieldNEQ(FieldType, v)) } // TypeIn applies the In predicate on the "type" field. func TypeIn(vs ...string) predicate.Bouncer { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldType), v...)) - }) + return predicate.Bouncer(sql.FieldIn(FieldType, vs...)) } // TypeNotIn applies the NotIn predicate on the "type" field. func TypeNotIn(vs ...string) predicate.Bouncer { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldType), v...)) - }) + return predicate.Bouncer(sql.FieldNotIn(FieldType, vs...)) } // TypeGT applies the GT predicate on the "type" field. func TypeGT(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldType), v)) - }) + return predicate.Bouncer(sql.FieldGT(FieldType, v)) } // TypeGTE applies the GTE predicate on the "type" field. func TypeGTE(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldType), v)) - }) + return predicate.Bouncer(sql.FieldGTE(FieldType, v)) } // TypeLT applies the LT predicate on the "type" field. func TypeLT(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldType), v)) - }) + return predicate.Bouncer(sql.FieldLT(FieldType, v)) } // TypeLTE applies the LTE predicate on the "type" field. func TypeLTE(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldType), v)) - }) + return predicate.Bouncer(sql.FieldLTE(FieldType, v)) } // TypeContains applies the Contains predicate on the "type" field. func TypeContains(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldType), v)) - }) + return predicate.Bouncer(sql.FieldContains(FieldType, v)) } // TypeHasPrefix applies the HasPrefix predicate on the "type" field. func TypeHasPrefix(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldType), v)) - }) + return predicate.Bouncer(sql.FieldHasPrefix(FieldType, v)) } // TypeHasSuffix applies the HasSuffix predicate on the "type" field. func TypeHasSuffix(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldType), v)) - }) + return predicate.Bouncer(sql.FieldHasSuffix(FieldType, v)) } // TypeIsNil applies the IsNil predicate on the "type" field. func TypeIsNil() predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldType))) - }) + return predicate.Bouncer(sql.FieldIsNull(FieldType)) } // TypeNotNil applies the NotNil predicate on the "type" field. func TypeNotNil() predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldType))) - }) + return predicate.Bouncer(sql.FieldNotNull(FieldType)) } // TypeEqualFold applies the EqualFold predicate on the "type" field. func TypeEqualFold(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldType), v)) - }) + return predicate.Bouncer(sql.FieldEqualFold(FieldType, v)) } // TypeContainsFold applies the ContainsFold predicate on the "type" field. func TypeContainsFold(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldType), v)) - }) + return predicate.Bouncer(sql.FieldContainsFold(FieldType, v)) } // VersionEQ applies the EQ predicate on the "version" field. func VersionEQ(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldVersion), v)) - }) + return predicate.Bouncer(sql.FieldEQ(FieldVersion, v)) } // VersionNEQ applies the NEQ predicate on the "version" field. func VersionNEQ(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldVersion), v)) - }) + return predicate.Bouncer(sql.FieldNEQ(FieldVersion, v)) } // VersionIn applies the In predicate on the "version" field. func VersionIn(vs ...string) predicate.Bouncer { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldVersion), v...)) - }) + return predicate.Bouncer(sql.FieldIn(FieldVersion, vs...)) } // VersionNotIn applies the NotIn predicate on the "version" field. func VersionNotIn(vs ...string) predicate.Bouncer { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldVersion), v...)) - }) + return predicate.Bouncer(sql.FieldNotIn(FieldVersion, vs...)) } // VersionGT applies the GT predicate on the "version" field. func VersionGT(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldVersion), v)) - }) + return predicate.Bouncer(sql.FieldGT(FieldVersion, v)) } // VersionGTE applies the GTE predicate on the "version" field. func VersionGTE(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldVersion), v)) - }) + return predicate.Bouncer(sql.FieldGTE(FieldVersion, v)) } // VersionLT applies the LT predicate on the "version" field. func VersionLT(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldVersion), v)) - }) + return predicate.Bouncer(sql.FieldLT(FieldVersion, v)) } // VersionLTE applies the LTE predicate on the "version" field. func VersionLTE(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldVersion), v)) - }) + return predicate.Bouncer(sql.FieldLTE(FieldVersion, v)) } // VersionContains applies the Contains predicate on the "version" field. func VersionContains(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldVersion), v)) - }) + return predicate.Bouncer(sql.FieldContains(FieldVersion, v)) } // VersionHasPrefix applies the HasPrefix predicate on the "version" field. func VersionHasPrefix(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldVersion), v)) - }) + return predicate.Bouncer(sql.FieldHasPrefix(FieldVersion, v)) } // VersionHasSuffix applies the HasSuffix predicate on the "version" field. func VersionHasSuffix(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldVersion), v)) - }) + return predicate.Bouncer(sql.FieldHasSuffix(FieldVersion, v)) } // VersionIsNil applies the IsNil predicate on the "version" field. func VersionIsNil() predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldVersion))) - }) + return predicate.Bouncer(sql.FieldIsNull(FieldVersion)) } // VersionNotNil applies the NotNil predicate on the "version" field. func VersionNotNil() predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldVersion))) - }) + return predicate.Bouncer(sql.FieldNotNull(FieldVersion)) } // VersionEqualFold applies the EqualFold predicate on the "version" field. func VersionEqualFold(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldVersion), v)) - }) + return predicate.Bouncer(sql.FieldEqualFold(FieldVersion, v)) } // VersionContainsFold applies the ContainsFold predicate on the "version" field. func VersionContainsFold(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldVersion), v)) - }) + return predicate.Bouncer(sql.FieldContainsFold(FieldVersion, v)) } // UntilEQ applies the EQ predicate on the "until" field. func UntilEQ(v time.Time) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldUntil), v)) - }) + return predicate.Bouncer(sql.FieldEQ(FieldUntil, v)) } // UntilNEQ applies the NEQ predicate on the "until" field. func UntilNEQ(v time.Time) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldUntil), v)) - }) + return predicate.Bouncer(sql.FieldNEQ(FieldUntil, v)) } // UntilIn applies the In predicate on the "until" field. func UntilIn(vs ...time.Time) predicate.Bouncer { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldUntil), v...)) - }) + return predicate.Bouncer(sql.FieldIn(FieldUntil, vs...)) } // UntilNotIn applies the NotIn predicate on the "until" field. func UntilNotIn(vs ...time.Time) predicate.Bouncer { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldUntil), v...)) - }) + return predicate.Bouncer(sql.FieldNotIn(FieldUntil, vs...)) } // UntilGT applies the GT predicate on the "until" field. func UntilGT(v time.Time) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldUntil), v)) - }) + return predicate.Bouncer(sql.FieldGT(FieldUntil, v)) } // UntilGTE applies the GTE predicate on the "until" field. func UntilGTE(v time.Time) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldUntil), v)) - }) + return predicate.Bouncer(sql.FieldGTE(FieldUntil, v)) } // UntilLT applies the LT predicate on the "until" field. func UntilLT(v time.Time) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldUntil), v)) - }) + return predicate.Bouncer(sql.FieldLT(FieldUntil, v)) } // UntilLTE applies the LTE predicate on the "until" field. func UntilLTE(v time.Time) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldUntil), v)) - }) + return predicate.Bouncer(sql.FieldLTE(FieldUntil, v)) } // UntilIsNil applies the IsNil predicate on the "until" field. func UntilIsNil() predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldUntil))) - }) + return predicate.Bouncer(sql.FieldIsNull(FieldUntil)) } // UntilNotNil applies the NotNil predicate on the "until" field. func UntilNotNil() predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldUntil))) - }) + return predicate.Bouncer(sql.FieldNotNull(FieldUntil)) } // LastPullEQ applies the EQ predicate on the "last_pull" field. func LastPullEQ(v time.Time) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldLastPull), v)) - }) + return predicate.Bouncer(sql.FieldEQ(FieldLastPull, v)) } // LastPullNEQ applies the NEQ predicate on the "last_pull" field. func LastPullNEQ(v time.Time) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldLastPull), v)) - }) + return predicate.Bouncer(sql.FieldNEQ(FieldLastPull, v)) } // LastPullIn applies the In predicate on the "last_pull" field. func LastPullIn(vs ...time.Time) predicate.Bouncer { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldLastPull), v...)) - }) + return predicate.Bouncer(sql.FieldIn(FieldLastPull, vs...)) } // LastPullNotIn applies the NotIn predicate on the "last_pull" field. func LastPullNotIn(vs ...time.Time) predicate.Bouncer { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldLastPull), v...)) - }) + return predicate.Bouncer(sql.FieldNotIn(FieldLastPull, vs...)) } // LastPullGT applies the GT predicate on the "last_pull" field. func LastPullGT(v time.Time) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldLastPull), v)) - }) + return predicate.Bouncer(sql.FieldGT(FieldLastPull, v)) } // LastPullGTE applies the GTE predicate on the "last_pull" field. func LastPullGTE(v time.Time) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldLastPull), v)) - }) + return predicate.Bouncer(sql.FieldGTE(FieldLastPull, v)) } // LastPullLT applies the LT predicate on the "last_pull" field. func LastPullLT(v time.Time) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldLastPull), v)) - }) + return predicate.Bouncer(sql.FieldLT(FieldLastPull, v)) } // LastPullLTE applies the LTE predicate on the "last_pull" field. func LastPullLTE(v time.Time) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldLastPull), v)) - }) + return predicate.Bouncer(sql.FieldLTE(FieldLastPull, v)) } // AuthTypeEQ applies the EQ predicate on the "auth_type" field. func AuthTypeEQ(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldAuthType), v)) - }) + return predicate.Bouncer(sql.FieldEQ(FieldAuthType, v)) } // AuthTypeNEQ applies the NEQ predicate on the "auth_type" field. func AuthTypeNEQ(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldAuthType), v)) - }) + return predicate.Bouncer(sql.FieldNEQ(FieldAuthType, v)) } // AuthTypeIn applies the In predicate on the "auth_type" field. func AuthTypeIn(vs ...string) predicate.Bouncer { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldAuthType), v...)) - }) + return predicate.Bouncer(sql.FieldIn(FieldAuthType, vs...)) } // AuthTypeNotIn applies the NotIn predicate on the "auth_type" field. func AuthTypeNotIn(vs ...string) predicate.Bouncer { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldAuthType), v...)) - }) + return predicate.Bouncer(sql.FieldNotIn(FieldAuthType, vs...)) } // AuthTypeGT applies the GT predicate on the "auth_type" field. func AuthTypeGT(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldAuthType), v)) - }) + return predicate.Bouncer(sql.FieldGT(FieldAuthType, v)) } // AuthTypeGTE applies the GTE predicate on the "auth_type" field. func AuthTypeGTE(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldAuthType), v)) - }) + return predicate.Bouncer(sql.FieldGTE(FieldAuthType, v)) } // AuthTypeLT applies the LT predicate on the "auth_type" field. func AuthTypeLT(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldAuthType), v)) - }) + return predicate.Bouncer(sql.FieldLT(FieldAuthType, v)) } // AuthTypeLTE applies the LTE predicate on the "auth_type" field. func AuthTypeLTE(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldAuthType), v)) - }) + return predicate.Bouncer(sql.FieldLTE(FieldAuthType, v)) } // AuthTypeContains applies the Contains predicate on the "auth_type" field. func AuthTypeContains(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldAuthType), v)) - }) + return predicate.Bouncer(sql.FieldContains(FieldAuthType, v)) } // AuthTypeHasPrefix applies the HasPrefix predicate on the "auth_type" field. func AuthTypeHasPrefix(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldAuthType), v)) - }) + return predicate.Bouncer(sql.FieldHasPrefix(FieldAuthType, v)) } // AuthTypeHasSuffix applies the HasSuffix predicate on the "auth_type" field. func AuthTypeHasSuffix(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldAuthType), v)) - }) + return predicate.Bouncer(sql.FieldHasSuffix(FieldAuthType, v)) } // AuthTypeEqualFold applies the EqualFold predicate on the "auth_type" field. func AuthTypeEqualFold(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldAuthType), v)) - }) + return predicate.Bouncer(sql.FieldEqualFold(FieldAuthType, v)) } // AuthTypeContainsFold applies the ContainsFold predicate on the "auth_type" field. func AuthTypeContainsFold(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldAuthType), v)) - }) + return predicate.Bouncer(sql.FieldContainsFold(FieldAuthType, v)) } // And groups predicates with the AND operator between them. func And(predicates ...predicate.Bouncer) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s1 := s.Clone().SetP(nil) - for _, p := range predicates { - p(s1) - } - s.Where(s1.P()) - }) + return predicate.Bouncer(sql.AndPredicates(predicates...)) } // Or groups predicates with the OR operator between them. func Or(predicates ...predicate.Bouncer) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s1 := s.Clone().SetP(nil) - for i, p := range predicates { - if i > 0 { - s1.Or() - } - p(s1) - } - s.Where(s1.P()) - }) + return predicate.Bouncer(sql.OrPredicates(predicates...)) } // Not applies the not operator on the given predicate. func Not(p predicate.Bouncer) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - p(s.Not()) - }) + return predicate.Bouncer(sql.NotPredicates(p)) } diff --git a/pkg/database/ent/bouncer_create.go b/pkg/database/ent/bouncer_create.go index 685ce089d1e..3d08277dcfb 100644 --- a/pkg/database/ent/bouncer_create.go +++ b/pkg/database/ent/bouncer_create.go @@ -157,50 +157,8 @@ func (bc *BouncerCreate) Mutation() *BouncerMutation { // Save creates the Bouncer in the database. func (bc *BouncerCreate) Save(ctx context.Context) (*Bouncer, error) { - var ( - err error - node *Bouncer - ) bc.defaults() - if len(bc.hooks) == 0 { - if err = bc.check(); err != nil { - return nil, err - } - node, err = bc.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*BouncerMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - if err = bc.check(); err != nil { - return nil, err - } - bc.mutation = mutation - if node, err = bc.sqlSave(ctx); err != nil { - return nil, err - } - mutation.id = &node.ID - mutation.done = true - return node, err - }) - for i := len(bc.hooks) - 1; i >= 0; i-- { - if bc.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = bc.hooks[i](mut) - } - v, err := mut.Mutate(ctx, bc.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*Bouncer) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from BouncerMutation", v) - } - node = nv - } - return node, err + return withHooks(ctx, bc.sqlSave, bc.mutation, bc.hooks) } // SaveX calls Save and panics if Save returns an error. @@ -274,6 +232,9 @@ func (bc *BouncerCreate) check() error { } func (bc *BouncerCreate) sqlSave(ctx context.Context) (*Bouncer, error) { + if err := bc.check(); err != nil { + return nil, err + } _node, _spec := bc.createSpec() if err := sqlgraph.CreateNode(ctx, bc.driver, _spec); err != nil { if sqlgraph.IsConstraintError(err) { @@ -283,106 +244,58 @@ func (bc *BouncerCreate) sqlSave(ctx context.Context) (*Bouncer, error) { } id := _spec.ID.Value.(int64) _node.ID = int(id) + bc.mutation.id = &_node.ID + bc.mutation.done = true return _node, nil } func (bc *BouncerCreate) createSpec() (*Bouncer, *sqlgraph.CreateSpec) { var ( _node = &Bouncer{config: bc.config} - _spec = &sqlgraph.CreateSpec{ - Table: bouncer.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: bouncer.FieldID, - }, - } + _spec = sqlgraph.NewCreateSpec(bouncer.Table, sqlgraph.NewFieldSpec(bouncer.FieldID, field.TypeInt)) ) if value, ok := bc.mutation.CreatedAt(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: bouncer.FieldCreatedAt, - }) + _spec.SetField(bouncer.FieldCreatedAt, field.TypeTime, value) _node.CreatedAt = &value } if value, ok := bc.mutation.UpdatedAt(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: bouncer.FieldUpdatedAt, - }) + _spec.SetField(bouncer.FieldUpdatedAt, field.TypeTime, value) _node.UpdatedAt = &value } if value, ok := bc.mutation.Name(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: bouncer.FieldName, - }) + _spec.SetField(bouncer.FieldName, field.TypeString, value) _node.Name = value } if value, ok := bc.mutation.APIKey(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: bouncer.FieldAPIKey, - }) + _spec.SetField(bouncer.FieldAPIKey, field.TypeString, value) _node.APIKey = value } if value, ok := bc.mutation.Revoked(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeBool, - Value: value, - Column: bouncer.FieldRevoked, - }) + _spec.SetField(bouncer.FieldRevoked, field.TypeBool, value) _node.Revoked = value } if value, ok := bc.mutation.IPAddress(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: bouncer.FieldIPAddress, - }) + _spec.SetField(bouncer.FieldIPAddress, field.TypeString, value) _node.IPAddress = value } if value, ok := bc.mutation.GetType(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: bouncer.FieldType, - }) + _spec.SetField(bouncer.FieldType, field.TypeString, value) _node.Type = value } if value, ok := bc.mutation.Version(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: bouncer.FieldVersion, - }) + _spec.SetField(bouncer.FieldVersion, field.TypeString, value) _node.Version = value } if value, ok := bc.mutation.Until(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: bouncer.FieldUntil, - }) + _spec.SetField(bouncer.FieldUntil, field.TypeTime, value) _node.Until = value } if value, ok := bc.mutation.LastPull(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: bouncer.FieldLastPull, - }) + _spec.SetField(bouncer.FieldLastPull, field.TypeTime, value) _node.LastPull = value } if value, ok := bc.mutation.AuthType(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: bouncer.FieldAuthType, - }) + _spec.SetField(bouncer.FieldAuthType, field.TypeString, value) _node.AuthType = value } return _node, _spec @@ -391,11 +304,15 @@ func (bc *BouncerCreate) createSpec() (*Bouncer, *sqlgraph.CreateSpec) { // BouncerCreateBulk is the builder for creating many Bouncer entities in bulk. type BouncerCreateBulk struct { config + err error builders []*BouncerCreate } // Save creates the Bouncer entities in the database. func (bcb *BouncerCreateBulk) Save(ctx context.Context) ([]*Bouncer, error) { + if bcb.err != nil { + return nil, bcb.err + } specs := make([]*sqlgraph.CreateSpec, len(bcb.builders)) nodes := make([]*Bouncer, len(bcb.builders)) mutators := make([]Mutator, len(bcb.builders)) @@ -412,8 +329,8 @@ func (bcb *BouncerCreateBulk) Save(ctx context.Context) ([]*Bouncer, error) { return nil, err } builder.mutation = mutation - nodes[i], specs[i] = builder.createSpec() var err error + nodes[i], specs[i] = builder.createSpec() if i < len(mutators)-1 { _, err = mutators[i+1].Mutate(root, bcb.builders[i+1].mutation) } else { diff --git a/pkg/database/ent/bouncer_delete.go b/pkg/database/ent/bouncer_delete.go index 6bfb9459190..bf459e77e28 100644 --- a/pkg/database/ent/bouncer_delete.go +++ b/pkg/database/ent/bouncer_delete.go @@ -4,7 +4,6 @@ package ent import ( "context" - "fmt" "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" @@ -28,34 +27,7 @@ func (bd *BouncerDelete) Where(ps ...predicate.Bouncer) *BouncerDelete { // Exec executes the deletion query and returns how many vertices were deleted. func (bd *BouncerDelete) Exec(ctx context.Context) (int, error) { - var ( - err error - affected int - ) - if len(bd.hooks) == 0 { - affected, err = bd.sqlExec(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*BouncerMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - bd.mutation = mutation - affected, err = bd.sqlExec(ctx) - mutation.done = true - return affected, err - }) - for i := len(bd.hooks) - 1; i >= 0; i-- { - if bd.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = bd.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, bd.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks(ctx, bd.sqlExec, bd.mutation, bd.hooks) } // ExecX is like Exec, but panics if an error occurs. @@ -68,15 +40,7 @@ func (bd *BouncerDelete) ExecX(ctx context.Context) int { } func (bd *BouncerDelete) sqlExec(ctx context.Context) (int, error) { - _spec := &sqlgraph.DeleteSpec{ - Node: &sqlgraph.NodeSpec{ - Table: bouncer.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: bouncer.FieldID, - }, - }, - } + _spec := sqlgraph.NewDeleteSpec(bouncer.Table, sqlgraph.NewFieldSpec(bouncer.FieldID, field.TypeInt)) if ps := bd.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -88,6 +52,7 @@ func (bd *BouncerDelete) sqlExec(ctx context.Context) (int, error) { if err != nil && sqlgraph.IsConstraintError(err) { err = &ConstraintError{msg: err.Error(), wrap: err} } + bd.mutation.done = true return affected, err } @@ -96,6 +61,12 @@ type BouncerDeleteOne struct { bd *BouncerDelete } +// Where appends a list predicates to the BouncerDelete builder. +func (bdo *BouncerDeleteOne) Where(ps ...predicate.Bouncer) *BouncerDeleteOne { + bdo.bd.mutation.Where(ps...) + return bdo +} + // Exec executes the deletion query. func (bdo *BouncerDeleteOne) Exec(ctx context.Context) error { n, err := bdo.bd.Exec(ctx) @@ -111,5 +82,7 @@ func (bdo *BouncerDeleteOne) Exec(ctx context.Context) error { // ExecX is like Exec, but panics if an error occurs. func (bdo *BouncerDeleteOne) ExecX(ctx context.Context) { - bdo.bd.ExecX(ctx) + if err := bdo.Exec(ctx); err != nil { + panic(err) + } } diff --git a/pkg/database/ent/bouncer_query.go b/pkg/database/ent/bouncer_query.go index 2747a3e0b3a..ea2b7495733 100644 --- a/pkg/database/ent/bouncer_query.go +++ b/pkg/database/ent/bouncer_query.go @@ -17,11 +17,9 @@ import ( // BouncerQuery is the builder for querying Bouncer entities. type BouncerQuery struct { config - limit *int - offset *int - unique *bool - order []OrderFunc - fields []string + ctx *QueryContext + order []bouncer.OrderOption + inters []Interceptor predicates []predicate.Bouncer // intermediate query (i.e. traversal path). sql *sql.Selector @@ -34,27 +32,27 @@ func (bq *BouncerQuery) Where(ps ...predicate.Bouncer) *BouncerQuery { return bq } -// Limit adds a limit step to the query. +// Limit the number of records to be returned by this query. func (bq *BouncerQuery) Limit(limit int) *BouncerQuery { - bq.limit = &limit + bq.ctx.Limit = &limit return bq } -// Offset adds an offset step to the query. +// Offset to start from. func (bq *BouncerQuery) Offset(offset int) *BouncerQuery { - bq.offset = &offset + bq.ctx.Offset = &offset return bq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (bq *BouncerQuery) Unique(unique bool) *BouncerQuery { - bq.unique = &unique + bq.ctx.Unique = &unique return bq } -// Order adds an order step to the query. -func (bq *BouncerQuery) Order(o ...OrderFunc) *BouncerQuery { +// Order specifies how the records should be ordered. +func (bq *BouncerQuery) Order(o ...bouncer.OrderOption) *BouncerQuery { bq.order = append(bq.order, o...) return bq } @@ -62,7 +60,7 @@ func (bq *BouncerQuery) Order(o ...OrderFunc) *BouncerQuery { // First returns the first Bouncer entity from the query. // Returns a *NotFoundError when no Bouncer was found. func (bq *BouncerQuery) First(ctx context.Context) (*Bouncer, error) { - nodes, err := bq.Limit(1).All(ctx) + nodes, err := bq.Limit(1).All(setContextOp(ctx, bq.ctx, "First")) if err != nil { return nil, err } @@ -85,7 +83,7 @@ func (bq *BouncerQuery) FirstX(ctx context.Context) *Bouncer { // Returns a *NotFoundError when no Bouncer ID was found. func (bq *BouncerQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = bq.Limit(1).IDs(ctx); err != nil { + if ids, err = bq.Limit(1).IDs(setContextOp(ctx, bq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -108,7 +106,7 @@ func (bq *BouncerQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one Bouncer entity is found. // Returns a *NotFoundError when no Bouncer entities are found. func (bq *BouncerQuery) Only(ctx context.Context) (*Bouncer, error) { - nodes, err := bq.Limit(2).All(ctx) + nodes, err := bq.Limit(2).All(setContextOp(ctx, bq.ctx, "Only")) if err != nil { return nil, err } @@ -136,7 +134,7 @@ func (bq *BouncerQuery) OnlyX(ctx context.Context) *Bouncer { // Returns a *NotFoundError when no entities are found. func (bq *BouncerQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = bq.Limit(2).IDs(ctx); err != nil { + if ids, err = bq.Limit(2).IDs(setContextOp(ctx, bq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -161,10 +159,12 @@ func (bq *BouncerQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of Bouncers. func (bq *BouncerQuery) All(ctx context.Context) ([]*Bouncer, error) { + ctx = setContextOp(ctx, bq.ctx, "All") if err := bq.prepareQuery(ctx); err != nil { return nil, err } - return bq.sqlAll(ctx) + qr := querierAll[[]*Bouncer, *BouncerQuery]() + return withInterceptors[[]*Bouncer](ctx, bq, qr, bq.inters) } // AllX is like All, but panics if an error occurs. @@ -177,9 +177,12 @@ func (bq *BouncerQuery) AllX(ctx context.Context) []*Bouncer { } // IDs executes the query and returns a list of Bouncer IDs. -func (bq *BouncerQuery) IDs(ctx context.Context) ([]int, error) { - var ids []int - if err := bq.Select(bouncer.FieldID).Scan(ctx, &ids); err != nil { +func (bq *BouncerQuery) IDs(ctx context.Context) (ids []int, err error) { + if bq.ctx.Unique == nil && bq.path != nil { + bq.Unique(true) + } + ctx = setContextOp(ctx, bq.ctx, "IDs") + if err = bq.Select(bouncer.FieldID).Scan(ctx, &ids); err != nil { return nil, err } return ids, nil @@ -196,10 +199,11 @@ func (bq *BouncerQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (bq *BouncerQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, bq.ctx, "Count") if err := bq.prepareQuery(ctx); err != nil { return 0, err } - return bq.sqlCount(ctx) + return withInterceptors[int](ctx, bq, querierCount[*BouncerQuery](), bq.inters) } // CountX is like Count, but panics if an error occurs. @@ -213,10 +217,15 @@ func (bq *BouncerQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (bq *BouncerQuery) Exist(ctx context.Context) (bool, error) { - if err := bq.prepareQuery(ctx); err != nil { - return false, err + ctx = setContextOp(ctx, bq.ctx, "Exist") + switch _, err := bq.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil } - return bq.sqlExist(ctx) } // ExistX is like Exist, but panics if an error occurs. @@ -236,14 +245,13 @@ func (bq *BouncerQuery) Clone() *BouncerQuery { } return &BouncerQuery{ config: bq.config, - limit: bq.limit, - offset: bq.offset, - order: append([]OrderFunc{}, bq.order...), + ctx: bq.ctx.Clone(), + order: append([]bouncer.OrderOption{}, bq.order...), + inters: append([]Interceptor{}, bq.inters...), predicates: append([]predicate.Bouncer{}, bq.predicates...), // clone intermediate query. - sql: bq.sql.Clone(), - path: bq.path, - unique: bq.unique, + sql: bq.sql.Clone(), + path: bq.path, } } @@ -262,16 +270,11 @@ func (bq *BouncerQuery) Clone() *BouncerQuery { // Aggregate(ent.Count()). // Scan(ctx, &v) func (bq *BouncerQuery) GroupBy(field string, fields ...string) *BouncerGroupBy { - grbuild := &BouncerGroupBy{config: bq.config} - grbuild.fields = append([]string{field}, fields...) - grbuild.path = func(ctx context.Context) (prev *sql.Selector, err error) { - if err := bq.prepareQuery(ctx); err != nil { - return nil, err - } - return bq.sqlQuery(ctx), nil - } + bq.ctx.Fields = append([]string{field}, fields...) + grbuild := &BouncerGroupBy{build: bq} + grbuild.flds = &bq.ctx.Fields grbuild.label = bouncer.Label - grbuild.flds, grbuild.scan = &grbuild.fields, grbuild.Scan + grbuild.scan = grbuild.Scan return grbuild } @@ -288,15 +291,30 @@ func (bq *BouncerQuery) GroupBy(field string, fields ...string) *BouncerGroupBy // Select(bouncer.FieldCreatedAt). // Scan(ctx, &v) func (bq *BouncerQuery) Select(fields ...string) *BouncerSelect { - bq.fields = append(bq.fields, fields...) - selbuild := &BouncerSelect{BouncerQuery: bq} - selbuild.label = bouncer.Label - selbuild.flds, selbuild.scan = &bq.fields, selbuild.Scan - return selbuild + bq.ctx.Fields = append(bq.ctx.Fields, fields...) + sbuild := &BouncerSelect{BouncerQuery: bq} + sbuild.label = bouncer.Label + sbuild.flds, sbuild.scan = &bq.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a BouncerSelect configured with the given aggregations. +func (bq *BouncerQuery) Aggregate(fns ...AggregateFunc) *BouncerSelect { + return bq.Select().Aggregate(fns...) } func (bq *BouncerQuery) prepareQuery(ctx context.Context) error { - for _, f := range bq.fields { + for _, inter := range bq.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, bq); err != nil { + return err + } + } + } + for _, f := range bq.ctx.Fields { if !bouncer.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -338,41 +356,22 @@ func (bq *BouncerQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Boun func (bq *BouncerQuery) sqlCount(ctx context.Context) (int, error) { _spec := bq.querySpec() - _spec.Node.Columns = bq.fields - if len(bq.fields) > 0 { - _spec.Unique = bq.unique != nil && *bq.unique + _spec.Node.Columns = bq.ctx.Fields + if len(bq.ctx.Fields) > 0 { + _spec.Unique = bq.ctx.Unique != nil && *bq.ctx.Unique } return sqlgraph.CountNodes(ctx, bq.driver, _spec) } -func (bq *BouncerQuery) sqlExist(ctx context.Context) (bool, error) { - switch _, err := bq.FirstID(ctx); { - case IsNotFound(err): - return false, nil - case err != nil: - return false, fmt.Errorf("ent: check existence: %w", err) - default: - return true, nil - } -} - func (bq *BouncerQuery) querySpec() *sqlgraph.QuerySpec { - _spec := &sqlgraph.QuerySpec{ - Node: &sqlgraph.NodeSpec{ - Table: bouncer.Table, - Columns: bouncer.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: bouncer.FieldID, - }, - }, - From: bq.sql, - Unique: true, - } - if unique := bq.unique; unique != nil { + _spec := sqlgraph.NewQuerySpec(bouncer.Table, bouncer.Columns, sqlgraph.NewFieldSpec(bouncer.FieldID, field.TypeInt)) + _spec.From = bq.sql + if unique := bq.ctx.Unique; unique != nil { _spec.Unique = *unique + } else if bq.path != nil { + _spec.Unique = true } - if fields := bq.fields; len(fields) > 0 { + if fields := bq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, bouncer.FieldID) for i := range fields { @@ -388,10 +387,10 @@ func (bq *BouncerQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := bq.limit; limit != nil { + if limit := bq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := bq.offset; offset != nil { + if offset := bq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := bq.order; len(ps) > 0 { @@ -407,7 +406,7 @@ func (bq *BouncerQuery) querySpec() *sqlgraph.QuerySpec { func (bq *BouncerQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(bq.driver.Dialect()) t1 := builder.Table(bouncer.Table) - columns := bq.fields + columns := bq.ctx.Fields if len(columns) == 0 { columns = bouncer.Columns } @@ -416,7 +415,7 @@ func (bq *BouncerQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = bq.sql selector.Select(selector.Columns(columns...)...) } - if bq.unique != nil && *bq.unique { + if bq.ctx.Unique != nil && *bq.ctx.Unique { selector.Distinct() } for _, p := range bq.predicates { @@ -425,12 +424,12 @@ func (bq *BouncerQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range bq.order { p(selector) } - if offset := bq.offset; offset != nil { + if offset := bq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := bq.limit; limit != nil { + if limit := bq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -438,13 +437,8 @@ func (bq *BouncerQuery) sqlQuery(ctx context.Context) *sql.Selector { // BouncerGroupBy is the group-by builder for Bouncer entities. type BouncerGroupBy struct { - config selector - fields []string - fns []AggregateFunc - // intermediate query (i.e. traversal path). - sql *sql.Selector - path func(context.Context) (*sql.Selector, error) + build *BouncerQuery } // Aggregate adds the given aggregation functions to the group-by query. @@ -453,74 +447,77 @@ func (bgb *BouncerGroupBy) Aggregate(fns ...AggregateFunc) *BouncerGroupBy { return bgb } -// Scan applies the group-by query and scans the result into the given value. +// Scan applies the selector query and scans the result into the given value. func (bgb *BouncerGroupBy) Scan(ctx context.Context, v any) error { - query, err := bgb.path(ctx) - if err != nil { + ctx = setContextOp(ctx, bgb.build.ctx, "GroupBy") + if err := bgb.build.prepareQuery(ctx); err != nil { return err } - bgb.sql = query - return bgb.sqlScan(ctx, v) + return scanWithInterceptors[*BouncerQuery, *BouncerGroupBy](ctx, bgb.build, bgb, bgb.build.inters, v) } -func (bgb *BouncerGroupBy) sqlScan(ctx context.Context, v any) error { - for _, f := range bgb.fields { - if !bouncer.ValidColumn(f) { - return &ValidationError{Name: f, err: fmt.Errorf("invalid field %q for group-by", f)} +func (bgb *BouncerGroupBy) sqlScan(ctx context.Context, root *BouncerQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(bgb.fns)) + for _, fn := range bgb.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*bgb.flds)+len(bgb.fns)) + for _, f := range *bgb.flds { + columns = append(columns, selector.C(f)) } + columns = append(columns, aggregation...) + selector.Select(columns...) } - selector := bgb.sqlQuery() + selector.GroupBy(selector.Columns(*bgb.flds...)...) if err := selector.Err(); err != nil { return err } rows := &sql.Rows{} query, args := selector.Query() - if err := bgb.driver.Query(ctx, query, args, rows); err != nil { + if err := bgb.build.driver.Query(ctx, query, args, rows); err != nil { return err } defer rows.Close() return sql.ScanSlice(rows, v) } -func (bgb *BouncerGroupBy) sqlQuery() *sql.Selector { - selector := bgb.sql.Select() - aggregation := make([]string, 0, len(bgb.fns)) - for _, fn := range bgb.fns { - aggregation = append(aggregation, fn(selector)) - } - // If no columns were selected in a custom aggregation function, the default - // selection is the fields used for "group-by", and the aggregation functions. - if len(selector.SelectedColumns()) == 0 { - columns := make([]string, 0, len(bgb.fields)+len(bgb.fns)) - for _, f := range bgb.fields { - columns = append(columns, selector.C(f)) - } - columns = append(columns, aggregation...) - selector.Select(columns...) - } - return selector.GroupBy(selector.Columns(bgb.fields...)...) -} - // BouncerSelect is the builder for selecting fields of Bouncer entities. type BouncerSelect struct { *BouncerQuery selector - // intermediate query (i.e. traversal path). - sql *sql.Selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (bs *BouncerSelect) Aggregate(fns ...AggregateFunc) *BouncerSelect { + bs.fns = append(bs.fns, fns...) + return bs } // Scan applies the selector query and scans the result into the given value. func (bs *BouncerSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, bs.ctx, "Select") if err := bs.prepareQuery(ctx); err != nil { return err } - bs.sql = bs.BouncerQuery.sqlQuery(ctx) - return bs.sqlScan(ctx, v) + return scanWithInterceptors[*BouncerQuery, *BouncerSelect](ctx, bs.BouncerQuery, bs, bs.inters, v) } -func (bs *BouncerSelect) sqlScan(ctx context.Context, v any) error { +func (bs *BouncerSelect) sqlScan(ctx context.Context, root *BouncerQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(bs.fns)) + for _, fn := range bs.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*bs.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } rows := &sql.Rows{} - query, args := bs.sql.Query() + query, args := selector.Query() if err := bs.driver.Query(ctx, query, args, rows); err != nil { return err } diff --git a/pkg/database/ent/bouncer_update.go b/pkg/database/ent/bouncer_update.go index acf48dedeec..f7e71eb315e 100644 --- a/pkg/database/ent/bouncer_update.go +++ b/pkg/database/ent/bouncer_update.go @@ -185,35 +185,8 @@ func (bu *BouncerUpdate) Mutation() *BouncerMutation { // Save executes the query and returns the number of nodes affected by the update operation. func (bu *BouncerUpdate) Save(ctx context.Context) (int, error) { - var ( - err error - affected int - ) bu.defaults() - if len(bu.hooks) == 0 { - affected, err = bu.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*BouncerMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - bu.mutation = mutation - affected, err = bu.sqlSave(ctx) - mutation.done = true - return affected, err - }) - for i := len(bu.hooks) - 1; i >= 0; i-- { - if bu.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = bu.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, bu.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks(ctx, bu.sqlSave, bu.mutation, bu.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -251,16 +224,7 @@ func (bu *BouncerUpdate) defaults() { } func (bu *BouncerUpdate) sqlSave(ctx context.Context) (n int, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: bouncer.Table, - Columns: bouncer.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: bouncer.FieldID, - }, - }, - } + _spec := sqlgraph.NewUpdateSpec(bouncer.Table, bouncer.Columns, sqlgraph.NewFieldSpec(bouncer.FieldID, field.TypeInt)) if ps := bu.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -269,117 +233,55 @@ func (bu *BouncerUpdate) sqlSave(ctx context.Context) (n int, err error) { } } if value, ok := bu.mutation.CreatedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: bouncer.FieldCreatedAt, - }) + _spec.SetField(bouncer.FieldCreatedAt, field.TypeTime, value) } if bu.mutation.CreatedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: bouncer.FieldCreatedAt, - }) + _spec.ClearField(bouncer.FieldCreatedAt, field.TypeTime) } if value, ok := bu.mutation.UpdatedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: bouncer.FieldUpdatedAt, - }) + _spec.SetField(bouncer.FieldUpdatedAt, field.TypeTime, value) } if bu.mutation.UpdatedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: bouncer.FieldUpdatedAt, - }) + _spec.ClearField(bouncer.FieldUpdatedAt, field.TypeTime) } if value, ok := bu.mutation.Name(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: bouncer.FieldName, - }) + _spec.SetField(bouncer.FieldName, field.TypeString, value) } if value, ok := bu.mutation.APIKey(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: bouncer.FieldAPIKey, - }) + _spec.SetField(bouncer.FieldAPIKey, field.TypeString, value) } if value, ok := bu.mutation.Revoked(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeBool, - Value: value, - Column: bouncer.FieldRevoked, - }) + _spec.SetField(bouncer.FieldRevoked, field.TypeBool, value) } if value, ok := bu.mutation.IPAddress(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: bouncer.FieldIPAddress, - }) + _spec.SetField(bouncer.FieldIPAddress, field.TypeString, value) } if bu.mutation.IPAddressCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: bouncer.FieldIPAddress, - }) + _spec.ClearField(bouncer.FieldIPAddress, field.TypeString) } if value, ok := bu.mutation.GetType(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: bouncer.FieldType, - }) + _spec.SetField(bouncer.FieldType, field.TypeString, value) } if bu.mutation.TypeCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: bouncer.FieldType, - }) + _spec.ClearField(bouncer.FieldType, field.TypeString) } if value, ok := bu.mutation.Version(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: bouncer.FieldVersion, - }) + _spec.SetField(bouncer.FieldVersion, field.TypeString, value) } if bu.mutation.VersionCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: bouncer.FieldVersion, - }) + _spec.ClearField(bouncer.FieldVersion, field.TypeString) } if value, ok := bu.mutation.Until(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: bouncer.FieldUntil, - }) + _spec.SetField(bouncer.FieldUntil, field.TypeTime, value) } if bu.mutation.UntilCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: bouncer.FieldUntil, - }) + _spec.ClearField(bouncer.FieldUntil, field.TypeTime) } if value, ok := bu.mutation.LastPull(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: bouncer.FieldLastPull, - }) + _spec.SetField(bouncer.FieldLastPull, field.TypeTime, value) } if value, ok := bu.mutation.AuthType(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: bouncer.FieldAuthType, - }) + _spec.SetField(bouncer.FieldAuthType, field.TypeString, value) } if n, err = sqlgraph.UpdateNodes(ctx, bu.driver, _spec); err != nil { if _, ok := err.(*sqlgraph.NotFoundError); ok { @@ -389,6 +291,7 @@ func (bu *BouncerUpdate) sqlSave(ctx context.Context) (n int, err error) { } return 0, err } + bu.mutation.done = true return n, nil } @@ -555,6 +458,12 @@ func (buo *BouncerUpdateOne) Mutation() *BouncerMutation { return buo.mutation } +// Where appends a list predicates to the BouncerUpdate builder. +func (buo *BouncerUpdateOne) Where(ps ...predicate.Bouncer) *BouncerUpdateOne { + buo.mutation.Where(ps...) + return buo +} + // Select allows selecting one or more fields (columns) of the returned entity. // The default is selecting all fields defined in the entity schema. func (buo *BouncerUpdateOne) Select(field string, fields ...string) *BouncerUpdateOne { @@ -564,41 +473,8 @@ func (buo *BouncerUpdateOne) Select(field string, fields ...string) *BouncerUpda // Save executes the query and returns the updated Bouncer entity. func (buo *BouncerUpdateOne) Save(ctx context.Context) (*Bouncer, error) { - var ( - err error - node *Bouncer - ) buo.defaults() - if len(buo.hooks) == 0 { - node, err = buo.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*BouncerMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - buo.mutation = mutation - node, err = buo.sqlSave(ctx) - mutation.done = true - return node, err - }) - for i := len(buo.hooks) - 1; i >= 0; i-- { - if buo.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = buo.hooks[i](mut) - } - v, err := mut.Mutate(ctx, buo.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*Bouncer) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from BouncerMutation", v) - } - node = nv - } - return node, err + return withHooks(ctx, buo.sqlSave, buo.mutation, buo.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -636,16 +512,7 @@ func (buo *BouncerUpdateOne) defaults() { } func (buo *BouncerUpdateOne) sqlSave(ctx context.Context) (_node *Bouncer, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: bouncer.Table, - Columns: bouncer.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: bouncer.FieldID, - }, - }, - } + _spec := sqlgraph.NewUpdateSpec(bouncer.Table, bouncer.Columns, sqlgraph.NewFieldSpec(bouncer.FieldID, field.TypeInt)) id, ok := buo.mutation.ID() if !ok { return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "Bouncer.id" for update`)} @@ -671,117 +538,55 @@ func (buo *BouncerUpdateOne) sqlSave(ctx context.Context) (_node *Bouncer, err e } } if value, ok := buo.mutation.CreatedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: bouncer.FieldCreatedAt, - }) + _spec.SetField(bouncer.FieldCreatedAt, field.TypeTime, value) } if buo.mutation.CreatedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: bouncer.FieldCreatedAt, - }) + _spec.ClearField(bouncer.FieldCreatedAt, field.TypeTime) } if value, ok := buo.mutation.UpdatedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: bouncer.FieldUpdatedAt, - }) + _spec.SetField(bouncer.FieldUpdatedAt, field.TypeTime, value) } if buo.mutation.UpdatedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: bouncer.FieldUpdatedAt, - }) + _spec.ClearField(bouncer.FieldUpdatedAt, field.TypeTime) } if value, ok := buo.mutation.Name(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: bouncer.FieldName, - }) + _spec.SetField(bouncer.FieldName, field.TypeString, value) } if value, ok := buo.mutation.APIKey(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: bouncer.FieldAPIKey, - }) + _spec.SetField(bouncer.FieldAPIKey, field.TypeString, value) } if value, ok := buo.mutation.Revoked(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeBool, - Value: value, - Column: bouncer.FieldRevoked, - }) + _spec.SetField(bouncer.FieldRevoked, field.TypeBool, value) } if value, ok := buo.mutation.IPAddress(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: bouncer.FieldIPAddress, - }) + _spec.SetField(bouncer.FieldIPAddress, field.TypeString, value) } if buo.mutation.IPAddressCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: bouncer.FieldIPAddress, - }) + _spec.ClearField(bouncer.FieldIPAddress, field.TypeString) } if value, ok := buo.mutation.GetType(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: bouncer.FieldType, - }) + _spec.SetField(bouncer.FieldType, field.TypeString, value) } if buo.mutation.TypeCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: bouncer.FieldType, - }) + _spec.ClearField(bouncer.FieldType, field.TypeString) } if value, ok := buo.mutation.Version(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: bouncer.FieldVersion, - }) + _spec.SetField(bouncer.FieldVersion, field.TypeString, value) } if buo.mutation.VersionCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: bouncer.FieldVersion, - }) + _spec.ClearField(bouncer.FieldVersion, field.TypeString) } if value, ok := buo.mutation.Until(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: bouncer.FieldUntil, - }) + _spec.SetField(bouncer.FieldUntil, field.TypeTime, value) } if buo.mutation.UntilCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: bouncer.FieldUntil, - }) + _spec.ClearField(bouncer.FieldUntil, field.TypeTime) } if value, ok := buo.mutation.LastPull(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: bouncer.FieldLastPull, - }) + _spec.SetField(bouncer.FieldLastPull, field.TypeTime, value) } if value, ok := buo.mutation.AuthType(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: bouncer.FieldAuthType, - }) + _spec.SetField(bouncer.FieldAuthType, field.TypeString, value) } _node = &Bouncer{config: buo.config} _spec.Assign = _node.assignValues @@ -794,5 +599,6 @@ func (buo *BouncerUpdateOne) sqlSave(ctx context.Context) (_node *Bouncer, err e } return nil, err } + buo.mutation.done = true return _node, nil } diff --git a/pkg/database/ent/client.go b/pkg/database/ent/client.go index 815b1df6d16..2761ff088b5 100644 --- a/pkg/database/ent/client.go +++ b/pkg/database/ent/client.go @@ -7,9 +7,14 @@ import ( "errors" "fmt" "log" + "reflect" "github.com/crowdsecurity/crowdsec/pkg/database/ent/migrate" + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" "github.com/crowdsecurity/crowdsec/pkg/database/ent/alert" "github.com/crowdsecurity/crowdsec/pkg/database/ent/bouncer" "github.com/crowdsecurity/crowdsec/pkg/database/ent/configitem" @@ -17,10 +22,6 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/database/ent/event" "github.com/crowdsecurity/crowdsec/pkg/database/ent/machine" "github.com/crowdsecurity/crowdsec/pkg/database/ent/meta" - - "entgo.io/ent/dialect" - "entgo.io/ent/dialect/sql" - "entgo.io/ent/dialect/sql/sqlgraph" ) // Client is the client that holds all ent builders. @@ -46,7 +47,7 @@ type Client struct { // NewClient creates a new client configured with the given options. func NewClient(opts ...Option) *Client { - cfg := config{log: log.Println, hooks: &hooks{}} + cfg := config{log: log.Println, hooks: &hooks{}, inters: &inters{}} cfg.options(opts...) client := &Client{config: cfg} client.init() @@ -64,6 +65,55 @@ func (c *Client) init() { c.Meta = NewMetaClient(c.config) } +type ( + // config is the configuration for the client and its builder. + config struct { + // driver used for executing database requests. + driver dialect.Driver + // debug enable a debug logging. + debug bool + // log used for logging on debug mode. + log func(...any) + // hooks to execute on mutations. + hooks *hooks + // interceptors to execute on queries. + inters *inters + } + // Option function to configure the client. + Option func(*config) +) + +// options applies the options on the config object. +func (c *config) options(opts ...Option) { + for _, opt := range opts { + opt(c) + } + if c.debug { + c.driver = dialect.Debug(c.driver, c.log) + } +} + +// Debug enables debug logging on the ent.Driver. +func Debug() Option { + return func(c *config) { + c.debug = true + } +} + +// Log sets the logging function for debug mode. +func Log(fn func(...any)) Option { + return func(c *config) { + c.log = fn + } +} + +// Driver configures the client driver. +func Driver(driver dialect.Driver) Option { + return func(c *config) { + c.driver = driver + } +} + // Open opens a database/sql.DB specified by the driver name and // the data source name, and returns a new client attached to it. // Optional parameters can be added for configuring the client. @@ -80,11 +130,14 @@ func Open(driverName, dataSourceName string, options ...Option) (*Client, error) } } +// ErrTxStarted is returned when trying to start a new transaction from a transactional client. +var ErrTxStarted = errors.New("ent: cannot start a transaction within a transaction") + // Tx returns a new transactional client. The provided context // is used until the transaction is committed or rolled back. func (c *Client) Tx(ctx context.Context) (*Tx, error) { if _, ok := c.driver.(*txDriver); ok { - return nil, errors.New("ent: cannot start a transaction within a transaction") + return nil, ErrTxStarted } tx, err := newTx(ctx, c.driver) if err != nil { @@ -156,13 +209,43 @@ func (c *Client) Close() error { // Use adds the mutation hooks to all the entity clients. // In order to add hooks to a specific client, call: `client.Node.Use(...)`. func (c *Client) Use(hooks ...Hook) { - c.Alert.Use(hooks...) - c.Bouncer.Use(hooks...) - c.ConfigItem.Use(hooks...) - c.Decision.Use(hooks...) - c.Event.Use(hooks...) - c.Machine.Use(hooks...) - c.Meta.Use(hooks...) + for _, n := range []interface{ Use(...Hook) }{ + c.Alert, c.Bouncer, c.ConfigItem, c.Decision, c.Event, c.Machine, c.Meta, + } { + n.Use(hooks...) + } +} + +// Intercept adds the query interceptors to all the entity clients. +// In order to add interceptors to a specific client, call: `client.Node.Intercept(...)`. +func (c *Client) Intercept(interceptors ...Interceptor) { + for _, n := range []interface{ Intercept(...Interceptor) }{ + c.Alert, c.Bouncer, c.ConfigItem, c.Decision, c.Event, c.Machine, c.Meta, + } { + n.Intercept(interceptors...) + } +} + +// Mutate implements the ent.Mutator interface. +func (c *Client) Mutate(ctx context.Context, m Mutation) (Value, error) { + switch m := m.(type) { + case *AlertMutation: + return c.Alert.mutate(ctx, m) + case *BouncerMutation: + return c.Bouncer.mutate(ctx, m) + case *ConfigItemMutation: + return c.ConfigItem.mutate(ctx, m) + case *DecisionMutation: + return c.Decision.mutate(ctx, m) + case *EventMutation: + return c.Event.mutate(ctx, m) + case *MachineMutation: + return c.Machine.mutate(ctx, m) + case *MetaMutation: + return c.Meta.mutate(ctx, m) + default: + return nil, fmt.Errorf("ent: unknown mutation type %T", m) + } } // AlertClient is a client for the Alert schema. @@ -181,6 +264,12 @@ func (c *AlertClient) Use(hooks ...Hook) { c.hooks.Alert = append(c.hooks.Alert, hooks...) } +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `alert.Intercept(f(g(h())))`. +func (c *AlertClient) Intercept(interceptors ...Interceptor) { + c.inters.Alert = append(c.inters.Alert, interceptors...) +} + // Create returns a builder for creating a Alert entity. func (c *AlertClient) Create() *AlertCreate { mutation := newAlertMutation(c.config, OpCreate) @@ -192,6 +281,21 @@ func (c *AlertClient) CreateBulk(builders ...*AlertCreate) *AlertCreateBulk { return &AlertCreateBulk{config: c.config, builders: builders} } +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *AlertClient) MapCreateBulk(slice any, setFunc func(*AlertCreate, int)) *AlertCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &AlertCreateBulk{err: fmt.Errorf("calling to AlertClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*AlertCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &AlertCreateBulk{config: c.config, builders: builders} +} + // Update returns an update builder for Alert. func (c *AlertClient) Update() *AlertUpdate { mutation := newAlertMutation(c.config, OpUpdate) @@ -221,7 +325,7 @@ func (c *AlertClient) DeleteOne(a *Alert) *AlertDeleteOne { return c.DeleteOneID(a.ID) } -// DeleteOne returns a builder for deleting the given entity by its id. +// DeleteOneID returns a builder for deleting the given entity by its id. func (c *AlertClient) DeleteOneID(id int) *AlertDeleteOne { builder := c.Delete().Where(alert.ID(id)) builder.mutation.id = &id @@ -233,6 +337,8 @@ func (c *AlertClient) DeleteOneID(id int) *AlertDeleteOne { func (c *AlertClient) Query() *AlertQuery { return &AlertQuery{ config: c.config, + ctx: &QueryContext{Type: TypeAlert}, + inters: c.Interceptors(), } } @@ -252,8 +358,8 @@ func (c *AlertClient) GetX(ctx context.Context, id int) *Alert { // QueryOwner queries the owner edge of a Alert. func (c *AlertClient) QueryOwner(a *Alert) *MachineQuery { - query := &MachineQuery{config: c.config} - query.path = func(ctx context.Context) (fromV *sql.Selector, _ error) { + query := (&MachineClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { id := a.ID step := sqlgraph.NewStep( sqlgraph.From(alert.Table, alert.FieldID, id), @@ -268,8 +374,8 @@ func (c *AlertClient) QueryOwner(a *Alert) *MachineQuery { // QueryDecisions queries the decisions edge of a Alert. func (c *AlertClient) QueryDecisions(a *Alert) *DecisionQuery { - query := &DecisionQuery{config: c.config} - query.path = func(ctx context.Context) (fromV *sql.Selector, _ error) { + query := (&DecisionClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { id := a.ID step := sqlgraph.NewStep( sqlgraph.From(alert.Table, alert.FieldID, id), @@ -284,8 +390,8 @@ func (c *AlertClient) QueryDecisions(a *Alert) *DecisionQuery { // QueryEvents queries the events edge of a Alert. func (c *AlertClient) QueryEvents(a *Alert) *EventQuery { - query := &EventQuery{config: c.config} - query.path = func(ctx context.Context) (fromV *sql.Selector, _ error) { + query := (&EventClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { id := a.ID step := sqlgraph.NewStep( sqlgraph.From(alert.Table, alert.FieldID, id), @@ -300,8 +406,8 @@ func (c *AlertClient) QueryEvents(a *Alert) *EventQuery { // QueryMetas queries the metas edge of a Alert. func (c *AlertClient) QueryMetas(a *Alert) *MetaQuery { - query := &MetaQuery{config: c.config} - query.path = func(ctx context.Context) (fromV *sql.Selector, _ error) { + query := (&MetaClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { id := a.ID step := sqlgraph.NewStep( sqlgraph.From(alert.Table, alert.FieldID, id), @@ -319,6 +425,26 @@ func (c *AlertClient) Hooks() []Hook { return c.hooks.Alert } +// Interceptors returns the client interceptors. +func (c *AlertClient) Interceptors() []Interceptor { + return c.inters.Alert +} + +func (c *AlertClient) mutate(ctx context.Context, m *AlertMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&AlertCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&AlertUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&AlertUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&AlertDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown Alert mutation op: %q", m.Op()) + } +} + // BouncerClient is a client for the Bouncer schema. type BouncerClient struct { config @@ -335,6 +461,12 @@ func (c *BouncerClient) Use(hooks ...Hook) { c.hooks.Bouncer = append(c.hooks.Bouncer, hooks...) } +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `bouncer.Intercept(f(g(h())))`. +func (c *BouncerClient) Intercept(interceptors ...Interceptor) { + c.inters.Bouncer = append(c.inters.Bouncer, interceptors...) +} + // Create returns a builder for creating a Bouncer entity. func (c *BouncerClient) Create() *BouncerCreate { mutation := newBouncerMutation(c.config, OpCreate) @@ -346,6 +478,21 @@ func (c *BouncerClient) CreateBulk(builders ...*BouncerCreate) *BouncerCreateBul return &BouncerCreateBulk{config: c.config, builders: builders} } +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *BouncerClient) MapCreateBulk(slice any, setFunc func(*BouncerCreate, int)) *BouncerCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &BouncerCreateBulk{err: fmt.Errorf("calling to BouncerClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*BouncerCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &BouncerCreateBulk{config: c.config, builders: builders} +} + // Update returns an update builder for Bouncer. func (c *BouncerClient) Update() *BouncerUpdate { mutation := newBouncerMutation(c.config, OpUpdate) @@ -375,7 +522,7 @@ func (c *BouncerClient) DeleteOne(b *Bouncer) *BouncerDeleteOne { return c.DeleteOneID(b.ID) } -// DeleteOne returns a builder for deleting the given entity by its id. +// DeleteOneID returns a builder for deleting the given entity by its id. func (c *BouncerClient) DeleteOneID(id int) *BouncerDeleteOne { builder := c.Delete().Where(bouncer.ID(id)) builder.mutation.id = &id @@ -387,6 +534,8 @@ func (c *BouncerClient) DeleteOneID(id int) *BouncerDeleteOne { func (c *BouncerClient) Query() *BouncerQuery { return &BouncerQuery{ config: c.config, + ctx: &QueryContext{Type: TypeBouncer}, + inters: c.Interceptors(), } } @@ -409,6 +558,26 @@ func (c *BouncerClient) Hooks() []Hook { return c.hooks.Bouncer } +// Interceptors returns the client interceptors. +func (c *BouncerClient) Interceptors() []Interceptor { + return c.inters.Bouncer +} + +func (c *BouncerClient) mutate(ctx context.Context, m *BouncerMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&BouncerCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&BouncerUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&BouncerUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&BouncerDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown Bouncer mutation op: %q", m.Op()) + } +} + // ConfigItemClient is a client for the ConfigItem schema. type ConfigItemClient struct { config @@ -425,6 +594,12 @@ func (c *ConfigItemClient) Use(hooks ...Hook) { c.hooks.ConfigItem = append(c.hooks.ConfigItem, hooks...) } +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `configitem.Intercept(f(g(h())))`. +func (c *ConfigItemClient) Intercept(interceptors ...Interceptor) { + c.inters.ConfigItem = append(c.inters.ConfigItem, interceptors...) +} + // Create returns a builder for creating a ConfigItem entity. func (c *ConfigItemClient) Create() *ConfigItemCreate { mutation := newConfigItemMutation(c.config, OpCreate) @@ -436,6 +611,21 @@ func (c *ConfigItemClient) CreateBulk(builders ...*ConfigItemCreate) *ConfigItem return &ConfigItemCreateBulk{config: c.config, builders: builders} } +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *ConfigItemClient) MapCreateBulk(slice any, setFunc func(*ConfigItemCreate, int)) *ConfigItemCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &ConfigItemCreateBulk{err: fmt.Errorf("calling to ConfigItemClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*ConfigItemCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &ConfigItemCreateBulk{config: c.config, builders: builders} +} + // Update returns an update builder for ConfigItem. func (c *ConfigItemClient) Update() *ConfigItemUpdate { mutation := newConfigItemMutation(c.config, OpUpdate) @@ -465,7 +655,7 @@ func (c *ConfigItemClient) DeleteOne(ci *ConfigItem) *ConfigItemDeleteOne { return c.DeleteOneID(ci.ID) } -// DeleteOne returns a builder for deleting the given entity by its id. +// DeleteOneID returns a builder for deleting the given entity by its id. func (c *ConfigItemClient) DeleteOneID(id int) *ConfigItemDeleteOne { builder := c.Delete().Where(configitem.ID(id)) builder.mutation.id = &id @@ -477,6 +667,8 @@ func (c *ConfigItemClient) DeleteOneID(id int) *ConfigItemDeleteOne { func (c *ConfigItemClient) Query() *ConfigItemQuery { return &ConfigItemQuery{ config: c.config, + ctx: &QueryContext{Type: TypeConfigItem}, + inters: c.Interceptors(), } } @@ -499,6 +691,26 @@ func (c *ConfigItemClient) Hooks() []Hook { return c.hooks.ConfigItem } +// Interceptors returns the client interceptors. +func (c *ConfigItemClient) Interceptors() []Interceptor { + return c.inters.ConfigItem +} + +func (c *ConfigItemClient) mutate(ctx context.Context, m *ConfigItemMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&ConfigItemCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&ConfigItemUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&ConfigItemUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&ConfigItemDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown ConfigItem mutation op: %q", m.Op()) + } +} + // DecisionClient is a client for the Decision schema. type DecisionClient struct { config @@ -515,6 +727,12 @@ func (c *DecisionClient) Use(hooks ...Hook) { c.hooks.Decision = append(c.hooks.Decision, hooks...) } +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `decision.Intercept(f(g(h())))`. +func (c *DecisionClient) Intercept(interceptors ...Interceptor) { + c.inters.Decision = append(c.inters.Decision, interceptors...) +} + // Create returns a builder for creating a Decision entity. func (c *DecisionClient) Create() *DecisionCreate { mutation := newDecisionMutation(c.config, OpCreate) @@ -526,6 +744,21 @@ func (c *DecisionClient) CreateBulk(builders ...*DecisionCreate) *DecisionCreate return &DecisionCreateBulk{config: c.config, builders: builders} } +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *DecisionClient) MapCreateBulk(slice any, setFunc func(*DecisionCreate, int)) *DecisionCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &DecisionCreateBulk{err: fmt.Errorf("calling to DecisionClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*DecisionCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &DecisionCreateBulk{config: c.config, builders: builders} +} + // Update returns an update builder for Decision. func (c *DecisionClient) Update() *DecisionUpdate { mutation := newDecisionMutation(c.config, OpUpdate) @@ -555,7 +788,7 @@ func (c *DecisionClient) DeleteOne(d *Decision) *DecisionDeleteOne { return c.DeleteOneID(d.ID) } -// DeleteOne returns a builder for deleting the given entity by its id. +// DeleteOneID returns a builder for deleting the given entity by its id. func (c *DecisionClient) DeleteOneID(id int) *DecisionDeleteOne { builder := c.Delete().Where(decision.ID(id)) builder.mutation.id = &id @@ -567,6 +800,8 @@ func (c *DecisionClient) DeleteOneID(id int) *DecisionDeleteOne { func (c *DecisionClient) Query() *DecisionQuery { return &DecisionQuery{ config: c.config, + ctx: &QueryContext{Type: TypeDecision}, + inters: c.Interceptors(), } } @@ -586,8 +821,8 @@ func (c *DecisionClient) GetX(ctx context.Context, id int) *Decision { // QueryOwner queries the owner edge of a Decision. func (c *DecisionClient) QueryOwner(d *Decision) *AlertQuery { - query := &AlertQuery{config: c.config} - query.path = func(ctx context.Context) (fromV *sql.Selector, _ error) { + query := (&AlertClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { id := d.ID step := sqlgraph.NewStep( sqlgraph.From(decision.Table, decision.FieldID, id), @@ -605,6 +840,26 @@ func (c *DecisionClient) Hooks() []Hook { return c.hooks.Decision } +// Interceptors returns the client interceptors. +func (c *DecisionClient) Interceptors() []Interceptor { + return c.inters.Decision +} + +func (c *DecisionClient) mutate(ctx context.Context, m *DecisionMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&DecisionCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&DecisionUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&DecisionUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&DecisionDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown Decision mutation op: %q", m.Op()) + } +} + // EventClient is a client for the Event schema. type EventClient struct { config @@ -621,6 +876,12 @@ func (c *EventClient) Use(hooks ...Hook) { c.hooks.Event = append(c.hooks.Event, hooks...) } +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `event.Intercept(f(g(h())))`. +func (c *EventClient) Intercept(interceptors ...Interceptor) { + c.inters.Event = append(c.inters.Event, interceptors...) +} + // Create returns a builder for creating a Event entity. func (c *EventClient) Create() *EventCreate { mutation := newEventMutation(c.config, OpCreate) @@ -632,6 +893,21 @@ func (c *EventClient) CreateBulk(builders ...*EventCreate) *EventCreateBulk { return &EventCreateBulk{config: c.config, builders: builders} } +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *EventClient) MapCreateBulk(slice any, setFunc func(*EventCreate, int)) *EventCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &EventCreateBulk{err: fmt.Errorf("calling to EventClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*EventCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &EventCreateBulk{config: c.config, builders: builders} +} + // Update returns an update builder for Event. func (c *EventClient) Update() *EventUpdate { mutation := newEventMutation(c.config, OpUpdate) @@ -661,7 +937,7 @@ func (c *EventClient) DeleteOne(e *Event) *EventDeleteOne { return c.DeleteOneID(e.ID) } -// DeleteOne returns a builder for deleting the given entity by its id. +// DeleteOneID returns a builder for deleting the given entity by its id. func (c *EventClient) DeleteOneID(id int) *EventDeleteOne { builder := c.Delete().Where(event.ID(id)) builder.mutation.id = &id @@ -673,6 +949,8 @@ func (c *EventClient) DeleteOneID(id int) *EventDeleteOne { func (c *EventClient) Query() *EventQuery { return &EventQuery{ config: c.config, + ctx: &QueryContext{Type: TypeEvent}, + inters: c.Interceptors(), } } @@ -692,8 +970,8 @@ func (c *EventClient) GetX(ctx context.Context, id int) *Event { // QueryOwner queries the owner edge of a Event. func (c *EventClient) QueryOwner(e *Event) *AlertQuery { - query := &AlertQuery{config: c.config} - query.path = func(ctx context.Context) (fromV *sql.Selector, _ error) { + query := (&AlertClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { id := e.ID step := sqlgraph.NewStep( sqlgraph.From(event.Table, event.FieldID, id), @@ -711,6 +989,26 @@ func (c *EventClient) Hooks() []Hook { return c.hooks.Event } +// Interceptors returns the client interceptors. +func (c *EventClient) Interceptors() []Interceptor { + return c.inters.Event +} + +func (c *EventClient) mutate(ctx context.Context, m *EventMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&EventCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&EventUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&EventUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&EventDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown Event mutation op: %q", m.Op()) + } +} + // MachineClient is a client for the Machine schema. type MachineClient struct { config @@ -727,6 +1025,12 @@ func (c *MachineClient) Use(hooks ...Hook) { c.hooks.Machine = append(c.hooks.Machine, hooks...) } +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `machine.Intercept(f(g(h())))`. +func (c *MachineClient) Intercept(interceptors ...Interceptor) { + c.inters.Machine = append(c.inters.Machine, interceptors...) +} + // Create returns a builder for creating a Machine entity. func (c *MachineClient) Create() *MachineCreate { mutation := newMachineMutation(c.config, OpCreate) @@ -738,6 +1042,21 @@ func (c *MachineClient) CreateBulk(builders ...*MachineCreate) *MachineCreateBul return &MachineCreateBulk{config: c.config, builders: builders} } +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *MachineClient) MapCreateBulk(slice any, setFunc func(*MachineCreate, int)) *MachineCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &MachineCreateBulk{err: fmt.Errorf("calling to MachineClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*MachineCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &MachineCreateBulk{config: c.config, builders: builders} +} + // Update returns an update builder for Machine. func (c *MachineClient) Update() *MachineUpdate { mutation := newMachineMutation(c.config, OpUpdate) @@ -767,7 +1086,7 @@ func (c *MachineClient) DeleteOne(m *Machine) *MachineDeleteOne { return c.DeleteOneID(m.ID) } -// DeleteOne returns a builder for deleting the given entity by its id. +// DeleteOneID returns a builder for deleting the given entity by its id. func (c *MachineClient) DeleteOneID(id int) *MachineDeleteOne { builder := c.Delete().Where(machine.ID(id)) builder.mutation.id = &id @@ -779,6 +1098,8 @@ func (c *MachineClient) DeleteOneID(id int) *MachineDeleteOne { func (c *MachineClient) Query() *MachineQuery { return &MachineQuery{ config: c.config, + ctx: &QueryContext{Type: TypeMachine}, + inters: c.Interceptors(), } } @@ -798,8 +1119,8 @@ func (c *MachineClient) GetX(ctx context.Context, id int) *Machine { // QueryAlerts queries the alerts edge of a Machine. func (c *MachineClient) QueryAlerts(m *Machine) *AlertQuery { - query := &AlertQuery{config: c.config} - query.path = func(ctx context.Context) (fromV *sql.Selector, _ error) { + query := (&AlertClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { id := m.ID step := sqlgraph.NewStep( sqlgraph.From(machine.Table, machine.FieldID, id), @@ -817,6 +1138,26 @@ func (c *MachineClient) Hooks() []Hook { return c.hooks.Machine } +// Interceptors returns the client interceptors. +func (c *MachineClient) Interceptors() []Interceptor { + return c.inters.Machine +} + +func (c *MachineClient) mutate(ctx context.Context, m *MachineMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&MachineCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&MachineUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&MachineUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&MachineDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown Machine mutation op: %q", m.Op()) + } +} + // MetaClient is a client for the Meta schema. type MetaClient struct { config @@ -833,6 +1174,12 @@ func (c *MetaClient) Use(hooks ...Hook) { c.hooks.Meta = append(c.hooks.Meta, hooks...) } +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `meta.Intercept(f(g(h())))`. +func (c *MetaClient) Intercept(interceptors ...Interceptor) { + c.inters.Meta = append(c.inters.Meta, interceptors...) +} + // Create returns a builder for creating a Meta entity. func (c *MetaClient) Create() *MetaCreate { mutation := newMetaMutation(c.config, OpCreate) @@ -844,6 +1191,21 @@ func (c *MetaClient) CreateBulk(builders ...*MetaCreate) *MetaCreateBulk { return &MetaCreateBulk{config: c.config, builders: builders} } +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *MetaClient) MapCreateBulk(slice any, setFunc func(*MetaCreate, int)) *MetaCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &MetaCreateBulk{err: fmt.Errorf("calling to MetaClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*MetaCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &MetaCreateBulk{config: c.config, builders: builders} +} + // Update returns an update builder for Meta. func (c *MetaClient) Update() *MetaUpdate { mutation := newMetaMutation(c.config, OpUpdate) @@ -873,7 +1235,7 @@ func (c *MetaClient) DeleteOne(m *Meta) *MetaDeleteOne { return c.DeleteOneID(m.ID) } -// DeleteOne returns a builder for deleting the given entity by its id. +// DeleteOneID returns a builder for deleting the given entity by its id. func (c *MetaClient) DeleteOneID(id int) *MetaDeleteOne { builder := c.Delete().Where(meta.ID(id)) builder.mutation.id = &id @@ -885,6 +1247,8 @@ func (c *MetaClient) DeleteOneID(id int) *MetaDeleteOne { func (c *MetaClient) Query() *MetaQuery { return &MetaQuery{ config: c.config, + ctx: &QueryContext{Type: TypeMeta}, + inters: c.Interceptors(), } } @@ -904,8 +1268,8 @@ func (c *MetaClient) GetX(ctx context.Context, id int) *Meta { // QueryOwner queries the owner edge of a Meta. func (c *MetaClient) QueryOwner(m *Meta) *AlertQuery { - query := &AlertQuery{config: c.config} - query.path = func(ctx context.Context) (fromV *sql.Selector, _ error) { + query := (&AlertClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { id := m.ID step := sqlgraph.NewStep( sqlgraph.From(meta.Table, meta.FieldID, id), @@ -922,3 +1286,33 @@ func (c *MetaClient) QueryOwner(m *Meta) *AlertQuery { func (c *MetaClient) Hooks() []Hook { return c.hooks.Meta } + +// Interceptors returns the client interceptors. +func (c *MetaClient) Interceptors() []Interceptor { + return c.inters.Meta +} + +func (c *MetaClient) mutate(ctx context.Context, m *MetaMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&MetaCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&MetaUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&MetaUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&MetaDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown Meta mutation op: %q", m.Op()) + } +} + +// hooks and interceptors per client, for fast access. +type ( + hooks struct { + Alert, Bouncer, ConfigItem, Decision, Event, Machine, Meta []ent.Hook + } + inters struct { + Alert, Bouncer, ConfigItem, Decision, Event, Machine, Meta []ent.Interceptor + } +) diff --git a/pkg/database/ent/config.go b/pkg/database/ent/config.go deleted file mode 100644 index 1a152809a32..00000000000 --- a/pkg/database/ent/config.go +++ /dev/null @@ -1,65 +0,0 @@ -// Code generated by ent, DO NOT EDIT. - -package ent - -import ( - "entgo.io/ent" - "entgo.io/ent/dialect" -) - -// Option function to configure the client. -type Option func(*config) - -// Config is the configuration for the client and its builder. -type config struct { - // driver used for executing database requests. - driver dialect.Driver - // debug enable a debug logging. - debug bool - // log used for logging on debug mode. - log func(...any) - // hooks to execute on mutations. - hooks *hooks -} - -// hooks per client, for fast access. -type hooks struct { - Alert []ent.Hook - Bouncer []ent.Hook - ConfigItem []ent.Hook - Decision []ent.Hook - Event []ent.Hook - Machine []ent.Hook - Meta []ent.Hook -} - -// Options applies the options on the config object. -func (c *config) options(opts ...Option) { - for _, opt := range opts { - opt(c) - } - if c.debug { - c.driver = dialect.Debug(c.driver, c.log) - } -} - -// Debug enables debug logging on the ent.Driver. -func Debug() Option { - return func(c *config) { - c.debug = true - } -} - -// Log sets the logging function for debug mode. -func Log(fn func(...any)) Option { - return func(c *config) { - c.log = fn - } -} - -// Driver configures the client driver. -func Driver(driver dialect.Driver) Option { - return func(c *config) { - c.driver = driver - } -} diff --git a/pkg/database/ent/configitem.go b/pkg/database/ent/configitem.go index 615780dbacc..467e54386f6 100644 --- a/pkg/database/ent/configitem.go +++ b/pkg/database/ent/configitem.go @@ -7,6 +7,7 @@ import ( "strings" "time" + "entgo.io/ent" "entgo.io/ent/dialect/sql" "github.com/crowdsecurity/crowdsec/pkg/database/ent/configitem" ) @@ -23,7 +24,8 @@ type ConfigItem struct { // Name holds the value of the "name" field. Name string `json:"name"` // Value holds the value of the "value" field. - Value string `json:"value"` + Value string `json:"value"` + selectValues sql.SelectValues } // scanValues returns the types for scanning values from sql.Rows. @@ -38,7 +40,7 @@ func (*ConfigItem) scanValues(columns []string) ([]any, error) { case configitem.FieldCreatedAt, configitem.FieldUpdatedAt: values[i] = new(sql.NullTime) default: - return nil, fmt.Errorf("unexpected column %q for type ConfigItem", columns[i]) + values[i] = new(sql.UnknownType) } } return values, nil @@ -84,16 +86,24 @@ func (ci *ConfigItem) assignValues(columns []string, values []any) error { } else if value.Valid { ci.Value = value.String } + default: + ci.selectValues.Set(columns[i], values[i]) } } return nil } +// GetValue returns the ent.Value that was dynamically selected and assigned to the ConfigItem. +// This includes values selected through modifiers, order, etc. +func (ci *ConfigItem) GetValue(name string) (ent.Value, error) { + return ci.selectValues.Get(name) +} + // Update returns a builder for updating this ConfigItem. // Note that you need to call ConfigItem.Unwrap() before calling this method if this ConfigItem // was returned from a transaction, and the transaction was committed or rolled back. func (ci *ConfigItem) Update() *ConfigItemUpdateOne { - return (&ConfigItemClient{config: ci.config}).UpdateOne(ci) + return NewConfigItemClient(ci.config).UpdateOne(ci) } // Unwrap unwraps the ConfigItem entity that was returned from a transaction after it was closed, @@ -133,9 +143,3 @@ func (ci *ConfigItem) String() string { // ConfigItems is a parsable slice of ConfigItem. type ConfigItems []*ConfigItem - -func (ci ConfigItems) config(cfg config) { - for _i := range ci { - ci[_i].config = cfg - } -} diff --git a/pkg/database/ent/configitem/configitem.go b/pkg/database/ent/configitem/configitem.go index 80e93e4cc7e..a6ff6c32d57 100644 --- a/pkg/database/ent/configitem/configitem.go +++ b/pkg/database/ent/configitem/configitem.go @@ -4,6 +4,8 @@ package configitem import ( "time" + + "entgo.io/ent/dialect/sql" ) const ( @@ -52,3 +54,31 @@ var ( // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field. UpdateDefaultUpdatedAt func() time.Time ) + +// OrderOption defines the ordering options for the ConfigItem queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByCreatedAt orders the results by the created_at field. +func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() +} + +// ByUpdatedAt orders the results by the updated_at field. +func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc() +} + +// ByName orders the results by the name field. +func ByName(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldName, opts...).ToFunc() +} + +// ByValue orders the results by the value field. +func ByValue(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldValue, opts...).ToFunc() +} diff --git a/pkg/database/ent/configitem/where.go b/pkg/database/ent/configitem/where.go index 6d06938a855..767f0b420f1 100644 --- a/pkg/database/ent/configitem/where.go +++ b/pkg/database/ent/configitem/where.go @@ -11,485 +11,310 @@ import ( // ID filters vertices based on their ID field. func ID(id int) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.ConfigItem(sql.FieldEQ(FieldID, id)) } // IDEQ applies the EQ predicate on the ID field. func IDEQ(id int) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.ConfigItem(sql.FieldEQ(FieldID, id)) } // IDNEQ applies the NEQ predicate on the ID field. func IDNEQ(id int) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldID), id)) - }) + return predicate.ConfigItem(sql.FieldNEQ(FieldID, id)) } // IDIn applies the In predicate on the ID field. func IDIn(ids ...int) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - v := make([]any, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.In(s.C(FieldID), v...)) - }) + return predicate.ConfigItem(sql.FieldIn(FieldID, ids...)) } // IDNotIn applies the NotIn predicate on the ID field. func IDNotIn(ids ...int) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - v := make([]any, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.NotIn(s.C(FieldID), v...)) - }) + return predicate.ConfigItem(sql.FieldNotIn(FieldID, ids...)) } // IDGT applies the GT predicate on the ID field. func IDGT(id int) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldID), id)) - }) + return predicate.ConfigItem(sql.FieldGT(FieldID, id)) } // IDGTE applies the GTE predicate on the ID field. func IDGTE(id int) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldID), id)) - }) + return predicate.ConfigItem(sql.FieldGTE(FieldID, id)) } // IDLT applies the LT predicate on the ID field. func IDLT(id int) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldID), id)) - }) + return predicate.ConfigItem(sql.FieldLT(FieldID, id)) } // IDLTE applies the LTE predicate on the ID field. func IDLTE(id int) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldID), id)) - }) + return predicate.ConfigItem(sql.FieldLTE(FieldID, id)) } // CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. func CreatedAt(v time.Time) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldCreatedAt), v)) - }) + return predicate.ConfigItem(sql.FieldEQ(FieldCreatedAt, v)) } // UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ. func UpdatedAt(v time.Time) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldUpdatedAt), v)) - }) + return predicate.ConfigItem(sql.FieldEQ(FieldUpdatedAt, v)) } // Name applies equality check predicate on the "name" field. It's identical to NameEQ. func Name(v string) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldName), v)) - }) + return predicate.ConfigItem(sql.FieldEQ(FieldName, v)) } // Value applies equality check predicate on the "value" field. It's identical to ValueEQ. func Value(v string) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldValue), v)) - }) + return predicate.ConfigItem(sql.FieldEQ(FieldValue, v)) } // CreatedAtEQ applies the EQ predicate on the "created_at" field. func CreatedAtEQ(v time.Time) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldCreatedAt), v)) - }) + return predicate.ConfigItem(sql.FieldEQ(FieldCreatedAt, v)) } // CreatedAtNEQ applies the NEQ predicate on the "created_at" field. func CreatedAtNEQ(v time.Time) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldCreatedAt), v)) - }) + return predicate.ConfigItem(sql.FieldNEQ(FieldCreatedAt, v)) } // CreatedAtIn applies the In predicate on the "created_at" field. func CreatedAtIn(vs ...time.Time) predicate.ConfigItem { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldCreatedAt), v...)) - }) + return predicate.ConfigItem(sql.FieldIn(FieldCreatedAt, vs...)) } // CreatedAtNotIn applies the NotIn predicate on the "created_at" field. func CreatedAtNotIn(vs ...time.Time) predicate.ConfigItem { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldCreatedAt), v...)) - }) + return predicate.ConfigItem(sql.FieldNotIn(FieldCreatedAt, vs...)) } // CreatedAtGT applies the GT predicate on the "created_at" field. func CreatedAtGT(v time.Time) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldCreatedAt), v)) - }) + return predicate.ConfigItem(sql.FieldGT(FieldCreatedAt, v)) } // CreatedAtGTE applies the GTE predicate on the "created_at" field. func CreatedAtGTE(v time.Time) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldCreatedAt), v)) - }) + return predicate.ConfigItem(sql.FieldGTE(FieldCreatedAt, v)) } // CreatedAtLT applies the LT predicate on the "created_at" field. func CreatedAtLT(v time.Time) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldCreatedAt), v)) - }) + return predicate.ConfigItem(sql.FieldLT(FieldCreatedAt, v)) } // CreatedAtLTE applies the LTE predicate on the "created_at" field. func CreatedAtLTE(v time.Time) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldCreatedAt), v)) - }) + return predicate.ConfigItem(sql.FieldLTE(FieldCreatedAt, v)) } // CreatedAtIsNil applies the IsNil predicate on the "created_at" field. func CreatedAtIsNil() predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldCreatedAt))) - }) + return predicate.ConfigItem(sql.FieldIsNull(FieldCreatedAt)) } // CreatedAtNotNil applies the NotNil predicate on the "created_at" field. func CreatedAtNotNil() predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldCreatedAt))) - }) + return predicate.ConfigItem(sql.FieldNotNull(FieldCreatedAt)) } // UpdatedAtEQ applies the EQ predicate on the "updated_at" field. func UpdatedAtEQ(v time.Time) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldUpdatedAt), v)) - }) + return predicate.ConfigItem(sql.FieldEQ(FieldUpdatedAt, v)) } // UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field. func UpdatedAtNEQ(v time.Time) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldUpdatedAt), v)) - }) + return predicate.ConfigItem(sql.FieldNEQ(FieldUpdatedAt, v)) } // UpdatedAtIn applies the In predicate on the "updated_at" field. func UpdatedAtIn(vs ...time.Time) predicate.ConfigItem { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldUpdatedAt), v...)) - }) + return predicate.ConfigItem(sql.FieldIn(FieldUpdatedAt, vs...)) } // UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field. func UpdatedAtNotIn(vs ...time.Time) predicate.ConfigItem { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldUpdatedAt), v...)) - }) + return predicate.ConfigItem(sql.FieldNotIn(FieldUpdatedAt, vs...)) } // UpdatedAtGT applies the GT predicate on the "updated_at" field. func UpdatedAtGT(v time.Time) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldUpdatedAt), v)) - }) + return predicate.ConfigItem(sql.FieldGT(FieldUpdatedAt, v)) } // UpdatedAtGTE applies the GTE predicate on the "updated_at" field. func UpdatedAtGTE(v time.Time) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldUpdatedAt), v)) - }) + return predicate.ConfigItem(sql.FieldGTE(FieldUpdatedAt, v)) } // UpdatedAtLT applies the LT predicate on the "updated_at" field. func UpdatedAtLT(v time.Time) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldUpdatedAt), v)) - }) + return predicate.ConfigItem(sql.FieldLT(FieldUpdatedAt, v)) } // UpdatedAtLTE applies the LTE predicate on the "updated_at" field. func UpdatedAtLTE(v time.Time) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldUpdatedAt), v)) - }) + return predicate.ConfigItem(sql.FieldLTE(FieldUpdatedAt, v)) } // UpdatedAtIsNil applies the IsNil predicate on the "updated_at" field. func UpdatedAtIsNil() predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldUpdatedAt))) - }) + return predicate.ConfigItem(sql.FieldIsNull(FieldUpdatedAt)) } // UpdatedAtNotNil applies the NotNil predicate on the "updated_at" field. func UpdatedAtNotNil() predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldUpdatedAt))) - }) + return predicate.ConfigItem(sql.FieldNotNull(FieldUpdatedAt)) } // NameEQ applies the EQ predicate on the "name" field. func NameEQ(v string) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldName), v)) - }) + return predicate.ConfigItem(sql.FieldEQ(FieldName, v)) } // NameNEQ applies the NEQ predicate on the "name" field. func NameNEQ(v string) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldName), v)) - }) + return predicate.ConfigItem(sql.FieldNEQ(FieldName, v)) } // NameIn applies the In predicate on the "name" field. func NameIn(vs ...string) predicate.ConfigItem { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldName), v...)) - }) + return predicate.ConfigItem(sql.FieldIn(FieldName, vs...)) } // NameNotIn applies the NotIn predicate on the "name" field. func NameNotIn(vs ...string) predicate.ConfigItem { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldName), v...)) - }) + return predicate.ConfigItem(sql.FieldNotIn(FieldName, vs...)) } // NameGT applies the GT predicate on the "name" field. func NameGT(v string) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldName), v)) - }) + return predicate.ConfigItem(sql.FieldGT(FieldName, v)) } // NameGTE applies the GTE predicate on the "name" field. func NameGTE(v string) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldName), v)) - }) + return predicate.ConfigItem(sql.FieldGTE(FieldName, v)) } // NameLT applies the LT predicate on the "name" field. func NameLT(v string) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldName), v)) - }) + return predicate.ConfigItem(sql.FieldLT(FieldName, v)) } // NameLTE applies the LTE predicate on the "name" field. func NameLTE(v string) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldName), v)) - }) + return predicate.ConfigItem(sql.FieldLTE(FieldName, v)) } // NameContains applies the Contains predicate on the "name" field. func NameContains(v string) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldName), v)) - }) + return predicate.ConfigItem(sql.FieldContains(FieldName, v)) } // NameHasPrefix applies the HasPrefix predicate on the "name" field. func NameHasPrefix(v string) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldName), v)) - }) + return predicate.ConfigItem(sql.FieldHasPrefix(FieldName, v)) } // NameHasSuffix applies the HasSuffix predicate on the "name" field. func NameHasSuffix(v string) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldName), v)) - }) + return predicate.ConfigItem(sql.FieldHasSuffix(FieldName, v)) } // NameEqualFold applies the EqualFold predicate on the "name" field. func NameEqualFold(v string) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldName), v)) - }) + return predicate.ConfigItem(sql.FieldEqualFold(FieldName, v)) } // NameContainsFold applies the ContainsFold predicate on the "name" field. func NameContainsFold(v string) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldName), v)) - }) + return predicate.ConfigItem(sql.FieldContainsFold(FieldName, v)) } // ValueEQ applies the EQ predicate on the "value" field. func ValueEQ(v string) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldValue), v)) - }) + return predicate.ConfigItem(sql.FieldEQ(FieldValue, v)) } // ValueNEQ applies the NEQ predicate on the "value" field. func ValueNEQ(v string) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldValue), v)) - }) + return predicate.ConfigItem(sql.FieldNEQ(FieldValue, v)) } // ValueIn applies the In predicate on the "value" field. func ValueIn(vs ...string) predicate.ConfigItem { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldValue), v...)) - }) + return predicate.ConfigItem(sql.FieldIn(FieldValue, vs...)) } // ValueNotIn applies the NotIn predicate on the "value" field. func ValueNotIn(vs ...string) predicate.ConfigItem { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldValue), v...)) - }) + return predicate.ConfigItem(sql.FieldNotIn(FieldValue, vs...)) } // ValueGT applies the GT predicate on the "value" field. func ValueGT(v string) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldValue), v)) - }) + return predicate.ConfigItem(sql.FieldGT(FieldValue, v)) } // ValueGTE applies the GTE predicate on the "value" field. func ValueGTE(v string) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldValue), v)) - }) + return predicate.ConfigItem(sql.FieldGTE(FieldValue, v)) } // ValueLT applies the LT predicate on the "value" field. func ValueLT(v string) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldValue), v)) - }) + return predicate.ConfigItem(sql.FieldLT(FieldValue, v)) } // ValueLTE applies the LTE predicate on the "value" field. func ValueLTE(v string) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldValue), v)) - }) + return predicate.ConfigItem(sql.FieldLTE(FieldValue, v)) } // ValueContains applies the Contains predicate on the "value" field. func ValueContains(v string) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldValue), v)) - }) + return predicate.ConfigItem(sql.FieldContains(FieldValue, v)) } // ValueHasPrefix applies the HasPrefix predicate on the "value" field. func ValueHasPrefix(v string) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldValue), v)) - }) + return predicate.ConfigItem(sql.FieldHasPrefix(FieldValue, v)) } // ValueHasSuffix applies the HasSuffix predicate on the "value" field. func ValueHasSuffix(v string) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldValue), v)) - }) + return predicate.ConfigItem(sql.FieldHasSuffix(FieldValue, v)) } // ValueEqualFold applies the EqualFold predicate on the "value" field. func ValueEqualFold(v string) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldValue), v)) - }) + return predicate.ConfigItem(sql.FieldEqualFold(FieldValue, v)) } // ValueContainsFold applies the ContainsFold predicate on the "value" field. func ValueContainsFold(v string) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldValue), v)) - }) + return predicate.ConfigItem(sql.FieldContainsFold(FieldValue, v)) } // And groups predicates with the AND operator between them. func And(predicates ...predicate.ConfigItem) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s1 := s.Clone().SetP(nil) - for _, p := range predicates { - p(s1) - } - s.Where(s1.P()) - }) + return predicate.ConfigItem(sql.AndPredicates(predicates...)) } // Or groups predicates with the OR operator between them. func Or(predicates ...predicate.ConfigItem) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s1 := s.Clone().SetP(nil) - for i, p := range predicates { - if i > 0 { - s1.Or() - } - p(s1) - } - s.Where(s1.P()) - }) + return predicate.ConfigItem(sql.OrPredicates(predicates...)) } // Not applies the not operator on the given predicate. func Not(p predicate.ConfigItem) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - p(s.Not()) - }) + return predicate.ConfigItem(sql.NotPredicates(p)) } diff --git a/pkg/database/ent/configitem_create.go b/pkg/database/ent/configitem_create.go index 736e6a50514..19e73dea41c 100644 --- a/pkg/database/ent/configitem_create.go +++ b/pkg/database/ent/configitem_create.go @@ -67,50 +67,8 @@ func (cic *ConfigItemCreate) Mutation() *ConfigItemMutation { // Save creates the ConfigItem in the database. func (cic *ConfigItemCreate) Save(ctx context.Context) (*ConfigItem, error) { - var ( - err error - node *ConfigItem - ) cic.defaults() - if len(cic.hooks) == 0 { - if err = cic.check(); err != nil { - return nil, err - } - node, err = cic.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*ConfigItemMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - if err = cic.check(); err != nil { - return nil, err - } - cic.mutation = mutation - if node, err = cic.sqlSave(ctx); err != nil { - return nil, err - } - mutation.id = &node.ID - mutation.done = true - return node, err - }) - for i := len(cic.hooks) - 1; i >= 0; i-- { - if cic.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = cic.hooks[i](mut) - } - v, err := mut.Mutate(ctx, cic.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*ConfigItem) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from ConfigItemMutation", v) - } - node = nv - } - return node, err + return withHooks(ctx, cic.sqlSave, cic.mutation, cic.hooks) } // SaveX calls Save and panics if Save returns an error. @@ -159,6 +117,9 @@ func (cic *ConfigItemCreate) check() error { } func (cic *ConfigItemCreate) sqlSave(ctx context.Context) (*ConfigItem, error) { + if err := cic.check(); err != nil { + return nil, err + } _node, _spec := cic.createSpec() if err := sqlgraph.CreateNode(ctx, cic.driver, _spec); err != nil { if sqlgraph.IsConstraintError(err) { @@ -168,50 +129,30 @@ func (cic *ConfigItemCreate) sqlSave(ctx context.Context) (*ConfigItem, error) { } id := _spec.ID.Value.(int64) _node.ID = int(id) + cic.mutation.id = &_node.ID + cic.mutation.done = true return _node, nil } func (cic *ConfigItemCreate) createSpec() (*ConfigItem, *sqlgraph.CreateSpec) { var ( _node = &ConfigItem{config: cic.config} - _spec = &sqlgraph.CreateSpec{ - Table: configitem.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: configitem.FieldID, - }, - } + _spec = sqlgraph.NewCreateSpec(configitem.Table, sqlgraph.NewFieldSpec(configitem.FieldID, field.TypeInt)) ) if value, ok := cic.mutation.CreatedAt(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: configitem.FieldCreatedAt, - }) + _spec.SetField(configitem.FieldCreatedAt, field.TypeTime, value) _node.CreatedAt = &value } if value, ok := cic.mutation.UpdatedAt(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: configitem.FieldUpdatedAt, - }) + _spec.SetField(configitem.FieldUpdatedAt, field.TypeTime, value) _node.UpdatedAt = &value } if value, ok := cic.mutation.Name(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: configitem.FieldName, - }) + _spec.SetField(configitem.FieldName, field.TypeString, value) _node.Name = value } if value, ok := cic.mutation.Value(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: configitem.FieldValue, - }) + _spec.SetField(configitem.FieldValue, field.TypeString, value) _node.Value = value } return _node, _spec @@ -220,11 +161,15 @@ func (cic *ConfigItemCreate) createSpec() (*ConfigItem, *sqlgraph.CreateSpec) { // ConfigItemCreateBulk is the builder for creating many ConfigItem entities in bulk. type ConfigItemCreateBulk struct { config + err error builders []*ConfigItemCreate } // Save creates the ConfigItem entities in the database. func (cicb *ConfigItemCreateBulk) Save(ctx context.Context) ([]*ConfigItem, error) { + if cicb.err != nil { + return nil, cicb.err + } specs := make([]*sqlgraph.CreateSpec, len(cicb.builders)) nodes := make([]*ConfigItem, len(cicb.builders)) mutators := make([]Mutator, len(cicb.builders)) @@ -241,8 +186,8 @@ func (cicb *ConfigItemCreateBulk) Save(ctx context.Context) ([]*ConfigItem, erro return nil, err } builder.mutation = mutation - nodes[i], specs[i] = builder.createSpec() var err error + nodes[i], specs[i] = builder.createSpec() if i < len(mutators)-1 { _, err = mutators[i+1].Mutate(root, cicb.builders[i+1].mutation) } else { diff --git a/pkg/database/ent/configitem_delete.go b/pkg/database/ent/configitem_delete.go index 223fa9eefbf..a5dc811f60d 100644 --- a/pkg/database/ent/configitem_delete.go +++ b/pkg/database/ent/configitem_delete.go @@ -4,7 +4,6 @@ package ent import ( "context" - "fmt" "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" @@ -28,34 +27,7 @@ func (cid *ConfigItemDelete) Where(ps ...predicate.ConfigItem) *ConfigItemDelete // Exec executes the deletion query and returns how many vertices were deleted. func (cid *ConfigItemDelete) Exec(ctx context.Context) (int, error) { - var ( - err error - affected int - ) - if len(cid.hooks) == 0 { - affected, err = cid.sqlExec(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*ConfigItemMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - cid.mutation = mutation - affected, err = cid.sqlExec(ctx) - mutation.done = true - return affected, err - }) - for i := len(cid.hooks) - 1; i >= 0; i-- { - if cid.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = cid.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, cid.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks(ctx, cid.sqlExec, cid.mutation, cid.hooks) } // ExecX is like Exec, but panics if an error occurs. @@ -68,15 +40,7 @@ func (cid *ConfigItemDelete) ExecX(ctx context.Context) int { } func (cid *ConfigItemDelete) sqlExec(ctx context.Context) (int, error) { - _spec := &sqlgraph.DeleteSpec{ - Node: &sqlgraph.NodeSpec{ - Table: configitem.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: configitem.FieldID, - }, - }, - } + _spec := sqlgraph.NewDeleteSpec(configitem.Table, sqlgraph.NewFieldSpec(configitem.FieldID, field.TypeInt)) if ps := cid.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -88,6 +52,7 @@ func (cid *ConfigItemDelete) sqlExec(ctx context.Context) (int, error) { if err != nil && sqlgraph.IsConstraintError(err) { err = &ConstraintError{msg: err.Error(), wrap: err} } + cid.mutation.done = true return affected, err } @@ -96,6 +61,12 @@ type ConfigItemDeleteOne struct { cid *ConfigItemDelete } +// Where appends a list predicates to the ConfigItemDelete builder. +func (cido *ConfigItemDeleteOne) Where(ps ...predicate.ConfigItem) *ConfigItemDeleteOne { + cido.cid.mutation.Where(ps...) + return cido +} + // Exec executes the deletion query. func (cido *ConfigItemDeleteOne) Exec(ctx context.Context) error { n, err := cido.cid.Exec(ctx) @@ -111,5 +82,7 @@ func (cido *ConfigItemDeleteOne) Exec(ctx context.Context) error { // ExecX is like Exec, but panics if an error occurs. func (cido *ConfigItemDeleteOne) ExecX(ctx context.Context) { - cido.cid.ExecX(ctx) + if err := cido.Exec(ctx); err != nil { + panic(err) + } } diff --git a/pkg/database/ent/configitem_query.go b/pkg/database/ent/configitem_query.go index 6c9e6732a9b..f68b8953ddb 100644 --- a/pkg/database/ent/configitem_query.go +++ b/pkg/database/ent/configitem_query.go @@ -17,11 +17,9 @@ import ( // ConfigItemQuery is the builder for querying ConfigItem entities. type ConfigItemQuery struct { config - limit *int - offset *int - unique *bool - order []OrderFunc - fields []string + ctx *QueryContext + order []configitem.OrderOption + inters []Interceptor predicates []predicate.ConfigItem // intermediate query (i.e. traversal path). sql *sql.Selector @@ -34,27 +32,27 @@ func (ciq *ConfigItemQuery) Where(ps ...predicate.ConfigItem) *ConfigItemQuery { return ciq } -// Limit adds a limit step to the query. +// Limit the number of records to be returned by this query. func (ciq *ConfigItemQuery) Limit(limit int) *ConfigItemQuery { - ciq.limit = &limit + ciq.ctx.Limit = &limit return ciq } -// Offset adds an offset step to the query. +// Offset to start from. func (ciq *ConfigItemQuery) Offset(offset int) *ConfigItemQuery { - ciq.offset = &offset + ciq.ctx.Offset = &offset return ciq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (ciq *ConfigItemQuery) Unique(unique bool) *ConfigItemQuery { - ciq.unique = &unique + ciq.ctx.Unique = &unique return ciq } -// Order adds an order step to the query. -func (ciq *ConfigItemQuery) Order(o ...OrderFunc) *ConfigItemQuery { +// Order specifies how the records should be ordered. +func (ciq *ConfigItemQuery) Order(o ...configitem.OrderOption) *ConfigItemQuery { ciq.order = append(ciq.order, o...) return ciq } @@ -62,7 +60,7 @@ func (ciq *ConfigItemQuery) Order(o ...OrderFunc) *ConfigItemQuery { // First returns the first ConfigItem entity from the query. // Returns a *NotFoundError when no ConfigItem was found. func (ciq *ConfigItemQuery) First(ctx context.Context) (*ConfigItem, error) { - nodes, err := ciq.Limit(1).All(ctx) + nodes, err := ciq.Limit(1).All(setContextOp(ctx, ciq.ctx, "First")) if err != nil { return nil, err } @@ -85,7 +83,7 @@ func (ciq *ConfigItemQuery) FirstX(ctx context.Context) *ConfigItem { // Returns a *NotFoundError when no ConfigItem ID was found. func (ciq *ConfigItemQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = ciq.Limit(1).IDs(ctx); err != nil { + if ids, err = ciq.Limit(1).IDs(setContextOp(ctx, ciq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -108,7 +106,7 @@ func (ciq *ConfigItemQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one ConfigItem entity is found. // Returns a *NotFoundError when no ConfigItem entities are found. func (ciq *ConfigItemQuery) Only(ctx context.Context) (*ConfigItem, error) { - nodes, err := ciq.Limit(2).All(ctx) + nodes, err := ciq.Limit(2).All(setContextOp(ctx, ciq.ctx, "Only")) if err != nil { return nil, err } @@ -136,7 +134,7 @@ func (ciq *ConfigItemQuery) OnlyX(ctx context.Context) *ConfigItem { // Returns a *NotFoundError when no entities are found. func (ciq *ConfigItemQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = ciq.Limit(2).IDs(ctx); err != nil { + if ids, err = ciq.Limit(2).IDs(setContextOp(ctx, ciq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -161,10 +159,12 @@ func (ciq *ConfigItemQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of ConfigItems. func (ciq *ConfigItemQuery) All(ctx context.Context) ([]*ConfigItem, error) { + ctx = setContextOp(ctx, ciq.ctx, "All") if err := ciq.prepareQuery(ctx); err != nil { return nil, err } - return ciq.sqlAll(ctx) + qr := querierAll[[]*ConfigItem, *ConfigItemQuery]() + return withInterceptors[[]*ConfigItem](ctx, ciq, qr, ciq.inters) } // AllX is like All, but panics if an error occurs. @@ -177,9 +177,12 @@ func (ciq *ConfigItemQuery) AllX(ctx context.Context) []*ConfigItem { } // IDs executes the query and returns a list of ConfigItem IDs. -func (ciq *ConfigItemQuery) IDs(ctx context.Context) ([]int, error) { - var ids []int - if err := ciq.Select(configitem.FieldID).Scan(ctx, &ids); err != nil { +func (ciq *ConfigItemQuery) IDs(ctx context.Context) (ids []int, err error) { + if ciq.ctx.Unique == nil && ciq.path != nil { + ciq.Unique(true) + } + ctx = setContextOp(ctx, ciq.ctx, "IDs") + if err = ciq.Select(configitem.FieldID).Scan(ctx, &ids); err != nil { return nil, err } return ids, nil @@ -196,10 +199,11 @@ func (ciq *ConfigItemQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (ciq *ConfigItemQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, ciq.ctx, "Count") if err := ciq.prepareQuery(ctx); err != nil { return 0, err } - return ciq.sqlCount(ctx) + return withInterceptors[int](ctx, ciq, querierCount[*ConfigItemQuery](), ciq.inters) } // CountX is like Count, but panics if an error occurs. @@ -213,10 +217,15 @@ func (ciq *ConfigItemQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (ciq *ConfigItemQuery) Exist(ctx context.Context) (bool, error) { - if err := ciq.prepareQuery(ctx); err != nil { - return false, err + ctx = setContextOp(ctx, ciq.ctx, "Exist") + switch _, err := ciq.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil } - return ciq.sqlExist(ctx) } // ExistX is like Exist, but panics if an error occurs. @@ -236,14 +245,13 @@ func (ciq *ConfigItemQuery) Clone() *ConfigItemQuery { } return &ConfigItemQuery{ config: ciq.config, - limit: ciq.limit, - offset: ciq.offset, - order: append([]OrderFunc{}, ciq.order...), + ctx: ciq.ctx.Clone(), + order: append([]configitem.OrderOption{}, ciq.order...), + inters: append([]Interceptor{}, ciq.inters...), predicates: append([]predicate.ConfigItem{}, ciq.predicates...), // clone intermediate query. - sql: ciq.sql.Clone(), - path: ciq.path, - unique: ciq.unique, + sql: ciq.sql.Clone(), + path: ciq.path, } } @@ -262,16 +270,11 @@ func (ciq *ConfigItemQuery) Clone() *ConfigItemQuery { // Aggregate(ent.Count()). // Scan(ctx, &v) func (ciq *ConfigItemQuery) GroupBy(field string, fields ...string) *ConfigItemGroupBy { - grbuild := &ConfigItemGroupBy{config: ciq.config} - grbuild.fields = append([]string{field}, fields...) - grbuild.path = func(ctx context.Context) (prev *sql.Selector, err error) { - if err := ciq.prepareQuery(ctx); err != nil { - return nil, err - } - return ciq.sqlQuery(ctx), nil - } + ciq.ctx.Fields = append([]string{field}, fields...) + grbuild := &ConfigItemGroupBy{build: ciq} + grbuild.flds = &ciq.ctx.Fields grbuild.label = configitem.Label - grbuild.flds, grbuild.scan = &grbuild.fields, grbuild.Scan + grbuild.scan = grbuild.Scan return grbuild } @@ -288,15 +291,30 @@ func (ciq *ConfigItemQuery) GroupBy(field string, fields ...string) *ConfigItemG // Select(configitem.FieldCreatedAt). // Scan(ctx, &v) func (ciq *ConfigItemQuery) Select(fields ...string) *ConfigItemSelect { - ciq.fields = append(ciq.fields, fields...) - selbuild := &ConfigItemSelect{ConfigItemQuery: ciq} - selbuild.label = configitem.Label - selbuild.flds, selbuild.scan = &ciq.fields, selbuild.Scan - return selbuild + ciq.ctx.Fields = append(ciq.ctx.Fields, fields...) + sbuild := &ConfigItemSelect{ConfigItemQuery: ciq} + sbuild.label = configitem.Label + sbuild.flds, sbuild.scan = &ciq.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a ConfigItemSelect configured with the given aggregations. +func (ciq *ConfigItemQuery) Aggregate(fns ...AggregateFunc) *ConfigItemSelect { + return ciq.Select().Aggregate(fns...) } func (ciq *ConfigItemQuery) prepareQuery(ctx context.Context) error { - for _, f := range ciq.fields { + for _, inter := range ciq.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, ciq); err != nil { + return err + } + } + } + for _, f := range ciq.ctx.Fields { if !configitem.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -338,41 +356,22 @@ func (ciq *ConfigItemQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]* func (ciq *ConfigItemQuery) sqlCount(ctx context.Context) (int, error) { _spec := ciq.querySpec() - _spec.Node.Columns = ciq.fields - if len(ciq.fields) > 0 { - _spec.Unique = ciq.unique != nil && *ciq.unique + _spec.Node.Columns = ciq.ctx.Fields + if len(ciq.ctx.Fields) > 0 { + _spec.Unique = ciq.ctx.Unique != nil && *ciq.ctx.Unique } return sqlgraph.CountNodes(ctx, ciq.driver, _spec) } -func (ciq *ConfigItemQuery) sqlExist(ctx context.Context) (bool, error) { - switch _, err := ciq.FirstID(ctx); { - case IsNotFound(err): - return false, nil - case err != nil: - return false, fmt.Errorf("ent: check existence: %w", err) - default: - return true, nil - } -} - func (ciq *ConfigItemQuery) querySpec() *sqlgraph.QuerySpec { - _spec := &sqlgraph.QuerySpec{ - Node: &sqlgraph.NodeSpec{ - Table: configitem.Table, - Columns: configitem.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: configitem.FieldID, - }, - }, - From: ciq.sql, - Unique: true, - } - if unique := ciq.unique; unique != nil { + _spec := sqlgraph.NewQuerySpec(configitem.Table, configitem.Columns, sqlgraph.NewFieldSpec(configitem.FieldID, field.TypeInt)) + _spec.From = ciq.sql + if unique := ciq.ctx.Unique; unique != nil { _spec.Unique = *unique + } else if ciq.path != nil { + _spec.Unique = true } - if fields := ciq.fields; len(fields) > 0 { + if fields := ciq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, configitem.FieldID) for i := range fields { @@ -388,10 +387,10 @@ func (ciq *ConfigItemQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := ciq.limit; limit != nil { + if limit := ciq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := ciq.offset; offset != nil { + if offset := ciq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := ciq.order; len(ps) > 0 { @@ -407,7 +406,7 @@ func (ciq *ConfigItemQuery) querySpec() *sqlgraph.QuerySpec { func (ciq *ConfigItemQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(ciq.driver.Dialect()) t1 := builder.Table(configitem.Table) - columns := ciq.fields + columns := ciq.ctx.Fields if len(columns) == 0 { columns = configitem.Columns } @@ -416,7 +415,7 @@ func (ciq *ConfigItemQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = ciq.sql selector.Select(selector.Columns(columns...)...) } - if ciq.unique != nil && *ciq.unique { + if ciq.ctx.Unique != nil && *ciq.ctx.Unique { selector.Distinct() } for _, p := range ciq.predicates { @@ -425,12 +424,12 @@ func (ciq *ConfigItemQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range ciq.order { p(selector) } - if offset := ciq.offset; offset != nil { + if offset := ciq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := ciq.limit; limit != nil { + if limit := ciq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -438,13 +437,8 @@ func (ciq *ConfigItemQuery) sqlQuery(ctx context.Context) *sql.Selector { // ConfigItemGroupBy is the group-by builder for ConfigItem entities. type ConfigItemGroupBy struct { - config selector - fields []string - fns []AggregateFunc - // intermediate query (i.e. traversal path). - sql *sql.Selector - path func(context.Context) (*sql.Selector, error) + build *ConfigItemQuery } // Aggregate adds the given aggregation functions to the group-by query. @@ -453,74 +447,77 @@ func (cigb *ConfigItemGroupBy) Aggregate(fns ...AggregateFunc) *ConfigItemGroupB return cigb } -// Scan applies the group-by query and scans the result into the given value. +// Scan applies the selector query and scans the result into the given value. func (cigb *ConfigItemGroupBy) Scan(ctx context.Context, v any) error { - query, err := cigb.path(ctx) - if err != nil { + ctx = setContextOp(ctx, cigb.build.ctx, "GroupBy") + if err := cigb.build.prepareQuery(ctx); err != nil { return err } - cigb.sql = query - return cigb.sqlScan(ctx, v) + return scanWithInterceptors[*ConfigItemQuery, *ConfigItemGroupBy](ctx, cigb.build, cigb, cigb.build.inters, v) } -func (cigb *ConfigItemGroupBy) sqlScan(ctx context.Context, v any) error { - for _, f := range cigb.fields { - if !configitem.ValidColumn(f) { - return &ValidationError{Name: f, err: fmt.Errorf("invalid field %q for group-by", f)} +func (cigb *ConfigItemGroupBy) sqlScan(ctx context.Context, root *ConfigItemQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(cigb.fns)) + for _, fn := range cigb.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*cigb.flds)+len(cigb.fns)) + for _, f := range *cigb.flds { + columns = append(columns, selector.C(f)) } + columns = append(columns, aggregation...) + selector.Select(columns...) } - selector := cigb.sqlQuery() + selector.GroupBy(selector.Columns(*cigb.flds...)...) if err := selector.Err(); err != nil { return err } rows := &sql.Rows{} query, args := selector.Query() - if err := cigb.driver.Query(ctx, query, args, rows); err != nil { + if err := cigb.build.driver.Query(ctx, query, args, rows); err != nil { return err } defer rows.Close() return sql.ScanSlice(rows, v) } -func (cigb *ConfigItemGroupBy) sqlQuery() *sql.Selector { - selector := cigb.sql.Select() - aggregation := make([]string, 0, len(cigb.fns)) - for _, fn := range cigb.fns { - aggregation = append(aggregation, fn(selector)) - } - // If no columns were selected in a custom aggregation function, the default - // selection is the fields used for "group-by", and the aggregation functions. - if len(selector.SelectedColumns()) == 0 { - columns := make([]string, 0, len(cigb.fields)+len(cigb.fns)) - for _, f := range cigb.fields { - columns = append(columns, selector.C(f)) - } - columns = append(columns, aggregation...) - selector.Select(columns...) - } - return selector.GroupBy(selector.Columns(cigb.fields...)...) -} - // ConfigItemSelect is the builder for selecting fields of ConfigItem entities. type ConfigItemSelect struct { *ConfigItemQuery selector - // intermediate query (i.e. traversal path). - sql *sql.Selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (cis *ConfigItemSelect) Aggregate(fns ...AggregateFunc) *ConfigItemSelect { + cis.fns = append(cis.fns, fns...) + return cis } // Scan applies the selector query and scans the result into the given value. func (cis *ConfigItemSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, cis.ctx, "Select") if err := cis.prepareQuery(ctx); err != nil { return err } - cis.sql = cis.ConfigItemQuery.sqlQuery(ctx) - return cis.sqlScan(ctx, v) + return scanWithInterceptors[*ConfigItemQuery, *ConfigItemSelect](ctx, cis.ConfigItemQuery, cis, cis.inters, v) } -func (cis *ConfigItemSelect) sqlScan(ctx context.Context, v any) error { +func (cis *ConfigItemSelect) sqlScan(ctx context.Context, root *ConfigItemQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(cis.fns)) + for _, fn := range cis.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*cis.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } rows := &sql.Rows{} - query, args := cis.sql.Query() + query, args := selector.Query() if err := cis.driver.Query(ctx, query, args, rows); err != nil { return err } diff --git a/pkg/database/ent/configitem_update.go b/pkg/database/ent/configitem_update.go index e591347a0c3..0db3a0b5233 100644 --- a/pkg/database/ent/configitem_update.go +++ b/pkg/database/ent/configitem_update.go @@ -71,35 +71,8 @@ func (ciu *ConfigItemUpdate) Mutation() *ConfigItemMutation { // Save executes the query and returns the number of nodes affected by the update operation. func (ciu *ConfigItemUpdate) Save(ctx context.Context) (int, error) { - var ( - err error - affected int - ) ciu.defaults() - if len(ciu.hooks) == 0 { - affected, err = ciu.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*ConfigItemMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - ciu.mutation = mutation - affected, err = ciu.sqlSave(ctx) - mutation.done = true - return affected, err - }) - for i := len(ciu.hooks) - 1; i >= 0; i-- { - if ciu.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = ciu.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, ciu.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks(ctx, ciu.sqlSave, ciu.mutation, ciu.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -137,16 +110,7 @@ func (ciu *ConfigItemUpdate) defaults() { } func (ciu *ConfigItemUpdate) sqlSave(ctx context.Context) (n int, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: configitem.Table, - Columns: configitem.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: configitem.FieldID, - }, - }, - } + _spec := sqlgraph.NewUpdateSpec(configitem.Table, configitem.Columns, sqlgraph.NewFieldSpec(configitem.FieldID, field.TypeInt)) if ps := ciu.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -155,44 +119,22 @@ func (ciu *ConfigItemUpdate) sqlSave(ctx context.Context) (n int, err error) { } } if value, ok := ciu.mutation.CreatedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: configitem.FieldCreatedAt, - }) + _spec.SetField(configitem.FieldCreatedAt, field.TypeTime, value) } if ciu.mutation.CreatedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: configitem.FieldCreatedAt, - }) + _spec.ClearField(configitem.FieldCreatedAt, field.TypeTime) } if value, ok := ciu.mutation.UpdatedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: configitem.FieldUpdatedAt, - }) + _spec.SetField(configitem.FieldUpdatedAt, field.TypeTime, value) } if ciu.mutation.UpdatedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: configitem.FieldUpdatedAt, - }) + _spec.ClearField(configitem.FieldUpdatedAt, field.TypeTime) } if value, ok := ciu.mutation.Name(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: configitem.FieldName, - }) + _spec.SetField(configitem.FieldName, field.TypeString, value) } if value, ok := ciu.mutation.Value(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: configitem.FieldValue, - }) + _spec.SetField(configitem.FieldValue, field.TypeString, value) } if n, err = sqlgraph.UpdateNodes(ctx, ciu.driver, _spec); err != nil { if _, ok := err.(*sqlgraph.NotFoundError); ok { @@ -202,6 +144,7 @@ func (ciu *ConfigItemUpdate) sqlSave(ctx context.Context) (n int, err error) { } return 0, err } + ciu.mutation.done = true return n, nil } @@ -254,6 +197,12 @@ func (ciuo *ConfigItemUpdateOne) Mutation() *ConfigItemMutation { return ciuo.mutation } +// Where appends a list predicates to the ConfigItemUpdate builder. +func (ciuo *ConfigItemUpdateOne) Where(ps ...predicate.ConfigItem) *ConfigItemUpdateOne { + ciuo.mutation.Where(ps...) + return ciuo +} + // Select allows selecting one or more fields (columns) of the returned entity. // The default is selecting all fields defined in the entity schema. func (ciuo *ConfigItemUpdateOne) Select(field string, fields ...string) *ConfigItemUpdateOne { @@ -263,41 +212,8 @@ func (ciuo *ConfigItemUpdateOne) Select(field string, fields ...string) *ConfigI // Save executes the query and returns the updated ConfigItem entity. func (ciuo *ConfigItemUpdateOne) Save(ctx context.Context) (*ConfigItem, error) { - var ( - err error - node *ConfigItem - ) ciuo.defaults() - if len(ciuo.hooks) == 0 { - node, err = ciuo.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*ConfigItemMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - ciuo.mutation = mutation - node, err = ciuo.sqlSave(ctx) - mutation.done = true - return node, err - }) - for i := len(ciuo.hooks) - 1; i >= 0; i-- { - if ciuo.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = ciuo.hooks[i](mut) - } - v, err := mut.Mutate(ctx, ciuo.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*ConfigItem) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from ConfigItemMutation", v) - } - node = nv - } - return node, err + return withHooks(ctx, ciuo.sqlSave, ciuo.mutation, ciuo.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -335,16 +251,7 @@ func (ciuo *ConfigItemUpdateOne) defaults() { } func (ciuo *ConfigItemUpdateOne) sqlSave(ctx context.Context) (_node *ConfigItem, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: configitem.Table, - Columns: configitem.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: configitem.FieldID, - }, - }, - } + _spec := sqlgraph.NewUpdateSpec(configitem.Table, configitem.Columns, sqlgraph.NewFieldSpec(configitem.FieldID, field.TypeInt)) id, ok := ciuo.mutation.ID() if !ok { return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "ConfigItem.id" for update`)} @@ -370,44 +277,22 @@ func (ciuo *ConfigItemUpdateOne) sqlSave(ctx context.Context) (_node *ConfigItem } } if value, ok := ciuo.mutation.CreatedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: configitem.FieldCreatedAt, - }) + _spec.SetField(configitem.FieldCreatedAt, field.TypeTime, value) } if ciuo.mutation.CreatedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: configitem.FieldCreatedAt, - }) + _spec.ClearField(configitem.FieldCreatedAt, field.TypeTime) } if value, ok := ciuo.mutation.UpdatedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: configitem.FieldUpdatedAt, - }) + _spec.SetField(configitem.FieldUpdatedAt, field.TypeTime, value) } if ciuo.mutation.UpdatedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: configitem.FieldUpdatedAt, - }) + _spec.ClearField(configitem.FieldUpdatedAt, field.TypeTime) } if value, ok := ciuo.mutation.Name(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: configitem.FieldName, - }) + _spec.SetField(configitem.FieldName, field.TypeString, value) } if value, ok := ciuo.mutation.Value(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: configitem.FieldValue, - }) + _spec.SetField(configitem.FieldValue, field.TypeString, value) } _node = &ConfigItem{config: ciuo.config} _spec.Assign = _node.assignValues @@ -420,5 +305,6 @@ func (ciuo *ConfigItemUpdateOne) sqlSave(ctx context.Context) (_node *ConfigItem } return nil, err } + ciuo.mutation.done = true return _node, nil } diff --git a/pkg/database/ent/context.go b/pkg/database/ent/context.go deleted file mode 100644 index 7811bfa2349..00000000000 --- a/pkg/database/ent/context.go +++ /dev/null @@ -1,33 +0,0 @@ -// Code generated by ent, DO NOT EDIT. - -package ent - -import ( - "context" -) - -type clientCtxKey struct{} - -// FromContext returns a Client stored inside a context, or nil if there isn't one. -func FromContext(ctx context.Context) *Client { - c, _ := ctx.Value(clientCtxKey{}).(*Client) - return c -} - -// NewContext returns a new context with the given Client attached. -func NewContext(parent context.Context, c *Client) context.Context { - return context.WithValue(parent, clientCtxKey{}, c) -} - -type txCtxKey struct{} - -// TxFromContext returns a Tx stored inside a context, or nil if there isn't one. -func TxFromContext(ctx context.Context) *Tx { - tx, _ := ctx.Value(txCtxKey{}).(*Tx) - return tx -} - -// NewTxContext returns a new context with the given Tx attached. -func NewTxContext(parent context.Context, tx *Tx) context.Context { - return context.WithValue(parent, txCtxKey{}, tx) -} diff --git a/pkg/database/ent/decision.go b/pkg/database/ent/decision.go index c969e576724..8a08bc1dfd4 100644 --- a/pkg/database/ent/decision.go +++ b/pkg/database/ent/decision.go @@ -7,6 +7,7 @@ import ( "strings" "time" + "entgo.io/ent" "entgo.io/ent/dialect/sql" "github.com/crowdsecurity/crowdsec/pkg/database/ent/alert" "github.com/crowdsecurity/crowdsec/pkg/database/ent/decision" @@ -51,7 +52,8 @@ type Decision struct { AlertDecisions int `json:"alert_decisions,omitempty"` // Edges holds the relations/edges for other nodes in the graph. // The values are being populated by the DecisionQuery when eager-loading is set. - Edges DecisionEdges `json:"edges"` + Edges DecisionEdges `json:"edges"` + selectValues sql.SelectValues } // DecisionEdges holds the relations/edges for other nodes in the graph. @@ -90,7 +92,7 @@ func (*Decision) scanValues(columns []string) ([]any, error) { case decision.FieldCreatedAt, decision.FieldUpdatedAt, decision.FieldUntil: values[i] = new(sql.NullTime) default: - return nil, fmt.Errorf("unexpected column %q for type Decision", columns[i]) + values[i] = new(sql.UnknownType) } } return values, nil @@ -209,21 +211,29 @@ func (d *Decision) assignValues(columns []string, values []any) error { } else if value.Valid { d.AlertDecisions = int(value.Int64) } + default: + d.selectValues.Set(columns[i], values[i]) } } return nil } +// GetValue returns the ent.Value that was dynamically selected and assigned to the Decision. +// This includes values selected through modifiers, order, etc. +func (d *Decision) GetValue(name string) (ent.Value, error) { + return d.selectValues.Get(name) +} + // QueryOwner queries the "owner" edge of the Decision entity. func (d *Decision) QueryOwner() *AlertQuery { - return (&DecisionClient{config: d.config}).QueryOwner(d) + return NewDecisionClient(d.config).QueryOwner(d) } // Update returns a builder for updating this Decision. // Note that you need to call Decision.Unwrap() before calling this method if this Decision // was returned from a transaction, and the transaction was committed or rolled back. func (d *Decision) Update() *DecisionUpdateOne { - return (&DecisionClient{config: d.config}).UpdateOne(d) + return NewDecisionClient(d.config).UpdateOne(d) } // Unwrap unwraps the Decision entity that was returned from a transaction after it was closed, @@ -301,9 +311,3 @@ func (d *Decision) String() string { // Decisions is a parsable slice of Decision. type Decisions []*Decision - -func (d Decisions) config(cfg config) { - for _i := range d { - d[_i].config = cfg - } -} diff --git a/pkg/database/ent/decision/decision.go b/pkg/database/ent/decision/decision.go index a0012d940a8..d9f67623bd8 100644 --- a/pkg/database/ent/decision/decision.go +++ b/pkg/database/ent/decision/decision.go @@ -4,6 +4,9 @@ package decision import ( "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" ) const ( @@ -99,3 +102,105 @@ var ( // DefaultSimulated holds the default value on creation for the "simulated" field. DefaultSimulated bool ) + +// OrderOption defines the ordering options for the Decision queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByCreatedAt orders the results by the created_at field. +func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() +} + +// ByUpdatedAt orders the results by the updated_at field. +func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc() +} + +// ByUntil orders the results by the until field. +func ByUntil(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUntil, opts...).ToFunc() +} + +// ByScenario orders the results by the scenario field. +func ByScenario(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldScenario, opts...).ToFunc() +} + +// ByType orders the results by the type field. +func ByType(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldType, opts...).ToFunc() +} + +// ByStartIP orders the results by the start_ip field. +func ByStartIP(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldStartIP, opts...).ToFunc() +} + +// ByEndIP orders the results by the end_ip field. +func ByEndIP(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldEndIP, opts...).ToFunc() +} + +// ByStartSuffix orders the results by the start_suffix field. +func ByStartSuffix(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldStartSuffix, opts...).ToFunc() +} + +// ByEndSuffix orders the results by the end_suffix field. +func ByEndSuffix(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldEndSuffix, opts...).ToFunc() +} + +// ByIPSize orders the results by the ip_size field. +func ByIPSize(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldIPSize, opts...).ToFunc() +} + +// ByScope orders the results by the scope field. +func ByScope(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldScope, opts...).ToFunc() +} + +// ByValue orders the results by the value field. +func ByValue(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldValue, opts...).ToFunc() +} + +// ByOrigin orders the results by the origin field. +func ByOrigin(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldOrigin, opts...).ToFunc() +} + +// BySimulated orders the results by the simulated field. +func BySimulated(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSimulated, opts...).ToFunc() +} + +// ByUUID orders the results by the uuid field. +func ByUUID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUUID, opts...).ToFunc() +} + +// ByAlertDecisions orders the results by the alert_decisions field. +func ByAlertDecisions(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldAlertDecisions, opts...).ToFunc() +} + +// ByOwnerField orders the results by owner field. +func ByOwnerField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newOwnerStep(), sql.OrderByField(field, opts...)) + } +} +func newOwnerStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(OwnerInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, OwnerTable, OwnerColumn), + ) +} diff --git a/pkg/database/ent/decision/where.go b/pkg/database/ent/decision/where.go index 18716a4a7c1..36374f5714d 100644 --- a/pkg/database/ent/decision/where.go +++ b/pkg/database/ent/decision/where.go @@ -12,1481 +12,967 @@ import ( // ID filters vertices based on their ID field. func ID(id int) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.Decision(sql.FieldEQ(FieldID, id)) } // IDEQ applies the EQ predicate on the ID field. func IDEQ(id int) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.Decision(sql.FieldEQ(FieldID, id)) } // IDNEQ applies the NEQ predicate on the ID field. func IDNEQ(id int) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldID), id)) - }) + return predicate.Decision(sql.FieldNEQ(FieldID, id)) } // IDIn applies the In predicate on the ID field. func IDIn(ids ...int) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - v := make([]any, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.In(s.C(FieldID), v...)) - }) + return predicate.Decision(sql.FieldIn(FieldID, ids...)) } // IDNotIn applies the NotIn predicate on the ID field. func IDNotIn(ids ...int) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - v := make([]any, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.NotIn(s.C(FieldID), v...)) - }) + return predicate.Decision(sql.FieldNotIn(FieldID, ids...)) } // IDGT applies the GT predicate on the ID field. func IDGT(id int) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldID), id)) - }) + return predicate.Decision(sql.FieldGT(FieldID, id)) } // IDGTE applies the GTE predicate on the ID field. func IDGTE(id int) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldID), id)) - }) + return predicate.Decision(sql.FieldGTE(FieldID, id)) } // IDLT applies the LT predicate on the ID field. func IDLT(id int) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldID), id)) - }) + return predicate.Decision(sql.FieldLT(FieldID, id)) } // IDLTE applies the LTE predicate on the ID field. func IDLTE(id int) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldID), id)) - }) + return predicate.Decision(sql.FieldLTE(FieldID, id)) } // CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. func CreatedAt(v time.Time) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldCreatedAt), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldCreatedAt, v)) } // UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ. func UpdatedAt(v time.Time) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldUpdatedAt), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldUpdatedAt, v)) } // Until applies equality check predicate on the "until" field. It's identical to UntilEQ. func Until(v time.Time) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldUntil), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldUntil, v)) } // Scenario applies equality check predicate on the "scenario" field. It's identical to ScenarioEQ. func Scenario(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldScenario), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldScenario, v)) } // Type applies equality check predicate on the "type" field. It's identical to TypeEQ. func Type(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldType), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldType, v)) } // StartIP applies equality check predicate on the "start_ip" field. It's identical to StartIPEQ. func StartIP(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldStartIP), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldStartIP, v)) } // EndIP applies equality check predicate on the "end_ip" field. It's identical to EndIPEQ. func EndIP(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldEndIP), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldEndIP, v)) } // StartSuffix applies equality check predicate on the "start_suffix" field. It's identical to StartSuffixEQ. func StartSuffix(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldStartSuffix), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldStartSuffix, v)) } // EndSuffix applies equality check predicate on the "end_suffix" field. It's identical to EndSuffixEQ. func EndSuffix(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldEndSuffix), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldEndSuffix, v)) } // IPSize applies equality check predicate on the "ip_size" field. It's identical to IPSizeEQ. func IPSize(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldIPSize), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldIPSize, v)) } // Scope applies equality check predicate on the "scope" field. It's identical to ScopeEQ. func Scope(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldScope), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldScope, v)) } // Value applies equality check predicate on the "value" field. It's identical to ValueEQ. func Value(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldValue), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldValue, v)) } // Origin applies equality check predicate on the "origin" field. It's identical to OriginEQ. func Origin(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldOrigin), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldOrigin, v)) } // Simulated applies equality check predicate on the "simulated" field. It's identical to SimulatedEQ. func Simulated(v bool) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldSimulated), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldSimulated, v)) } // UUID applies equality check predicate on the "uuid" field. It's identical to UUIDEQ. func UUID(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldUUID), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldUUID, v)) } // AlertDecisions applies equality check predicate on the "alert_decisions" field. It's identical to AlertDecisionsEQ. func AlertDecisions(v int) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldAlertDecisions), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldAlertDecisions, v)) } // CreatedAtEQ applies the EQ predicate on the "created_at" field. func CreatedAtEQ(v time.Time) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldCreatedAt), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldCreatedAt, v)) } // CreatedAtNEQ applies the NEQ predicate on the "created_at" field. func CreatedAtNEQ(v time.Time) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldCreatedAt), v)) - }) + return predicate.Decision(sql.FieldNEQ(FieldCreatedAt, v)) } // CreatedAtIn applies the In predicate on the "created_at" field. func CreatedAtIn(vs ...time.Time) predicate.Decision { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldCreatedAt), v...)) - }) + return predicate.Decision(sql.FieldIn(FieldCreatedAt, vs...)) } // CreatedAtNotIn applies the NotIn predicate on the "created_at" field. func CreatedAtNotIn(vs ...time.Time) predicate.Decision { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldCreatedAt), v...)) - }) + return predicate.Decision(sql.FieldNotIn(FieldCreatedAt, vs...)) } // CreatedAtGT applies the GT predicate on the "created_at" field. func CreatedAtGT(v time.Time) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldCreatedAt), v)) - }) + return predicate.Decision(sql.FieldGT(FieldCreatedAt, v)) } // CreatedAtGTE applies the GTE predicate on the "created_at" field. func CreatedAtGTE(v time.Time) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldCreatedAt), v)) - }) + return predicate.Decision(sql.FieldGTE(FieldCreatedAt, v)) } // CreatedAtLT applies the LT predicate on the "created_at" field. func CreatedAtLT(v time.Time) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldCreatedAt), v)) - }) + return predicate.Decision(sql.FieldLT(FieldCreatedAt, v)) } // CreatedAtLTE applies the LTE predicate on the "created_at" field. func CreatedAtLTE(v time.Time) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldCreatedAt), v)) - }) + return predicate.Decision(sql.FieldLTE(FieldCreatedAt, v)) } // CreatedAtIsNil applies the IsNil predicate on the "created_at" field. func CreatedAtIsNil() predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldCreatedAt))) - }) + return predicate.Decision(sql.FieldIsNull(FieldCreatedAt)) } // CreatedAtNotNil applies the NotNil predicate on the "created_at" field. func CreatedAtNotNil() predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldCreatedAt))) - }) + return predicate.Decision(sql.FieldNotNull(FieldCreatedAt)) } // UpdatedAtEQ applies the EQ predicate on the "updated_at" field. func UpdatedAtEQ(v time.Time) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldUpdatedAt), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldUpdatedAt, v)) } // UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field. func UpdatedAtNEQ(v time.Time) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldUpdatedAt), v)) - }) + return predicate.Decision(sql.FieldNEQ(FieldUpdatedAt, v)) } // UpdatedAtIn applies the In predicate on the "updated_at" field. func UpdatedAtIn(vs ...time.Time) predicate.Decision { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldUpdatedAt), v...)) - }) + return predicate.Decision(sql.FieldIn(FieldUpdatedAt, vs...)) } // UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field. func UpdatedAtNotIn(vs ...time.Time) predicate.Decision { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldUpdatedAt), v...)) - }) + return predicate.Decision(sql.FieldNotIn(FieldUpdatedAt, vs...)) } // UpdatedAtGT applies the GT predicate on the "updated_at" field. func UpdatedAtGT(v time.Time) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldUpdatedAt), v)) - }) + return predicate.Decision(sql.FieldGT(FieldUpdatedAt, v)) } // UpdatedAtGTE applies the GTE predicate on the "updated_at" field. func UpdatedAtGTE(v time.Time) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldUpdatedAt), v)) - }) + return predicate.Decision(sql.FieldGTE(FieldUpdatedAt, v)) } // UpdatedAtLT applies the LT predicate on the "updated_at" field. func UpdatedAtLT(v time.Time) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldUpdatedAt), v)) - }) + return predicate.Decision(sql.FieldLT(FieldUpdatedAt, v)) } // UpdatedAtLTE applies the LTE predicate on the "updated_at" field. func UpdatedAtLTE(v time.Time) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldUpdatedAt), v)) - }) + return predicate.Decision(sql.FieldLTE(FieldUpdatedAt, v)) } // UpdatedAtIsNil applies the IsNil predicate on the "updated_at" field. func UpdatedAtIsNil() predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldUpdatedAt))) - }) + return predicate.Decision(sql.FieldIsNull(FieldUpdatedAt)) } // UpdatedAtNotNil applies the NotNil predicate on the "updated_at" field. func UpdatedAtNotNil() predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldUpdatedAt))) - }) + return predicate.Decision(sql.FieldNotNull(FieldUpdatedAt)) } // UntilEQ applies the EQ predicate on the "until" field. func UntilEQ(v time.Time) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldUntil), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldUntil, v)) } // UntilNEQ applies the NEQ predicate on the "until" field. func UntilNEQ(v time.Time) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldUntil), v)) - }) + return predicate.Decision(sql.FieldNEQ(FieldUntil, v)) } // UntilIn applies the In predicate on the "until" field. func UntilIn(vs ...time.Time) predicate.Decision { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldUntil), v...)) - }) + return predicate.Decision(sql.FieldIn(FieldUntil, vs...)) } // UntilNotIn applies the NotIn predicate on the "until" field. func UntilNotIn(vs ...time.Time) predicate.Decision { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldUntil), v...)) - }) + return predicate.Decision(sql.FieldNotIn(FieldUntil, vs...)) } // UntilGT applies the GT predicate on the "until" field. func UntilGT(v time.Time) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldUntil), v)) - }) + return predicate.Decision(sql.FieldGT(FieldUntil, v)) } // UntilGTE applies the GTE predicate on the "until" field. func UntilGTE(v time.Time) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldUntil), v)) - }) + return predicate.Decision(sql.FieldGTE(FieldUntil, v)) } // UntilLT applies the LT predicate on the "until" field. func UntilLT(v time.Time) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldUntil), v)) - }) + return predicate.Decision(sql.FieldLT(FieldUntil, v)) } // UntilLTE applies the LTE predicate on the "until" field. func UntilLTE(v time.Time) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldUntil), v)) - }) + return predicate.Decision(sql.FieldLTE(FieldUntil, v)) } // UntilIsNil applies the IsNil predicate on the "until" field. func UntilIsNil() predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldUntil))) - }) + return predicate.Decision(sql.FieldIsNull(FieldUntil)) } // UntilNotNil applies the NotNil predicate on the "until" field. func UntilNotNil() predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldUntil))) - }) + return predicate.Decision(sql.FieldNotNull(FieldUntil)) } // ScenarioEQ applies the EQ predicate on the "scenario" field. func ScenarioEQ(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldScenario), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldScenario, v)) } // ScenarioNEQ applies the NEQ predicate on the "scenario" field. func ScenarioNEQ(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldScenario), v)) - }) + return predicate.Decision(sql.FieldNEQ(FieldScenario, v)) } // ScenarioIn applies the In predicate on the "scenario" field. func ScenarioIn(vs ...string) predicate.Decision { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldScenario), v...)) - }) + return predicate.Decision(sql.FieldIn(FieldScenario, vs...)) } // ScenarioNotIn applies the NotIn predicate on the "scenario" field. func ScenarioNotIn(vs ...string) predicate.Decision { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldScenario), v...)) - }) + return predicate.Decision(sql.FieldNotIn(FieldScenario, vs...)) } // ScenarioGT applies the GT predicate on the "scenario" field. func ScenarioGT(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldScenario), v)) - }) + return predicate.Decision(sql.FieldGT(FieldScenario, v)) } // ScenarioGTE applies the GTE predicate on the "scenario" field. func ScenarioGTE(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldScenario), v)) - }) + return predicate.Decision(sql.FieldGTE(FieldScenario, v)) } // ScenarioLT applies the LT predicate on the "scenario" field. func ScenarioLT(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldScenario), v)) - }) + return predicate.Decision(sql.FieldLT(FieldScenario, v)) } // ScenarioLTE applies the LTE predicate on the "scenario" field. func ScenarioLTE(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldScenario), v)) - }) + return predicate.Decision(sql.FieldLTE(FieldScenario, v)) } // ScenarioContains applies the Contains predicate on the "scenario" field. func ScenarioContains(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldScenario), v)) - }) + return predicate.Decision(sql.FieldContains(FieldScenario, v)) } // ScenarioHasPrefix applies the HasPrefix predicate on the "scenario" field. func ScenarioHasPrefix(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldScenario), v)) - }) + return predicate.Decision(sql.FieldHasPrefix(FieldScenario, v)) } // ScenarioHasSuffix applies the HasSuffix predicate on the "scenario" field. func ScenarioHasSuffix(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldScenario), v)) - }) + return predicate.Decision(sql.FieldHasSuffix(FieldScenario, v)) } // ScenarioEqualFold applies the EqualFold predicate on the "scenario" field. func ScenarioEqualFold(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldScenario), v)) - }) + return predicate.Decision(sql.FieldEqualFold(FieldScenario, v)) } // ScenarioContainsFold applies the ContainsFold predicate on the "scenario" field. func ScenarioContainsFold(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldScenario), v)) - }) + return predicate.Decision(sql.FieldContainsFold(FieldScenario, v)) } // TypeEQ applies the EQ predicate on the "type" field. func TypeEQ(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldType), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldType, v)) } // TypeNEQ applies the NEQ predicate on the "type" field. func TypeNEQ(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldType), v)) - }) + return predicate.Decision(sql.FieldNEQ(FieldType, v)) } // TypeIn applies the In predicate on the "type" field. func TypeIn(vs ...string) predicate.Decision { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldType), v...)) - }) + return predicate.Decision(sql.FieldIn(FieldType, vs...)) } // TypeNotIn applies the NotIn predicate on the "type" field. func TypeNotIn(vs ...string) predicate.Decision { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldType), v...)) - }) + return predicate.Decision(sql.FieldNotIn(FieldType, vs...)) } // TypeGT applies the GT predicate on the "type" field. func TypeGT(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldType), v)) - }) + return predicate.Decision(sql.FieldGT(FieldType, v)) } // TypeGTE applies the GTE predicate on the "type" field. func TypeGTE(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldType), v)) - }) + return predicate.Decision(sql.FieldGTE(FieldType, v)) } // TypeLT applies the LT predicate on the "type" field. func TypeLT(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldType), v)) - }) + return predicate.Decision(sql.FieldLT(FieldType, v)) } // TypeLTE applies the LTE predicate on the "type" field. func TypeLTE(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldType), v)) - }) + return predicate.Decision(sql.FieldLTE(FieldType, v)) } // TypeContains applies the Contains predicate on the "type" field. func TypeContains(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldType), v)) - }) + return predicate.Decision(sql.FieldContains(FieldType, v)) } // TypeHasPrefix applies the HasPrefix predicate on the "type" field. func TypeHasPrefix(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldType), v)) - }) + return predicate.Decision(sql.FieldHasPrefix(FieldType, v)) } // TypeHasSuffix applies the HasSuffix predicate on the "type" field. func TypeHasSuffix(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldType), v)) - }) + return predicate.Decision(sql.FieldHasSuffix(FieldType, v)) } // TypeEqualFold applies the EqualFold predicate on the "type" field. func TypeEqualFold(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldType), v)) - }) + return predicate.Decision(sql.FieldEqualFold(FieldType, v)) } // TypeContainsFold applies the ContainsFold predicate on the "type" field. func TypeContainsFold(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldType), v)) - }) + return predicate.Decision(sql.FieldContainsFold(FieldType, v)) } // StartIPEQ applies the EQ predicate on the "start_ip" field. func StartIPEQ(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldStartIP), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldStartIP, v)) } // StartIPNEQ applies the NEQ predicate on the "start_ip" field. func StartIPNEQ(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldStartIP), v)) - }) + return predicate.Decision(sql.FieldNEQ(FieldStartIP, v)) } // StartIPIn applies the In predicate on the "start_ip" field. func StartIPIn(vs ...int64) predicate.Decision { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldStartIP), v...)) - }) + return predicate.Decision(sql.FieldIn(FieldStartIP, vs...)) } // StartIPNotIn applies the NotIn predicate on the "start_ip" field. func StartIPNotIn(vs ...int64) predicate.Decision { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldStartIP), v...)) - }) + return predicate.Decision(sql.FieldNotIn(FieldStartIP, vs...)) } // StartIPGT applies the GT predicate on the "start_ip" field. func StartIPGT(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldStartIP), v)) - }) + return predicate.Decision(sql.FieldGT(FieldStartIP, v)) } // StartIPGTE applies the GTE predicate on the "start_ip" field. func StartIPGTE(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldStartIP), v)) - }) + return predicate.Decision(sql.FieldGTE(FieldStartIP, v)) } // StartIPLT applies the LT predicate on the "start_ip" field. func StartIPLT(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldStartIP), v)) - }) + return predicate.Decision(sql.FieldLT(FieldStartIP, v)) } // StartIPLTE applies the LTE predicate on the "start_ip" field. func StartIPLTE(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldStartIP), v)) - }) + return predicate.Decision(sql.FieldLTE(FieldStartIP, v)) } // StartIPIsNil applies the IsNil predicate on the "start_ip" field. func StartIPIsNil() predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldStartIP))) - }) + return predicate.Decision(sql.FieldIsNull(FieldStartIP)) } // StartIPNotNil applies the NotNil predicate on the "start_ip" field. func StartIPNotNil() predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldStartIP))) - }) + return predicate.Decision(sql.FieldNotNull(FieldStartIP)) } // EndIPEQ applies the EQ predicate on the "end_ip" field. func EndIPEQ(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldEndIP), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldEndIP, v)) } // EndIPNEQ applies the NEQ predicate on the "end_ip" field. func EndIPNEQ(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldEndIP), v)) - }) + return predicate.Decision(sql.FieldNEQ(FieldEndIP, v)) } // EndIPIn applies the In predicate on the "end_ip" field. func EndIPIn(vs ...int64) predicate.Decision { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldEndIP), v...)) - }) + return predicate.Decision(sql.FieldIn(FieldEndIP, vs...)) } // EndIPNotIn applies the NotIn predicate on the "end_ip" field. func EndIPNotIn(vs ...int64) predicate.Decision { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldEndIP), v...)) - }) + return predicate.Decision(sql.FieldNotIn(FieldEndIP, vs...)) } // EndIPGT applies the GT predicate on the "end_ip" field. func EndIPGT(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldEndIP), v)) - }) + return predicate.Decision(sql.FieldGT(FieldEndIP, v)) } // EndIPGTE applies the GTE predicate on the "end_ip" field. func EndIPGTE(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldEndIP), v)) - }) + return predicate.Decision(sql.FieldGTE(FieldEndIP, v)) } // EndIPLT applies the LT predicate on the "end_ip" field. func EndIPLT(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldEndIP), v)) - }) + return predicate.Decision(sql.FieldLT(FieldEndIP, v)) } // EndIPLTE applies the LTE predicate on the "end_ip" field. func EndIPLTE(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldEndIP), v)) - }) + return predicate.Decision(sql.FieldLTE(FieldEndIP, v)) } // EndIPIsNil applies the IsNil predicate on the "end_ip" field. func EndIPIsNil() predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldEndIP))) - }) + return predicate.Decision(sql.FieldIsNull(FieldEndIP)) } // EndIPNotNil applies the NotNil predicate on the "end_ip" field. func EndIPNotNil() predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldEndIP))) - }) + return predicate.Decision(sql.FieldNotNull(FieldEndIP)) } // StartSuffixEQ applies the EQ predicate on the "start_suffix" field. func StartSuffixEQ(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldStartSuffix), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldStartSuffix, v)) } // StartSuffixNEQ applies the NEQ predicate on the "start_suffix" field. func StartSuffixNEQ(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldStartSuffix), v)) - }) + return predicate.Decision(sql.FieldNEQ(FieldStartSuffix, v)) } // StartSuffixIn applies the In predicate on the "start_suffix" field. func StartSuffixIn(vs ...int64) predicate.Decision { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldStartSuffix), v...)) - }) + return predicate.Decision(sql.FieldIn(FieldStartSuffix, vs...)) } // StartSuffixNotIn applies the NotIn predicate on the "start_suffix" field. func StartSuffixNotIn(vs ...int64) predicate.Decision { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldStartSuffix), v...)) - }) + return predicate.Decision(sql.FieldNotIn(FieldStartSuffix, vs...)) } // StartSuffixGT applies the GT predicate on the "start_suffix" field. func StartSuffixGT(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldStartSuffix), v)) - }) + return predicate.Decision(sql.FieldGT(FieldStartSuffix, v)) } // StartSuffixGTE applies the GTE predicate on the "start_suffix" field. func StartSuffixGTE(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldStartSuffix), v)) - }) + return predicate.Decision(sql.FieldGTE(FieldStartSuffix, v)) } // StartSuffixLT applies the LT predicate on the "start_suffix" field. func StartSuffixLT(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldStartSuffix), v)) - }) + return predicate.Decision(sql.FieldLT(FieldStartSuffix, v)) } // StartSuffixLTE applies the LTE predicate on the "start_suffix" field. func StartSuffixLTE(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldStartSuffix), v)) - }) + return predicate.Decision(sql.FieldLTE(FieldStartSuffix, v)) } // StartSuffixIsNil applies the IsNil predicate on the "start_suffix" field. func StartSuffixIsNil() predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldStartSuffix))) - }) + return predicate.Decision(sql.FieldIsNull(FieldStartSuffix)) } // StartSuffixNotNil applies the NotNil predicate on the "start_suffix" field. func StartSuffixNotNil() predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldStartSuffix))) - }) + return predicate.Decision(sql.FieldNotNull(FieldStartSuffix)) } // EndSuffixEQ applies the EQ predicate on the "end_suffix" field. func EndSuffixEQ(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldEndSuffix), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldEndSuffix, v)) } // EndSuffixNEQ applies the NEQ predicate on the "end_suffix" field. func EndSuffixNEQ(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldEndSuffix), v)) - }) + return predicate.Decision(sql.FieldNEQ(FieldEndSuffix, v)) } // EndSuffixIn applies the In predicate on the "end_suffix" field. func EndSuffixIn(vs ...int64) predicate.Decision { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldEndSuffix), v...)) - }) + return predicate.Decision(sql.FieldIn(FieldEndSuffix, vs...)) } // EndSuffixNotIn applies the NotIn predicate on the "end_suffix" field. func EndSuffixNotIn(vs ...int64) predicate.Decision { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldEndSuffix), v...)) - }) + return predicate.Decision(sql.FieldNotIn(FieldEndSuffix, vs...)) } // EndSuffixGT applies the GT predicate on the "end_suffix" field. func EndSuffixGT(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldEndSuffix), v)) - }) + return predicate.Decision(sql.FieldGT(FieldEndSuffix, v)) } // EndSuffixGTE applies the GTE predicate on the "end_suffix" field. func EndSuffixGTE(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldEndSuffix), v)) - }) + return predicate.Decision(sql.FieldGTE(FieldEndSuffix, v)) } // EndSuffixLT applies the LT predicate on the "end_suffix" field. func EndSuffixLT(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldEndSuffix), v)) - }) + return predicate.Decision(sql.FieldLT(FieldEndSuffix, v)) } // EndSuffixLTE applies the LTE predicate on the "end_suffix" field. func EndSuffixLTE(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldEndSuffix), v)) - }) + return predicate.Decision(sql.FieldLTE(FieldEndSuffix, v)) } // EndSuffixIsNil applies the IsNil predicate on the "end_suffix" field. func EndSuffixIsNil() predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldEndSuffix))) - }) + return predicate.Decision(sql.FieldIsNull(FieldEndSuffix)) } // EndSuffixNotNil applies the NotNil predicate on the "end_suffix" field. func EndSuffixNotNil() predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldEndSuffix))) - }) + return predicate.Decision(sql.FieldNotNull(FieldEndSuffix)) } // IPSizeEQ applies the EQ predicate on the "ip_size" field. func IPSizeEQ(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldIPSize), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldIPSize, v)) } // IPSizeNEQ applies the NEQ predicate on the "ip_size" field. func IPSizeNEQ(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldIPSize), v)) - }) + return predicate.Decision(sql.FieldNEQ(FieldIPSize, v)) } // IPSizeIn applies the In predicate on the "ip_size" field. func IPSizeIn(vs ...int64) predicate.Decision { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldIPSize), v...)) - }) + return predicate.Decision(sql.FieldIn(FieldIPSize, vs...)) } // IPSizeNotIn applies the NotIn predicate on the "ip_size" field. func IPSizeNotIn(vs ...int64) predicate.Decision { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldIPSize), v...)) - }) + return predicate.Decision(sql.FieldNotIn(FieldIPSize, vs...)) } // IPSizeGT applies the GT predicate on the "ip_size" field. func IPSizeGT(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldIPSize), v)) - }) + return predicate.Decision(sql.FieldGT(FieldIPSize, v)) } // IPSizeGTE applies the GTE predicate on the "ip_size" field. func IPSizeGTE(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldIPSize), v)) - }) + return predicate.Decision(sql.FieldGTE(FieldIPSize, v)) } // IPSizeLT applies the LT predicate on the "ip_size" field. func IPSizeLT(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldIPSize), v)) - }) + return predicate.Decision(sql.FieldLT(FieldIPSize, v)) } // IPSizeLTE applies the LTE predicate on the "ip_size" field. func IPSizeLTE(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldIPSize), v)) - }) + return predicate.Decision(sql.FieldLTE(FieldIPSize, v)) } // IPSizeIsNil applies the IsNil predicate on the "ip_size" field. func IPSizeIsNil() predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldIPSize))) - }) + return predicate.Decision(sql.FieldIsNull(FieldIPSize)) } // IPSizeNotNil applies the NotNil predicate on the "ip_size" field. func IPSizeNotNil() predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldIPSize))) - }) + return predicate.Decision(sql.FieldNotNull(FieldIPSize)) } // ScopeEQ applies the EQ predicate on the "scope" field. func ScopeEQ(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldScope), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldScope, v)) } // ScopeNEQ applies the NEQ predicate on the "scope" field. func ScopeNEQ(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldScope), v)) - }) + return predicate.Decision(sql.FieldNEQ(FieldScope, v)) } // ScopeIn applies the In predicate on the "scope" field. func ScopeIn(vs ...string) predicate.Decision { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldScope), v...)) - }) + return predicate.Decision(sql.FieldIn(FieldScope, vs...)) } // ScopeNotIn applies the NotIn predicate on the "scope" field. func ScopeNotIn(vs ...string) predicate.Decision { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldScope), v...)) - }) + return predicate.Decision(sql.FieldNotIn(FieldScope, vs...)) } // ScopeGT applies the GT predicate on the "scope" field. func ScopeGT(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldScope), v)) - }) + return predicate.Decision(sql.FieldGT(FieldScope, v)) } // ScopeGTE applies the GTE predicate on the "scope" field. func ScopeGTE(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldScope), v)) - }) + return predicate.Decision(sql.FieldGTE(FieldScope, v)) } // ScopeLT applies the LT predicate on the "scope" field. func ScopeLT(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldScope), v)) - }) + return predicate.Decision(sql.FieldLT(FieldScope, v)) } // ScopeLTE applies the LTE predicate on the "scope" field. func ScopeLTE(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldScope), v)) - }) + return predicate.Decision(sql.FieldLTE(FieldScope, v)) } // ScopeContains applies the Contains predicate on the "scope" field. func ScopeContains(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldScope), v)) - }) + return predicate.Decision(sql.FieldContains(FieldScope, v)) } // ScopeHasPrefix applies the HasPrefix predicate on the "scope" field. func ScopeHasPrefix(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldScope), v)) - }) + return predicate.Decision(sql.FieldHasPrefix(FieldScope, v)) } // ScopeHasSuffix applies the HasSuffix predicate on the "scope" field. func ScopeHasSuffix(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldScope), v)) - }) + return predicate.Decision(sql.FieldHasSuffix(FieldScope, v)) } // ScopeEqualFold applies the EqualFold predicate on the "scope" field. func ScopeEqualFold(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldScope), v)) - }) + return predicate.Decision(sql.FieldEqualFold(FieldScope, v)) } // ScopeContainsFold applies the ContainsFold predicate on the "scope" field. func ScopeContainsFold(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldScope), v)) - }) + return predicate.Decision(sql.FieldContainsFold(FieldScope, v)) } // ValueEQ applies the EQ predicate on the "value" field. func ValueEQ(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldValue), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldValue, v)) } // ValueNEQ applies the NEQ predicate on the "value" field. func ValueNEQ(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldValue), v)) - }) + return predicate.Decision(sql.FieldNEQ(FieldValue, v)) } // ValueIn applies the In predicate on the "value" field. func ValueIn(vs ...string) predicate.Decision { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldValue), v...)) - }) + return predicate.Decision(sql.FieldIn(FieldValue, vs...)) } // ValueNotIn applies the NotIn predicate on the "value" field. func ValueNotIn(vs ...string) predicate.Decision { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldValue), v...)) - }) + return predicate.Decision(sql.FieldNotIn(FieldValue, vs...)) } // ValueGT applies the GT predicate on the "value" field. func ValueGT(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldValue), v)) - }) + return predicate.Decision(sql.FieldGT(FieldValue, v)) } // ValueGTE applies the GTE predicate on the "value" field. func ValueGTE(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldValue), v)) - }) + return predicate.Decision(sql.FieldGTE(FieldValue, v)) } // ValueLT applies the LT predicate on the "value" field. func ValueLT(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldValue), v)) - }) + return predicate.Decision(sql.FieldLT(FieldValue, v)) } // ValueLTE applies the LTE predicate on the "value" field. func ValueLTE(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldValue), v)) - }) + return predicate.Decision(sql.FieldLTE(FieldValue, v)) } // ValueContains applies the Contains predicate on the "value" field. func ValueContains(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldValue), v)) - }) + return predicate.Decision(sql.FieldContains(FieldValue, v)) } // ValueHasPrefix applies the HasPrefix predicate on the "value" field. func ValueHasPrefix(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldValue), v)) - }) + return predicate.Decision(sql.FieldHasPrefix(FieldValue, v)) } // ValueHasSuffix applies the HasSuffix predicate on the "value" field. func ValueHasSuffix(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldValue), v)) - }) + return predicate.Decision(sql.FieldHasSuffix(FieldValue, v)) } // ValueEqualFold applies the EqualFold predicate on the "value" field. func ValueEqualFold(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldValue), v)) - }) + return predicate.Decision(sql.FieldEqualFold(FieldValue, v)) } // ValueContainsFold applies the ContainsFold predicate on the "value" field. func ValueContainsFold(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldValue), v)) - }) + return predicate.Decision(sql.FieldContainsFold(FieldValue, v)) } // OriginEQ applies the EQ predicate on the "origin" field. func OriginEQ(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldOrigin), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldOrigin, v)) } // OriginNEQ applies the NEQ predicate on the "origin" field. func OriginNEQ(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldOrigin), v)) - }) + return predicate.Decision(sql.FieldNEQ(FieldOrigin, v)) } // OriginIn applies the In predicate on the "origin" field. func OriginIn(vs ...string) predicate.Decision { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldOrigin), v...)) - }) + return predicate.Decision(sql.FieldIn(FieldOrigin, vs...)) } // OriginNotIn applies the NotIn predicate on the "origin" field. func OriginNotIn(vs ...string) predicate.Decision { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldOrigin), v...)) - }) + return predicate.Decision(sql.FieldNotIn(FieldOrigin, vs...)) } // OriginGT applies the GT predicate on the "origin" field. func OriginGT(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldOrigin), v)) - }) + return predicate.Decision(sql.FieldGT(FieldOrigin, v)) } // OriginGTE applies the GTE predicate on the "origin" field. func OriginGTE(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldOrigin), v)) - }) + return predicate.Decision(sql.FieldGTE(FieldOrigin, v)) } // OriginLT applies the LT predicate on the "origin" field. func OriginLT(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldOrigin), v)) - }) + return predicate.Decision(sql.FieldLT(FieldOrigin, v)) } // OriginLTE applies the LTE predicate on the "origin" field. func OriginLTE(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldOrigin), v)) - }) + return predicate.Decision(sql.FieldLTE(FieldOrigin, v)) } // OriginContains applies the Contains predicate on the "origin" field. func OriginContains(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldOrigin), v)) - }) + return predicate.Decision(sql.FieldContains(FieldOrigin, v)) } // OriginHasPrefix applies the HasPrefix predicate on the "origin" field. func OriginHasPrefix(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldOrigin), v)) - }) + return predicate.Decision(sql.FieldHasPrefix(FieldOrigin, v)) } // OriginHasSuffix applies the HasSuffix predicate on the "origin" field. func OriginHasSuffix(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldOrigin), v)) - }) + return predicate.Decision(sql.FieldHasSuffix(FieldOrigin, v)) } // OriginEqualFold applies the EqualFold predicate on the "origin" field. func OriginEqualFold(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldOrigin), v)) - }) + return predicate.Decision(sql.FieldEqualFold(FieldOrigin, v)) } // OriginContainsFold applies the ContainsFold predicate on the "origin" field. func OriginContainsFold(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldOrigin), v)) - }) + return predicate.Decision(sql.FieldContainsFold(FieldOrigin, v)) } // SimulatedEQ applies the EQ predicate on the "simulated" field. func SimulatedEQ(v bool) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldSimulated), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldSimulated, v)) } // SimulatedNEQ applies the NEQ predicate on the "simulated" field. func SimulatedNEQ(v bool) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldSimulated), v)) - }) + return predicate.Decision(sql.FieldNEQ(FieldSimulated, v)) } // UUIDEQ applies the EQ predicate on the "uuid" field. func UUIDEQ(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldUUID), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldUUID, v)) } // UUIDNEQ applies the NEQ predicate on the "uuid" field. func UUIDNEQ(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldUUID), v)) - }) + return predicate.Decision(sql.FieldNEQ(FieldUUID, v)) } // UUIDIn applies the In predicate on the "uuid" field. func UUIDIn(vs ...string) predicate.Decision { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldUUID), v...)) - }) + return predicate.Decision(sql.FieldIn(FieldUUID, vs...)) } // UUIDNotIn applies the NotIn predicate on the "uuid" field. func UUIDNotIn(vs ...string) predicate.Decision { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldUUID), v...)) - }) + return predicate.Decision(sql.FieldNotIn(FieldUUID, vs...)) } // UUIDGT applies the GT predicate on the "uuid" field. func UUIDGT(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldUUID), v)) - }) + return predicate.Decision(sql.FieldGT(FieldUUID, v)) } // UUIDGTE applies the GTE predicate on the "uuid" field. func UUIDGTE(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldUUID), v)) - }) + return predicate.Decision(sql.FieldGTE(FieldUUID, v)) } // UUIDLT applies the LT predicate on the "uuid" field. func UUIDLT(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldUUID), v)) - }) + return predicate.Decision(sql.FieldLT(FieldUUID, v)) } // UUIDLTE applies the LTE predicate on the "uuid" field. func UUIDLTE(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldUUID), v)) - }) + return predicate.Decision(sql.FieldLTE(FieldUUID, v)) } // UUIDContains applies the Contains predicate on the "uuid" field. func UUIDContains(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldUUID), v)) - }) + return predicate.Decision(sql.FieldContains(FieldUUID, v)) } // UUIDHasPrefix applies the HasPrefix predicate on the "uuid" field. func UUIDHasPrefix(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldUUID), v)) - }) + return predicate.Decision(sql.FieldHasPrefix(FieldUUID, v)) } // UUIDHasSuffix applies the HasSuffix predicate on the "uuid" field. func UUIDHasSuffix(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldUUID), v)) - }) + return predicate.Decision(sql.FieldHasSuffix(FieldUUID, v)) } // UUIDIsNil applies the IsNil predicate on the "uuid" field. func UUIDIsNil() predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldUUID))) - }) + return predicate.Decision(sql.FieldIsNull(FieldUUID)) } // UUIDNotNil applies the NotNil predicate on the "uuid" field. func UUIDNotNil() predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldUUID))) - }) + return predicate.Decision(sql.FieldNotNull(FieldUUID)) } // UUIDEqualFold applies the EqualFold predicate on the "uuid" field. func UUIDEqualFold(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldUUID), v)) - }) + return predicate.Decision(sql.FieldEqualFold(FieldUUID, v)) } // UUIDContainsFold applies the ContainsFold predicate on the "uuid" field. func UUIDContainsFold(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldUUID), v)) - }) + return predicate.Decision(sql.FieldContainsFold(FieldUUID, v)) } // AlertDecisionsEQ applies the EQ predicate on the "alert_decisions" field. func AlertDecisionsEQ(v int) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldAlertDecisions), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldAlertDecisions, v)) } // AlertDecisionsNEQ applies the NEQ predicate on the "alert_decisions" field. func AlertDecisionsNEQ(v int) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldAlertDecisions), v)) - }) + return predicate.Decision(sql.FieldNEQ(FieldAlertDecisions, v)) } // AlertDecisionsIn applies the In predicate on the "alert_decisions" field. func AlertDecisionsIn(vs ...int) predicate.Decision { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldAlertDecisions), v...)) - }) + return predicate.Decision(sql.FieldIn(FieldAlertDecisions, vs...)) } // AlertDecisionsNotIn applies the NotIn predicate on the "alert_decisions" field. func AlertDecisionsNotIn(vs ...int) predicate.Decision { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldAlertDecisions), v...)) - }) + return predicate.Decision(sql.FieldNotIn(FieldAlertDecisions, vs...)) } // AlertDecisionsIsNil applies the IsNil predicate on the "alert_decisions" field. func AlertDecisionsIsNil() predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldAlertDecisions))) - }) + return predicate.Decision(sql.FieldIsNull(FieldAlertDecisions)) } // AlertDecisionsNotNil applies the NotNil predicate on the "alert_decisions" field. func AlertDecisionsNotNil() predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldAlertDecisions))) - }) + return predicate.Decision(sql.FieldNotNull(FieldAlertDecisions)) } // HasOwner applies the HasEdge predicate on the "owner" edge. @@ -1494,7 +980,6 @@ func HasOwner() predicate.Decision { return predicate.Decision(func(s *sql.Selector) { step := sqlgraph.NewStep( sqlgraph.From(Table, FieldID), - sqlgraph.To(OwnerTable, FieldID), sqlgraph.Edge(sqlgraph.M2O, true, OwnerTable, OwnerColumn), ) sqlgraph.HasNeighbors(s, step) @@ -1504,11 +989,7 @@ func HasOwner() predicate.Decision { // HasOwnerWith applies the HasEdge predicate on the "owner" edge with a given conditions (other predicates). func HasOwnerWith(preds ...predicate.Alert) predicate.Decision { return predicate.Decision(func(s *sql.Selector) { - step := sqlgraph.NewStep( - sqlgraph.From(Table, FieldID), - sqlgraph.To(OwnerInverseTable, FieldID), - sqlgraph.Edge(sqlgraph.M2O, true, OwnerTable, OwnerColumn), - ) + step := newOwnerStep() sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { for _, p := range preds { p(s) @@ -1519,32 +1000,15 @@ func HasOwnerWith(preds ...predicate.Alert) predicate.Decision { // And groups predicates with the AND operator between them. func And(predicates ...predicate.Decision) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s1 := s.Clone().SetP(nil) - for _, p := range predicates { - p(s1) - } - s.Where(s1.P()) - }) + return predicate.Decision(sql.AndPredicates(predicates...)) } // Or groups predicates with the OR operator between them. func Or(predicates ...predicate.Decision) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s1 := s.Clone().SetP(nil) - for i, p := range predicates { - if i > 0 { - s1.Or() - } - p(s1) - } - s.Where(s1.P()) - }) + return predicate.Decision(sql.OrPredicates(predicates...)) } // Not applies the not operator on the given predicate. func Not(p predicate.Decision) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - p(s.Not()) - }) + return predicate.Decision(sql.NotPredicates(p)) } diff --git a/pkg/database/ent/decision_create.go b/pkg/database/ent/decision_create.go index 64238cb7003..43a28c53114 100644 --- a/pkg/database/ent/decision_create.go +++ b/pkg/database/ent/decision_create.go @@ -231,50 +231,8 @@ func (dc *DecisionCreate) Mutation() *DecisionMutation { // Save creates the Decision in the database. func (dc *DecisionCreate) Save(ctx context.Context) (*Decision, error) { - var ( - err error - node *Decision - ) dc.defaults() - if len(dc.hooks) == 0 { - if err = dc.check(); err != nil { - return nil, err - } - node, err = dc.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*DecisionMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - if err = dc.check(); err != nil { - return nil, err - } - dc.mutation = mutation - if node, err = dc.sqlSave(ctx); err != nil { - return nil, err - } - mutation.id = &node.ID - mutation.done = true - return node, err - }) - for i := len(dc.hooks) - 1; i >= 0; i-- { - if dc.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = dc.hooks[i](mut) - } - v, err := mut.Mutate(ctx, dc.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*Decision) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from DecisionMutation", v) - } - node = nv - } - return node, err + return withHooks(ctx, dc.sqlSave, dc.mutation, dc.hooks) } // SaveX calls Save and panics if Save returns an error. @@ -339,6 +297,9 @@ func (dc *DecisionCreate) check() error { } func (dc *DecisionCreate) sqlSave(ctx context.Context) (*Decision, error) { + if err := dc.check(); err != nil { + return nil, err + } _node, _spec := dc.createSpec() if err := sqlgraph.CreateNode(ctx, dc.driver, _spec); err != nil { if sqlgraph.IsConstraintError(err) { @@ -348,138 +309,74 @@ func (dc *DecisionCreate) sqlSave(ctx context.Context) (*Decision, error) { } id := _spec.ID.Value.(int64) _node.ID = int(id) + dc.mutation.id = &_node.ID + dc.mutation.done = true return _node, nil } func (dc *DecisionCreate) createSpec() (*Decision, *sqlgraph.CreateSpec) { var ( _node = &Decision{config: dc.config} - _spec = &sqlgraph.CreateSpec{ - Table: decision.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: decision.FieldID, - }, - } + _spec = sqlgraph.NewCreateSpec(decision.Table, sqlgraph.NewFieldSpec(decision.FieldID, field.TypeInt)) ) if value, ok := dc.mutation.CreatedAt(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: decision.FieldCreatedAt, - }) + _spec.SetField(decision.FieldCreatedAt, field.TypeTime, value) _node.CreatedAt = &value } if value, ok := dc.mutation.UpdatedAt(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: decision.FieldUpdatedAt, - }) + _spec.SetField(decision.FieldUpdatedAt, field.TypeTime, value) _node.UpdatedAt = &value } if value, ok := dc.mutation.Until(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: decision.FieldUntil, - }) + _spec.SetField(decision.FieldUntil, field.TypeTime, value) _node.Until = &value } if value, ok := dc.mutation.Scenario(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: decision.FieldScenario, - }) + _spec.SetField(decision.FieldScenario, field.TypeString, value) _node.Scenario = value } if value, ok := dc.mutation.GetType(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: decision.FieldType, - }) + _spec.SetField(decision.FieldType, field.TypeString, value) _node.Type = value } if value, ok := dc.mutation.StartIP(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Value: value, - Column: decision.FieldStartIP, - }) + _spec.SetField(decision.FieldStartIP, field.TypeInt64, value) _node.StartIP = value } if value, ok := dc.mutation.EndIP(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Value: value, - Column: decision.FieldEndIP, - }) + _spec.SetField(decision.FieldEndIP, field.TypeInt64, value) _node.EndIP = value } if value, ok := dc.mutation.StartSuffix(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Value: value, - Column: decision.FieldStartSuffix, - }) + _spec.SetField(decision.FieldStartSuffix, field.TypeInt64, value) _node.StartSuffix = value } if value, ok := dc.mutation.EndSuffix(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Value: value, - Column: decision.FieldEndSuffix, - }) + _spec.SetField(decision.FieldEndSuffix, field.TypeInt64, value) _node.EndSuffix = value } if value, ok := dc.mutation.IPSize(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Value: value, - Column: decision.FieldIPSize, - }) + _spec.SetField(decision.FieldIPSize, field.TypeInt64, value) _node.IPSize = value } if value, ok := dc.mutation.Scope(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: decision.FieldScope, - }) + _spec.SetField(decision.FieldScope, field.TypeString, value) _node.Scope = value } if value, ok := dc.mutation.Value(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: decision.FieldValue, - }) + _spec.SetField(decision.FieldValue, field.TypeString, value) _node.Value = value } if value, ok := dc.mutation.Origin(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: decision.FieldOrigin, - }) + _spec.SetField(decision.FieldOrigin, field.TypeString, value) _node.Origin = value } if value, ok := dc.mutation.Simulated(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeBool, - Value: value, - Column: decision.FieldSimulated, - }) + _spec.SetField(decision.FieldSimulated, field.TypeBool, value) _node.Simulated = value } if value, ok := dc.mutation.UUID(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: decision.FieldUUID, - }) + _spec.SetField(decision.FieldUUID, field.TypeString, value) _node.UUID = value } if nodes := dc.mutation.OwnerIDs(); len(nodes) > 0 { @@ -490,10 +387,7 @@ func (dc *DecisionCreate) createSpec() (*Decision, *sqlgraph.CreateSpec) { Columns: []string{decision.OwnerColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: alert.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -508,11 +402,15 @@ func (dc *DecisionCreate) createSpec() (*Decision, *sqlgraph.CreateSpec) { // DecisionCreateBulk is the builder for creating many Decision entities in bulk. type DecisionCreateBulk struct { config + err error builders []*DecisionCreate } // Save creates the Decision entities in the database. func (dcb *DecisionCreateBulk) Save(ctx context.Context) ([]*Decision, error) { + if dcb.err != nil { + return nil, dcb.err + } specs := make([]*sqlgraph.CreateSpec, len(dcb.builders)) nodes := make([]*Decision, len(dcb.builders)) mutators := make([]Mutator, len(dcb.builders)) @@ -529,8 +427,8 @@ func (dcb *DecisionCreateBulk) Save(ctx context.Context) ([]*Decision, error) { return nil, err } builder.mutation = mutation - nodes[i], specs[i] = builder.createSpec() var err error + nodes[i], specs[i] = builder.createSpec() if i < len(mutators)-1 { _, err = mutators[i+1].Mutate(root, dcb.builders[i+1].mutation) } else { diff --git a/pkg/database/ent/decision_delete.go b/pkg/database/ent/decision_delete.go index 24b494b113e..35bb8767283 100644 --- a/pkg/database/ent/decision_delete.go +++ b/pkg/database/ent/decision_delete.go @@ -4,7 +4,6 @@ package ent import ( "context" - "fmt" "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" @@ -28,34 +27,7 @@ func (dd *DecisionDelete) Where(ps ...predicate.Decision) *DecisionDelete { // Exec executes the deletion query and returns how many vertices were deleted. func (dd *DecisionDelete) Exec(ctx context.Context) (int, error) { - var ( - err error - affected int - ) - if len(dd.hooks) == 0 { - affected, err = dd.sqlExec(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*DecisionMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - dd.mutation = mutation - affected, err = dd.sqlExec(ctx) - mutation.done = true - return affected, err - }) - for i := len(dd.hooks) - 1; i >= 0; i-- { - if dd.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = dd.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, dd.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks(ctx, dd.sqlExec, dd.mutation, dd.hooks) } // ExecX is like Exec, but panics if an error occurs. @@ -68,15 +40,7 @@ func (dd *DecisionDelete) ExecX(ctx context.Context) int { } func (dd *DecisionDelete) sqlExec(ctx context.Context) (int, error) { - _spec := &sqlgraph.DeleteSpec{ - Node: &sqlgraph.NodeSpec{ - Table: decision.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: decision.FieldID, - }, - }, - } + _spec := sqlgraph.NewDeleteSpec(decision.Table, sqlgraph.NewFieldSpec(decision.FieldID, field.TypeInt)) if ps := dd.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -88,6 +52,7 @@ func (dd *DecisionDelete) sqlExec(ctx context.Context) (int, error) { if err != nil && sqlgraph.IsConstraintError(err) { err = &ConstraintError{msg: err.Error(), wrap: err} } + dd.mutation.done = true return affected, err } @@ -96,6 +61,12 @@ type DecisionDeleteOne struct { dd *DecisionDelete } +// Where appends a list predicates to the DecisionDelete builder. +func (ddo *DecisionDeleteOne) Where(ps ...predicate.Decision) *DecisionDeleteOne { + ddo.dd.mutation.Where(ps...) + return ddo +} + // Exec executes the deletion query. func (ddo *DecisionDeleteOne) Exec(ctx context.Context) error { n, err := ddo.dd.Exec(ctx) @@ -111,5 +82,7 @@ func (ddo *DecisionDeleteOne) Exec(ctx context.Context) error { // ExecX is like Exec, but panics if an error occurs. func (ddo *DecisionDeleteOne) ExecX(ctx context.Context) { - ddo.dd.ExecX(ctx) + if err := ddo.Exec(ctx); err != nil { + panic(err) + } } diff --git a/pkg/database/ent/decision_query.go b/pkg/database/ent/decision_query.go index 91aebded968..b050a4d9649 100644 --- a/pkg/database/ent/decision_query.go +++ b/pkg/database/ent/decision_query.go @@ -18,11 +18,9 @@ import ( // DecisionQuery is the builder for querying Decision entities. type DecisionQuery struct { config - limit *int - offset *int - unique *bool - order []OrderFunc - fields []string + ctx *QueryContext + order []decision.OrderOption + inters []Interceptor predicates []predicate.Decision withOwner *AlertQuery // intermediate query (i.e. traversal path). @@ -36,34 +34,34 @@ func (dq *DecisionQuery) Where(ps ...predicate.Decision) *DecisionQuery { return dq } -// Limit adds a limit step to the query. +// Limit the number of records to be returned by this query. func (dq *DecisionQuery) Limit(limit int) *DecisionQuery { - dq.limit = &limit + dq.ctx.Limit = &limit return dq } -// Offset adds an offset step to the query. +// Offset to start from. func (dq *DecisionQuery) Offset(offset int) *DecisionQuery { - dq.offset = &offset + dq.ctx.Offset = &offset return dq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (dq *DecisionQuery) Unique(unique bool) *DecisionQuery { - dq.unique = &unique + dq.ctx.Unique = &unique return dq } -// Order adds an order step to the query. -func (dq *DecisionQuery) Order(o ...OrderFunc) *DecisionQuery { +// Order specifies how the records should be ordered. +func (dq *DecisionQuery) Order(o ...decision.OrderOption) *DecisionQuery { dq.order = append(dq.order, o...) return dq } // QueryOwner chains the current query on the "owner" edge. func (dq *DecisionQuery) QueryOwner() *AlertQuery { - query := &AlertQuery{config: dq.config} + query := (&AlertClient{config: dq.config}).Query() query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { if err := dq.prepareQuery(ctx); err != nil { return nil, err @@ -86,7 +84,7 @@ func (dq *DecisionQuery) QueryOwner() *AlertQuery { // First returns the first Decision entity from the query. // Returns a *NotFoundError when no Decision was found. func (dq *DecisionQuery) First(ctx context.Context) (*Decision, error) { - nodes, err := dq.Limit(1).All(ctx) + nodes, err := dq.Limit(1).All(setContextOp(ctx, dq.ctx, "First")) if err != nil { return nil, err } @@ -109,7 +107,7 @@ func (dq *DecisionQuery) FirstX(ctx context.Context) *Decision { // Returns a *NotFoundError when no Decision ID was found. func (dq *DecisionQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = dq.Limit(1).IDs(ctx); err != nil { + if ids, err = dq.Limit(1).IDs(setContextOp(ctx, dq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -132,7 +130,7 @@ func (dq *DecisionQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one Decision entity is found. // Returns a *NotFoundError when no Decision entities are found. func (dq *DecisionQuery) Only(ctx context.Context) (*Decision, error) { - nodes, err := dq.Limit(2).All(ctx) + nodes, err := dq.Limit(2).All(setContextOp(ctx, dq.ctx, "Only")) if err != nil { return nil, err } @@ -160,7 +158,7 @@ func (dq *DecisionQuery) OnlyX(ctx context.Context) *Decision { // Returns a *NotFoundError when no entities are found. func (dq *DecisionQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = dq.Limit(2).IDs(ctx); err != nil { + if ids, err = dq.Limit(2).IDs(setContextOp(ctx, dq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -185,10 +183,12 @@ func (dq *DecisionQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of Decisions. func (dq *DecisionQuery) All(ctx context.Context) ([]*Decision, error) { + ctx = setContextOp(ctx, dq.ctx, "All") if err := dq.prepareQuery(ctx); err != nil { return nil, err } - return dq.sqlAll(ctx) + qr := querierAll[[]*Decision, *DecisionQuery]() + return withInterceptors[[]*Decision](ctx, dq, qr, dq.inters) } // AllX is like All, but panics if an error occurs. @@ -201,9 +201,12 @@ func (dq *DecisionQuery) AllX(ctx context.Context) []*Decision { } // IDs executes the query and returns a list of Decision IDs. -func (dq *DecisionQuery) IDs(ctx context.Context) ([]int, error) { - var ids []int - if err := dq.Select(decision.FieldID).Scan(ctx, &ids); err != nil { +func (dq *DecisionQuery) IDs(ctx context.Context) (ids []int, err error) { + if dq.ctx.Unique == nil && dq.path != nil { + dq.Unique(true) + } + ctx = setContextOp(ctx, dq.ctx, "IDs") + if err = dq.Select(decision.FieldID).Scan(ctx, &ids); err != nil { return nil, err } return ids, nil @@ -220,10 +223,11 @@ func (dq *DecisionQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (dq *DecisionQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, dq.ctx, "Count") if err := dq.prepareQuery(ctx); err != nil { return 0, err } - return dq.sqlCount(ctx) + return withInterceptors[int](ctx, dq, querierCount[*DecisionQuery](), dq.inters) } // CountX is like Count, but panics if an error occurs. @@ -237,10 +241,15 @@ func (dq *DecisionQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (dq *DecisionQuery) Exist(ctx context.Context) (bool, error) { - if err := dq.prepareQuery(ctx); err != nil { - return false, err + ctx = setContextOp(ctx, dq.ctx, "Exist") + switch _, err := dq.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil } - return dq.sqlExist(ctx) } // ExistX is like Exist, but panics if an error occurs. @@ -260,22 +269,21 @@ func (dq *DecisionQuery) Clone() *DecisionQuery { } return &DecisionQuery{ config: dq.config, - limit: dq.limit, - offset: dq.offset, - order: append([]OrderFunc{}, dq.order...), + ctx: dq.ctx.Clone(), + order: append([]decision.OrderOption{}, dq.order...), + inters: append([]Interceptor{}, dq.inters...), predicates: append([]predicate.Decision{}, dq.predicates...), withOwner: dq.withOwner.Clone(), // clone intermediate query. - sql: dq.sql.Clone(), - path: dq.path, - unique: dq.unique, + sql: dq.sql.Clone(), + path: dq.path, } } // WithOwner tells the query-builder to eager-load the nodes that are connected to // the "owner" edge. The optional arguments are used to configure the query builder of the edge. func (dq *DecisionQuery) WithOwner(opts ...func(*AlertQuery)) *DecisionQuery { - query := &AlertQuery{config: dq.config} + query := (&AlertClient{config: dq.config}).Query() for _, opt := range opts { opt(query) } @@ -298,16 +306,11 @@ func (dq *DecisionQuery) WithOwner(opts ...func(*AlertQuery)) *DecisionQuery { // Aggregate(ent.Count()). // Scan(ctx, &v) func (dq *DecisionQuery) GroupBy(field string, fields ...string) *DecisionGroupBy { - grbuild := &DecisionGroupBy{config: dq.config} - grbuild.fields = append([]string{field}, fields...) - grbuild.path = func(ctx context.Context) (prev *sql.Selector, err error) { - if err := dq.prepareQuery(ctx); err != nil { - return nil, err - } - return dq.sqlQuery(ctx), nil - } + dq.ctx.Fields = append([]string{field}, fields...) + grbuild := &DecisionGroupBy{build: dq} + grbuild.flds = &dq.ctx.Fields grbuild.label = decision.Label - grbuild.flds, grbuild.scan = &grbuild.fields, grbuild.Scan + grbuild.scan = grbuild.Scan return grbuild } @@ -324,15 +327,30 @@ func (dq *DecisionQuery) GroupBy(field string, fields ...string) *DecisionGroupB // Select(decision.FieldCreatedAt). // Scan(ctx, &v) func (dq *DecisionQuery) Select(fields ...string) *DecisionSelect { - dq.fields = append(dq.fields, fields...) - selbuild := &DecisionSelect{DecisionQuery: dq} - selbuild.label = decision.Label - selbuild.flds, selbuild.scan = &dq.fields, selbuild.Scan - return selbuild + dq.ctx.Fields = append(dq.ctx.Fields, fields...) + sbuild := &DecisionSelect{DecisionQuery: dq} + sbuild.label = decision.Label + sbuild.flds, sbuild.scan = &dq.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a DecisionSelect configured with the given aggregations. +func (dq *DecisionQuery) Aggregate(fns ...AggregateFunc) *DecisionSelect { + return dq.Select().Aggregate(fns...) } func (dq *DecisionQuery) prepareQuery(ctx context.Context) error { - for _, f := range dq.fields { + for _, inter := range dq.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, dq); err != nil { + return err + } + } + } + for _, f := range dq.ctx.Fields { if !decision.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -392,6 +410,9 @@ func (dq *DecisionQuery) loadOwner(ctx context.Context, query *AlertQuery, nodes } nodeids[fk] = append(nodeids[fk], nodes[i]) } + if len(ids) == 0 { + return nil + } query.Where(alert.IDIn(ids...)) neighbors, err := query.All(ctx) if err != nil { @@ -411,41 +432,22 @@ func (dq *DecisionQuery) loadOwner(ctx context.Context, query *AlertQuery, nodes func (dq *DecisionQuery) sqlCount(ctx context.Context) (int, error) { _spec := dq.querySpec() - _spec.Node.Columns = dq.fields - if len(dq.fields) > 0 { - _spec.Unique = dq.unique != nil && *dq.unique + _spec.Node.Columns = dq.ctx.Fields + if len(dq.ctx.Fields) > 0 { + _spec.Unique = dq.ctx.Unique != nil && *dq.ctx.Unique } return sqlgraph.CountNodes(ctx, dq.driver, _spec) } -func (dq *DecisionQuery) sqlExist(ctx context.Context) (bool, error) { - switch _, err := dq.FirstID(ctx); { - case IsNotFound(err): - return false, nil - case err != nil: - return false, fmt.Errorf("ent: check existence: %w", err) - default: - return true, nil - } -} - func (dq *DecisionQuery) querySpec() *sqlgraph.QuerySpec { - _spec := &sqlgraph.QuerySpec{ - Node: &sqlgraph.NodeSpec{ - Table: decision.Table, - Columns: decision.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: decision.FieldID, - }, - }, - From: dq.sql, - Unique: true, - } - if unique := dq.unique; unique != nil { + _spec := sqlgraph.NewQuerySpec(decision.Table, decision.Columns, sqlgraph.NewFieldSpec(decision.FieldID, field.TypeInt)) + _spec.From = dq.sql + if unique := dq.ctx.Unique; unique != nil { _spec.Unique = *unique + } else if dq.path != nil { + _spec.Unique = true } - if fields := dq.fields; len(fields) > 0 { + if fields := dq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, decision.FieldID) for i := range fields { @@ -453,6 +455,9 @@ func (dq *DecisionQuery) querySpec() *sqlgraph.QuerySpec { _spec.Node.Columns = append(_spec.Node.Columns, fields[i]) } } + if dq.withOwner != nil { + _spec.Node.AddColumnOnce(decision.FieldAlertDecisions) + } } if ps := dq.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { @@ -461,10 +466,10 @@ func (dq *DecisionQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := dq.limit; limit != nil { + if limit := dq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := dq.offset; offset != nil { + if offset := dq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := dq.order; len(ps) > 0 { @@ -480,7 +485,7 @@ func (dq *DecisionQuery) querySpec() *sqlgraph.QuerySpec { func (dq *DecisionQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(dq.driver.Dialect()) t1 := builder.Table(decision.Table) - columns := dq.fields + columns := dq.ctx.Fields if len(columns) == 0 { columns = decision.Columns } @@ -489,7 +494,7 @@ func (dq *DecisionQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = dq.sql selector.Select(selector.Columns(columns...)...) } - if dq.unique != nil && *dq.unique { + if dq.ctx.Unique != nil && *dq.ctx.Unique { selector.Distinct() } for _, p := range dq.predicates { @@ -498,12 +503,12 @@ func (dq *DecisionQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range dq.order { p(selector) } - if offset := dq.offset; offset != nil { + if offset := dq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := dq.limit; limit != nil { + if limit := dq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -511,13 +516,8 @@ func (dq *DecisionQuery) sqlQuery(ctx context.Context) *sql.Selector { // DecisionGroupBy is the group-by builder for Decision entities. type DecisionGroupBy struct { - config selector - fields []string - fns []AggregateFunc - // intermediate query (i.e. traversal path). - sql *sql.Selector - path func(context.Context) (*sql.Selector, error) + build *DecisionQuery } // Aggregate adds the given aggregation functions to the group-by query. @@ -526,74 +526,77 @@ func (dgb *DecisionGroupBy) Aggregate(fns ...AggregateFunc) *DecisionGroupBy { return dgb } -// Scan applies the group-by query and scans the result into the given value. +// Scan applies the selector query and scans the result into the given value. func (dgb *DecisionGroupBy) Scan(ctx context.Context, v any) error { - query, err := dgb.path(ctx) - if err != nil { + ctx = setContextOp(ctx, dgb.build.ctx, "GroupBy") + if err := dgb.build.prepareQuery(ctx); err != nil { return err } - dgb.sql = query - return dgb.sqlScan(ctx, v) + return scanWithInterceptors[*DecisionQuery, *DecisionGroupBy](ctx, dgb.build, dgb, dgb.build.inters, v) } -func (dgb *DecisionGroupBy) sqlScan(ctx context.Context, v any) error { - for _, f := range dgb.fields { - if !decision.ValidColumn(f) { - return &ValidationError{Name: f, err: fmt.Errorf("invalid field %q for group-by", f)} +func (dgb *DecisionGroupBy) sqlScan(ctx context.Context, root *DecisionQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(dgb.fns)) + for _, fn := range dgb.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*dgb.flds)+len(dgb.fns)) + for _, f := range *dgb.flds { + columns = append(columns, selector.C(f)) } + columns = append(columns, aggregation...) + selector.Select(columns...) } - selector := dgb.sqlQuery() + selector.GroupBy(selector.Columns(*dgb.flds...)...) if err := selector.Err(); err != nil { return err } rows := &sql.Rows{} query, args := selector.Query() - if err := dgb.driver.Query(ctx, query, args, rows); err != nil { + if err := dgb.build.driver.Query(ctx, query, args, rows); err != nil { return err } defer rows.Close() return sql.ScanSlice(rows, v) } -func (dgb *DecisionGroupBy) sqlQuery() *sql.Selector { - selector := dgb.sql.Select() - aggregation := make([]string, 0, len(dgb.fns)) - for _, fn := range dgb.fns { - aggregation = append(aggregation, fn(selector)) - } - // If no columns were selected in a custom aggregation function, the default - // selection is the fields used for "group-by", and the aggregation functions. - if len(selector.SelectedColumns()) == 0 { - columns := make([]string, 0, len(dgb.fields)+len(dgb.fns)) - for _, f := range dgb.fields { - columns = append(columns, selector.C(f)) - } - columns = append(columns, aggregation...) - selector.Select(columns...) - } - return selector.GroupBy(selector.Columns(dgb.fields...)...) -} - // DecisionSelect is the builder for selecting fields of Decision entities. type DecisionSelect struct { *DecisionQuery selector - // intermediate query (i.e. traversal path). - sql *sql.Selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (ds *DecisionSelect) Aggregate(fns ...AggregateFunc) *DecisionSelect { + ds.fns = append(ds.fns, fns...) + return ds } // Scan applies the selector query and scans the result into the given value. func (ds *DecisionSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, ds.ctx, "Select") if err := ds.prepareQuery(ctx); err != nil { return err } - ds.sql = ds.DecisionQuery.sqlQuery(ctx) - return ds.sqlScan(ctx, v) + return scanWithInterceptors[*DecisionQuery, *DecisionSelect](ctx, ds.DecisionQuery, ds, ds.inters, v) } -func (ds *DecisionSelect) sqlScan(ctx context.Context, v any) error { +func (ds *DecisionSelect) sqlScan(ctx context.Context, root *DecisionQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(ds.fns)) + for _, fn := range ds.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*ds.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } rows := &sql.Rows{} - query, args := ds.sql.Query() + query, args := selector.Query() if err := ds.driver.Query(ctx, query, args, rows); err != nil { return err } diff --git a/pkg/database/ent/decision_update.go b/pkg/database/ent/decision_update.go index 64b40871eca..1b62cc54c30 100644 --- a/pkg/database/ent/decision_update.go +++ b/pkg/database/ent/decision_update.go @@ -324,35 +324,8 @@ func (du *DecisionUpdate) ClearOwner() *DecisionUpdate { // Save executes the query and returns the number of nodes affected by the update operation. func (du *DecisionUpdate) Save(ctx context.Context) (int, error) { - var ( - err error - affected int - ) du.defaults() - if len(du.hooks) == 0 { - affected, err = du.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*DecisionMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - du.mutation = mutation - affected, err = du.sqlSave(ctx) - mutation.done = true - return affected, err - }) - for i := len(du.hooks) - 1; i >= 0; i-- { - if du.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = du.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, du.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks(ctx, du.sqlSave, du.mutation, du.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -390,16 +363,7 @@ func (du *DecisionUpdate) defaults() { } func (du *DecisionUpdate) sqlSave(ctx context.Context) (n int, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: decision.Table, - Columns: decision.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: decision.FieldID, - }, - }, - } + _spec := sqlgraph.NewUpdateSpec(decision.Table, decision.Columns, sqlgraph.NewFieldSpec(decision.FieldID, field.TypeInt)) if ps := du.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -408,198 +372,91 @@ func (du *DecisionUpdate) sqlSave(ctx context.Context) (n int, err error) { } } if value, ok := du.mutation.CreatedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: decision.FieldCreatedAt, - }) + _spec.SetField(decision.FieldCreatedAt, field.TypeTime, value) } if du.mutation.CreatedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: decision.FieldCreatedAt, - }) + _spec.ClearField(decision.FieldCreatedAt, field.TypeTime) } if value, ok := du.mutation.UpdatedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: decision.FieldUpdatedAt, - }) + _spec.SetField(decision.FieldUpdatedAt, field.TypeTime, value) } if du.mutation.UpdatedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: decision.FieldUpdatedAt, - }) + _spec.ClearField(decision.FieldUpdatedAt, field.TypeTime) } if value, ok := du.mutation.Until(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: decision.FieldUntil, - }) + _spec.SetField(decision.FieldUntil, field.TypeTime, value) } if du.mutation.UntilCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: decision.FieldUntil, - }) + _spec.ClearField(decision.FieldUntil, field.TypeTime) } if value, ok := du.mutation.Scenario(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: decision.FieldScenario, - }) + _spec.SetField(decision.FieldScenario, field.TypeString, value) } if value, ok := du.mutation.GetType(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: decision.FieldType, - }) + _spec.SetField(decision.FieldType, field.TypeString, value) } if value, ok := du.mutation.StartIP(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Value: value, - Column: decision.FieldStartIP, - }) + _spec.SetField(decision.FieldStartIP, field.TypeInt64, value) } if value, ok := du.mutation.AddedStartIP(); ok { - _spec.Fields.Add = append(_spec.Fields.Add, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Value: value, - Column: decision.FieldStartIP, - }) + _spec.AddField(decision.FieldStartIP, field.TypeInt64, value) } if du.mutation.StartIPCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Column: decision.FieldStartIP, - }) + _spec.ClearField(decision.FieldStartIP, field.TypeInt64) } if value, ok := du.mutation.EndIP(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Value: value, - Column: decision.FieldEndIP, - }) + _spec.SetField(decision.FieldEndIP, field.TypeInt64, value) } if value, ok := du.mutation.AddedEndIP(); ok { - _spec.Fields.Add = append(_spec.Fields.Add, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Value: value, - Column: decision.FieldEndIP, - }) + _spec.AddField(decision.FieldEndIP, field.TypeInt64, value) } if du.mutation.EndIPCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Column: decision.FieldEndIP, - }) + _spec.ClearField(decision.FieldEndIP, field.TypeInt64) } if value, ok := du.mutation.StartSuffix(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Value: value, - Column: decision.FieldStartSuffix, - }) + _spec.SetField(decision.FieldStartSuffix, field.TypeInt64, value) } if value, ok := du.mutation.AddedStartSuffix(); ok { - _spec.Fields.Add = append(_spec.Fields.Add, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Value: value, - Column: decision.FieldStartSuffix, - }) + _spec.AddField(decision.FieldStartSuffix, field.TypeInt64, value) } if du.mutation.StartSuffixCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Column: decision.FieldStartSuffix, - }) + _spec.ClearField(decision.FieldStartSuffix, field.TypeInt64) } if value, ok := du.mutation.EndSuffix(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Value: value, - Column: decision.FieldEndSuffix, - }) + _spec.SetField(decision.FieldEndSuffix, field.TypeInt64, value) } if value, ok := du.mutation.AddedEndSuffix(); ok { - _spec.Fields.Add = append(_spec.Fields.Add, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Value: value, - Column: decision.FieldEndSuffix, - }) + _spec.AddField(decision.FieldEndSuffix, field.TypeInt64, value) } if du.mutation.EndSuffixCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Column: decision.FieldEndSuffix, - }) + _spec.ClearField(decision.FieldEndSuffix, field.TypeInt64) } if value, ok := du.mutation.IPSize(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Value: value, - Column: decision.FieldIPSize, - }) + _spec.SetField(decision.FieldIPSize, field.TypeInt64, value) } if value, ok := du.mutation.AddedIPSize(); ok { - _spec.Fields.Add = append(_spec.Fields.Add, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Value: value, - Column: decision.FieldIPSize, - }) + _spec.AddField(decision.FieldIPSize, field.TypeInt64, value) } if du.mutation.IPSizeCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Column: decision.FieldIPSize, - }) + _spec.ClearField(decision.FieldIPSize, field.TypeInt64) } if value, ok := du.mutation.Scope(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: decision.FieldScope, - }) + _spec.SetField(decision.FieldScope, field.TypeString, value) } if value, ok := du.mutation.Value(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: decision.FieldValue, - }) + _spec.SetField(decision.FieldValue, field.TypeString, value) } if value, ok := du.mutation.Origin(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: decision.FieldOrigin, - }) + _spec.SetField(decision.FieldOrigin, field.TypeString, value) } if value, ok := du.mutation.Simulated(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeBool, - Value: value, - Column: decision.FieldSimulated, - }) + _spec.SetField(decision.FieldSimulated, field.TypeBool, value) } if value, ok := du.mutation.UUID(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: decision.FieldUUID, - }) + _spec.SetField(decision.FieldUUID, field.TypeString, value) } if du.mutation.UUIDCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: decision.FieldUUID, - }) + _spec.ClearField(decision.FieldUUID, field.TypeString) } if du.mutation.OwnerCleared() { edge := &sqlgraph.EdgeSpec{ @@ -609,10 +466,7 @@ func (du *DecisionUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{decision.OwnerColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: alert.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt), }, } _spec.Edges.Clear = append(_spec.Edges.Clear, edge) @@ -625,10 +479,7 @@ func (du *DecisionUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{decision.OwnerColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: alert.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -644,6 +495,7 @@ func (du *DecisionUpdate) sqlSave(ctx context.Context) (n int, err error) { } return 0, err } + du.mutation.done = true return n, nil } @@ -948,6 +800,12 @@ func (duo *DecisionUpdateOne) ClearOwner() *DecisionUpdateOne { return duo } +// Where appends a list predicates to the DecisionUpdate builder. +func (duo *DecisionUpdateOne) Where(ps ...predicate.Decision) *DecisionUpdateOne { + duo.mutation.Where(ps...) + return duo +} + // Select allows selecting one or more fields (columns) of the returned entity. // The default is selecting all fields defined in the entity schema. func (duo *DecisionUpdateOne) Select(field string, fields ...string) *DecisionUpdateOne { @@ -957,41 +815,8 @@ func (duo *DecisionUpdateOne) Select(field string, fields ...string) *DecisionUp // Save executes the query and returns the updated Decision entity. func (duo *DecisionUpdateOne) Save(ctx context.Context) (*Decision, error) { - var ( - err error - node *Decision - ) duo.defaults() - if len(duo.hooks) == 0 { - node, err = duo.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*DecisionMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - duo.mutation = mutation - node, err = duo.sqlSave(ctx) - mutation.done = true - return node, err - }) - for i := len(duo.hooks) - 1; i >= 0; i-- { - if duo.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = duo.hooks[i](mut) - } - v, err := mut.Mutate(ctx, duo.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*Decision) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from DecisionMutation", v) - } - node = nv - } - return node, err + return withHooks(ctx, duo.sqlSave, duo.mutation, duo.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -1029,16 +854,7 @@ func (duo *DecisionUpdateOne) defaults() { } func (duo *DecisionUpdateOne) sqlSave(ctx context.Context) (_node *Decision, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: decision.Table, - Columns: decision.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: decision.FieldID, - }, - }, - } + _spec := sqlgraph.NewUpdateSpec(decision.Table, decision.Columns, sqlgraph.NewFieldSpec(decision.FieldID, field.TypeInt)) id, ok := duo.mutation.ID() if !ok { return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "Decision.id" for update`)} @@ -1064,198 +880,91 @@ func (duo *DecisionUpdateOne) sqlSave(ctx context.Context) (_node *Decision, err } } if value, ok := duo.mutation.CreatedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: decision.FieldCreatedAt, - }) + _spec.SetField(decision.FieldCreatedAt, field.TypeTime, value) } if duo.mutation.CreatedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: decision.FieldCreatedAt, - }) + _spec.ClearField(decision.FieldCreatedAt, field.TypeTime) } if value, ok := duo.mutation.UpdatedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: decision.FieldUpdatedAt, - }) + _spec.SetField(decision.FieldUpdatedAt, field.TypeTime, value) } if duo.mutation.UpdatedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: decision.FieldUpdatedAt, - }) + _spec.ClearField(decision.FieldUpdatedAt, field.TypeTime) } if value, ok := duo.mutation.Until(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: decision.FieldUntil, - }) + _spec.SetField(decision.FieldUntil, field.TypeTime, value) } if duo.mutation.UntilCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: decision.FieldUntil, - }) + _spec.ClearField(decision.FieldUntil, field.TypeTime) } if value, ok := duo.mutation.Scenario(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: decision.FieldScenario, - }) + _spec.SetField(decision.FieldScenario, field.TypeString, value) } if value, ok := duo.mutation.GetType(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: decision.FieldType, - }) + _spec.SetField(decision.FieldType, field.TypeString, value) } if value, ok := duo.mutation.StartIP(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Value: value, - Column: decision.FieldStartIP, - }) + _spec.SetField(decision.FieldStartIP, field.TypeInt64, value) } if value, ok := duo.mutation.AddedStartIP(); ok { - _spec.Fields.Add = append(_spec.Fields.Add, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Value: value, - Column: decision.FieldStartIP, - }) + _spec.AddField(decision.FieldStartIP, field.TypeInt64, value) } if duo.mutation.StartIPCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Column: decision.FieldStartIP, - }) + _spec.ClearField(decision.FieldStartIP, field.TypeInt64) } if value, ok := duo.mutation.EndIP(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Value: value, - Column: decision.FieldEndIP, - }) + _spec.SetField(decision.FieldEndIP, field.TypeInt64, value) } if value, ok := duo.mutation.AddedEndIP(); ok { - _spec.Fields.Add = append(_spec.Fields.Add, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Value: value, - Column: decision.FieldEndIP, - }) + _spec.AddField(decision.FieldEndIP, field.TypeInt64, value) } if duo.mutation.EndIPCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Column: decision.FieldEndIP, - }) + _spec.ClearField(decision.FieldEndIP, field.TypeInt64) } if value, ok := duo.mutation.StartSuffix(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Value: value, - Column: decision.FieldStartSuffix, - }) + _spec.SetField(decision.FieldStartSuffix, field.TypeInt64, value) } if value, ok := duo.mutation.AddedStartSuffix(); ok { - _spec.Fields.Add = append(_spec.Fields.Add, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Value: value, - Column: decision.FieldStartSuffix, - }) + _spec.AddField(decision.FieldStartSuffix, field.TypeInt64, value) } if duo.mutation.StartSuffixCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Column: decision.FieldStartSuffix, - }) + _spec.ClearField(decision.FieldStartSuffix, field.TypeInt64) } if value, ok := duo.mutation.EndSuffix(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Value: value, - Column: decision.FieldEndSuffix, - }) + _spec.SetField(decision.FieldEndSuffix, field.TypeInt64, value) } if value, ok := duo.mutation.AddedEndSuffix(); ok { - _spec.Fields.Add = append(_spec.Fields.Add, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Value: value, - Column: decision.FieldEndSuffix, - }) + _spec.AddField(decision.FieldEndSuffix, field.TypeInt64, value) } if duo.mutation.EndSuffixCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Column: decision.FieldEndSuffix, - }) + _spec.ClearField(decision.FieldEndSuffix, field.TypeInt64) } if value, ok := duo.mutation.IPSize(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Value: value, - Column: decision.FieldIPSize, - }) + _spec.SetField(decision.FieldIPSize, field.TypeInt64, value) } if value, ok := duo.mutation.AddedIPSize(); ok { - _spec.Fields.Add = append(_spec.Fields.Add, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Value: value, - Column: decision.FieldIPSize, - }) + _spec.AddField(decision.FieldIPSize, field.TypeInt64, value) } if duo.mutation.IPSizeCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Column: decision.FieldIPSize, - }) + _spec.ClearField(decision.FieldIPSize, field.TypeInt64) } if value, ok := duo.mutation.Scope(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: decision.FieldScope, - }) + _spec.SetField(decision.FieldScope, field.TypeString, value) } if value, ok := duo.mutation.Value(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: decision.FieldValue, - }) + _spec.SetField(decision.FieldValue, field.TypeString, value) } if value, ok := duo.mutation.Origin(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: decision.FieldOrigin, - }) + _spec.SetField(decision.FieldOrigin, field.TypeString, value) } if value, ok := duo.mutation.Simulated(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeBool, - Value: value, - Column: decision.FieldSimulated, - }) + _spec.SetField(decision.FieldSimulated, field.TypeBool, value) } if value, ok := duo.mutation.UUID(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: decision.FieldUUID, - }) + _spec.SetField(decision.FieldUUID, field.TypeString, value) } if duo.mutation.UUIDCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: decision.FieldUUID, - }) + _spec.ClearField(decision.FieldUUID, field.TypeString) } if duo.mutation.OwnerCleared() { edge := &sqlgraph.EdgeSpec{ @@ -1265,10 +974,7 @@ func (duo *DecisionUpdateOne) sqlSave(ctx context.Context) (_node *Decision, err Columns: []string{decision.OwnerColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: alert.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt), }, } _spec.Edges.Clear = append(_spec.Edges.Clear, edge) @@ -1281,10 +987,7 @@ func (duo *DecisionUpdateOne) sqlSave(ctx context.Context) (_node *Decision, err Columns: []string{decision.OwnerColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: alert.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -1303,5 +1006,6 @@ func (duo *DecisionUpdateOne) sqlSave(ctx context.Context) (_node *Decision, err } return nil, err } + duo.mutation.done = true return _node, nil } diff --git a/pkg/database/ent/ent.go b/pkg/database/ent/ent.go index 0455af444d2..393ce9f1869 100644 --- a/pkg/database/ent/ent.go +++ b/pkg/database/ent/ent.go @@ -6,6 +6,8 @@ import ( "context" "errors" "fmt" + "reflect" + "sync" "entgo.io/ent" "entgo.io/ent/dialect/sql" @@ -21,50 +23,79 @@ import ( // ent aliases to avoid import conflicts in user's code. type ( - Op = ent.Op - Hook = ent.Hook - Value = ent.Value - Query = ent.Query - Policy = ent.Policy - Mutator = ent.Mutator - Mutation = ent.Mutation - MutateFunc = ent.MutateFunc + Op = ent.Op + Hook = ent.Hook + Value = ent.Value + Query = ent.Query + QueryContext = ent.QueryContext + Querier = ent.Querier + QuerierFunc = ent.QuerierFunc + Interceptor = ent.Interceptor + InterceptFunc = ent.InterceptFunc + Traverser = ent.Traverser + TraverseFunc = ent.TraverseFunc + Policy = ent.Policy + Mutator = ent.Mutator + Mutation = ent.Mutation + MutateFunc = ent.MutateFunc ) +type clientCtxKey struct{} + +// FromContext returns a Client stored inside a context, or nil if there isn't one. +func FromContext(ctx context.Context) *Client { + c, _ := ctx.Value(clientCtxKey{}).(*Client) + return c +} + +// NewContext returns a new context with the given Client attached. +func NewContext(parent context.Context, c *Client) context.Context { + return context.WithValue(parent, clientCtxKey{}, c) +} + +type txCtxKey struct{} + +// TxFromContext returns a Tx stored inside a context, or nil if there isn't one. +func TxFromContext(ctx context.Context) *Tx { + tx, _ := ctx.Value(txCtxKey{}).(*Tx) + return tx +} + +// NewTxContext returns a new context with the given Tx attached. +func NewTxContext(parent context.Context, tx *Tx) context.Context { + return context.WithValue(parent, txCtxKey{}, tx) +} + // OrderFunc applies an ordering on the sql selector. +// Deprecated: Use Asc/Desc functions or the package builders instead. type OrderFunc func(*sql.Selector) -// columnChecker returns a function indicates if the column exists in the given column. -func columnChecker(table string) func(string) error { - checks := map[string]func(string) bool{ - alert.Table: alert.ValidColumn, - bouncer.Table: bouncer.ValidColumn, - configitem.Table: configitem.ValidColumn, - decision.Table: decision.ValidColumn, - event.Table: event.ValidColumn, - machine.Table: machine.ValidColumn, - meta.Table: meta.ValidColumn, - } - check, ok := checks[table] - if !ok { - return func(string) error { - return fmt.Errorf("unknown table %q", table) - } - } - return func(column string) error { - if !check(column) { - return fmt.Errorf("unknown column %q for table %q", column, table) - } - return nil - } +var ( + initCheck sync.Once + columnCheck sql.ColumnCheck +) + +// columnChecker checks if the column exists in the given table. +func checkColumn(table, column string) error { + initCheck.Do(func() { + columnCheck = sql.NewColumnCheck(map[string]func(string) bool{ + alert.Table: alert.ValidColumn, + bouncer.Table: bouncer.ValidColumn, + configitem.Table: configitem.ValidColumn, + decision.Table: decision.ValidColumn, + event.Table: event.ValidColumn, + machine.Table: machine.ValidColumn, + meta.Table: meta.ValidColumn, + }) + }) + return columnCheck(table, column) } // Asc applies the given fields in ASC order. -func Asc(fields ...string) OrderFunc { +func Asc(fields ...string) func(*sql.Selector) { return func(s *sql.Selector) { - check := columnChecker(s.TableName()) for _, f := range fields { - if err := check(f); err != nil { + if err := checkColumn(s.TableName(), f); err != nil { s.AddError(&ValidationError{Name: f, err: fmt.Errorf("ent: %w", err)}) } s.OrderBy(sql.Asc(s.C(f))) @@ -73,11 +104,10 @@ func Asc(fields ...string) OrderFunc { } // Desc applies the given fields in DESC order. -func Desc(fields ...string) OrderFunc { +func Desc(fields ...string) func(*sql.Selector) { return func(s *sql.Selector) { - check := columnChecker(s.TableName()) for _, f := range fields { - if err := check(f); err != nil { + if err := checkColumn(s.TableName(), f); err != nil { s.AddError(&ValidationError{Name: f, err: fmt.Errorf("ent: %w", err)}) } s.OrderBy(sql.Desc(s.C(f))) @@ -109,8 +139,7 @@ func Count() AggregateFunc { // Max applies the "max" aggregation function on the given field of each group. func Max(field string) AggregateFunc { return func(s *sql.Selector) string { - check := columnChecker(s.TableName()) - if err := check(field); err != nil { + if err := checkColumn(s.TableName(), field); err != nil { s.AddError(&ValidationError{Name: field, err: fmt.Errorf("ent: %w", err)}) return "" } @@ -121,8 +150,7 @@ func Max(field string) AggregateFunc { // Mean applies the "mean" aggregation function on the given field of each group. func Mean(field string) AggregateFunc { return func(s *sql.Selector) string { - check := columnChecker(s.TableName()) - if err := check(field); err != nil { + if err := checkColumn(s.TableName(), field); err != nil { s.AddError(&ValidationError{Name: field, err: fmt.Errorf("ent: %w", err)}) return "" } @@ -133,8 +161,7 @@ func Mean(field string) AggregateFunc { // Min applies the "min" aggregation function on the given field of each group. func Min(field string) AggregateFunc { return func(s *sql.Selector) string { - check := columnChecker(s.TableName()) - if err := check(field); err != nil { + if err := checkColumn(s.TableName(), field); err != nil { s.AddError(&ValidationError{Name: field, err: fmt.Errorf("ent: %w", err)}) return "" } @@ -145,8 +172,7 @@ func Min(field string) AggregateFunc { // Sum applies the "sum" aggregation function on the given field of each group. func Sum(field string) AggregateFunc { return func(s *sql.Selector) string { - check := columnChecker(s.TableName()) - if err := check(field); err != nil { + if err := checkColumn(s.TableName(), field); err != nil { s.AddError(&ValidationError{Name: field, err: fmt.Errorf("ent: %w", err)}) return "" } @@ -275,6 +301,7 @@ func IsConstraintError(err error) bool { type selector struct { label string flds *[]string + fns []AggregateFunc scan func(context.Context, any) error } @@ -473,5 +500,121 @@ func (s *selector) BoolX(ctx context.Context) bool { return v } +// withHooks invokes the builder operation with the given hooks, if any. +func withHooks[V Value, M any, PM interface { + *M + Mutation +}](ctx context.Context, exec func(context.Context) (V, error), mutation PM, hooks []Hook) (value V, err error) { + if len(hooks) == 0 { + return exec(ctx) + } + var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { + mutationT, ok := any(m).(PM) + if !ok { + return nil, fmt.Errorf("unexpected mutation type %T", m) + } + // Set the mutation to the builder. + *mutation = *mutationT + return exec(ctx) + }) + for i := len(hooks) - 1; i >= 0; i-- { + if hooks[i] == nil { + return value, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") + } + mut = hooks[i](mut) + } + v, err := mut.Mutate(ctx, mutation) + if err != nil { + return value, err + } + nv, ok := v.(V) + if !ok { + return value, fmt.Errorf("unexpected node type %T returned from %T", v, mutation) + } + return nv, nil +} + +// setContextOp returns a new context with the given QueryContext attached (including its op) in case it does not exist. +func setContextOp(ctx context.Context, qc *QueryContext, op string) context.Context { + if ent.QueryFromContext(ctx) == nil { + qc.Op = op + ctx = ent.NewQueryContext(ctx, qc) + } + return ctx +} + +func querierAll[V Value, Q interface { + sqlAll(context.Context, ...queryHook) (V, error) +}]() Querier { + return QuerierFunc(func(ctx context.Context, q Query) (Value, error) { + query, ok := q.(Q) + if !ok { + return nil, fmt.Errorf("unexpected query type %T", q) + } + return query.sqlAll(ctx) + }) +} + +func querierCount[Q interface { + sqlCount(context.Context) (int, error) +}]() Querier { + return QuerierFunc(func(ctx context.Context, q Query) (Value, error) { + query, ok := q.(Q) + if !ok { + return nil, fmt.Errorf("unexpected query type %T", q) + } + return query.sqlCount(ctx) + }) +} + +func withInterceptors[V Value](ctx context.Context, q Query, qr Querier, inters []Interceptor) (v V, err error) { + for i := len(inters) - 1; i >= 0; i-- { + qr = inters[i].Intercept(qr) + } + rv, err := qr.Query(ctx, q) + if err != nil { + return v, err + } + vt, ok := rv.(V) + if !ok { + return v, fmt.Errorf("unexpected type %T returned from %T. expected type: %T", vt, q, v) + } + return vt, nil +} + +func scanWithInterceptors[Q1 ent.Query, Q2 interface { + sqlScan(context.Context, Q1, any) error +}](ctx context.Context, rootQuery Q1, selectOrGroup Q2, inters []Interceptor, v any) error { + rv := reflect.ValueOf(v) + var qr Querier = QuerierFunc(func(ctx context.Context, q Query) (Value, error) { + query, ok := q.(Q1) + if !ok { + return nil, fmt.Errorf("unexpected query type %T", q) + } + if err := selectOrGroup.sqlScan(ctx, query, v); err != nil { + return nil, err + } + if k := rv.Kind(); k == reflect.Pointer && rv.Elem().CanInterface() { + return rv.Elem().Interface(), nil + } + return v, nil + }) + for i := len(inters) - 1; i >= 0; i-- { + qr = inters[i].Intercept(qr) + } + vv, err := qr.Query(ctx, rootQuery) + if err != nil { + return err + } + switch rv2 := reflect.ValueOf(vv); { + case rv.IsNil(), rv2.IsNil(), rv.Kind() != reflect.Pointer: + case rv.Type() == rv2.Type(): + rv.Elem().Set(rv2.Elem()) + case rv.Elem().Type() == rv2.Type(): + rv.Elem().Set(rv2) + } + return nil +} + // queryHook describes an internal hook for the different sqlAll methods. type queryHook func(context.Context, *sqlgraph.QuerySpec) diff --git a/pkg/database/ent/event.go b/pkg/database/ent/event.go index 4754107fddc..df4a2d10c8b 100644 --- a/pkg/database/ent/event.go +++ b/pkg/database/ent/event.go @@ -7,6 +7,7 @@ import ( "strings" "time" + "entgo.io/ent" "entgo.io/ent/dialect/sql" "github.com/crowdsecurity/crowdsec/pkg/database/ent/alert" "github.com/crowdsecurity/crowdsec/pkg/database/ent/event" @@ -29,7 +30,8 @@ type Event struct { AlertEvents int `json:"alert_events,omitempty"` // Edges holds the relations/edges for other nodes in the graph. // The values are being populated by the EventQuery when eager-loading is set. - Edges EventEdges `json:"edges"` + Edges EventEdges `json:"edges"` + selectValues sql.SelectValues } // EventEdges holds the relations/edges for other nodes in the graph. @@ -66,7 +68,7 @@ func (*Event) scanValues(columns []string) ([]any, error) { case event.FieldCreatedAt, event.FieldUpdatedAt, event.FieldTime: values[i] = new(sql.NullTime) default: - return nil, fmt.Errorf("unexpected column %q for type Event", columns[i]) + values[i] = new(sql.UnknownType) } } return values, nil @@ -118,21 +120,29 @@ func (e *Event) assignValues(columns []string, values []any) error { } else if value.Valid { e.AlertEvents = int(value.Int64) } + default: + e.selectValues.Set(columns[i], values[i]) } } return nil } +// Value returns the ent.Value that was dynamically selected and assigned to the Event. +// This includes values selected through modifiers, order, etc. +func (e *Event) Value(name string) (ent.Value, error) { + return e.selectValues.Get(name) +} + // QueryOwner queries the "owner" edge of the Event entity. func (e *Event) QueryOwner() *AlertQuery { - return (&EventClient{config: e.config}).QueryOwner(e) + return NewEventClient(e.config).QueryOwner(e) } // Update returns a builder for updating this Event. // Note that you need to call Event.Unwrap() before calling this method if this Event // was returned from a transaction, and the transaction was committed or rolled back. func (e *Event) Update() *EventUpdateOne { - return (&EventClient{config: e.config}).UpdateOne(e) + return NewEventClient(e.config).UpdateOne(e) } // Unwrap unwraps the Event entity that was returned from a transaction after it was closed, @@ -175,9 +185,3 @@ func (e *Event) String() string { // Events is a parsable slice of Event. type Events []*Event - -func (e Events) config(cfg config) { - for _i := range e { - e[_i].config = cfg - } -} diff --git a/pkg/database/ent/event/event.go b/pkg/database/ent/event/event.go index 33b9b67f8b9..48f5a355824 100644 --- a/pkg/database/ent/event/event.go +++ b/pkg/database/ent/event/event.go @@ -4,6 +4,9 @@ package event import ( "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" ) const ( @@ -66,3 +69,50 @@ var ( // SerializedValidator is a validator for the "serialized" field. It is called by the builders before save. SerializedValidator func(string) error ) + +// OrderOption defines the ordering options for the Event queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByCreatedAt orders the results by the created_at field. +func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() +} + +// ByUpdatedAt orders the results by the updated_at field. +func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc() +} + +// ByTime orders the results by the time field. +func ByTime(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldTime, opts...).ToFunc() +} + +// BySerialized orders the results by the serialized field. +func BySerialized(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSerialized, opts...).ToFunc() +} + +// ByAlertEvents orders the results by the alert_events field. +func ByAlertEvents(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldAlertEvents, opts...).ToFunc() +} + +// ByOwnerField orders the results by owner field. +func ByOwnerField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newOwnerStep(), sql.OrderByField(field, opts...)) + } +} +func newOwnerStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(OwnerInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, OwnerTable, OwnerColumn), + ) +} diff --git a/pkg/database/ent/event/where.go b/pkg/database/ent/event/where.go index 7554e59e678..238bea988bd 100644 --- a/pkg/database/ent/event/where.go +++ b/pkg/database/ent/event/where.go @@ -12,477 +12,307 @@ import ( // ID filters vertices based on their ID field. func ID(id int) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.Event(sql.FieldEQ(FieldID, id)) } // IDEQ applies the EQ predicate on the ID field. func IDEQ(id int) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.Event(sql.FieldEQ(FieldID, id)) } // IDNEQ applies the NEQ predicate on the ID field. func IDNEQ(id int) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldID), id)) - }) + return predicate.Event(sql.FieldNEQ(FieldID, id)) } // IDIn applies the In predicate on the ID field. func IDIn(ids ...int) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - v := make([]any, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.In(s.C(FieldID), v...)) - }) + return predicate.Event(sql.FieldIn(FieldID, ids...)) } // IDNotIn applies the NotIn predicate on the ID field. func IDNotIn(ids ...int) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - v := make([]any, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.NotIn(s.C(FieldID), v...)) - }) + return predicate.Event(sql.FieldNotIn(FieldID, ids...)) } // IDGT applies the GT predicate on the ID field. func IDGT(id int) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldID), id)) - }) + return predicate.Event(sql.FieldGT(FieldID, id)) } // IDGTE applies the GTE predicate on the ID field. func IDGTE(id int) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldID), id)) - }) + return predicate.Event(sql.FieldGTE(FieldID, id)) } // IDLT applies the LT predicate on the ID field. func IDLT(id int) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldID), id)) - }) + return predicate.Event(sql.FieldLT(FieldID, id)) } // IDLTE applies the LTE predicate on the ID field. func IDLTE(id int) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldID), id)) - }) + return predicate.Event(sql.FieldLTE(FieldID, id)) } // CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. func CreatedAt(v time.Time) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldCreatedAt), v)) - }) + return predicate.Event(sql.FieldEQ(FieldCreatedAt, v)) } // UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ. func UpdatedAt(v time.Time) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldUpdatedAt), v)) - }) + return predicate.Event(sql.FieldEQ(FieldUpdatedAt, v)) } // Time applies equality check predicate on the "time" field. It's identical to TimeEQ. func Time(v time.Time) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldTime), v)) - }) + return predicate.Event(sql.FieldEQ(FieldTime, v)) } // Serialized applies equality check predicate on the "serialized" field. It's identical to SerializedEQ. func Serialized(v string) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldSerialized), v)) - }) + return predicate.Event(sql.FieldEQ(FieldSerialized, v)) } // AlertEvents applies equality check predicate on the "alert_events" field. It's identical to AlertEventsEQ. func AlertEvents(v int) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldAlertEvents), v)) - }) + return predicate.Event(sql.FieldEQ(FieldAlertEvents, v)) } // CreatedAtEQ applies the EQ predicate on the "created_at" field. func CreatedAtEQ(v time.Time) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldCreatedAt), v)) - }) + return predicate.Event(sql.FieldEQ(FieldCreatedAt, v)) } // CreatedAtNEQ applies the NEQ predicate on the "created_at" field. func CreatedAtNEQ(v time.Time) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldCreatedAt), v)) - }) + return predicate.Event(sql.FieldNEQ(FieldCreatedAt, v)) } // CreatedAtIn applies the In predicate on the "created_at" field. func CreatedAtIn(vs ...time.Time) predicate.Event { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldCreatedAt), v...)) - }) + return predicate.Event(sql.FieldIn(FieldCreatedAt, vs...)) } // CreatedAtNotIn applies the NotIn predicate on the "created_at" field. func CreatedAtNotIn(vs ...time.Time) predicate.Event { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldCreatedAt), v...)) - }) + return predicate.Event(sql.FieldNotIn(FieldCreatedAt, vs...)) } // CreatedAtGT applies the GT predicate on the "created_at" field. func CreatedAtGT(v time.Time) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldCreatedAt), v)) - }) + return predicate.Event(sql.FieldGT(FieldCreatedAt, v)) } // CreatedAtGTE applies the GTE predicate on the "created_at" field. func CreatedAtGTE(v time.Time) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldCreatedAt), v)) - }) + return predicate.Event(sql.FieldGTE(FieldCreatedAt, v)) } // CreatedAtLT applies the LT predicate on the "created_at" field. func CreatedAtLT(v time.Time) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldCreatedAt), v)) - }) + return predicate.Event(sql.FieldLT(FieldCreatedAt, v)) } // CreatedAtLTE applies the LTE predicate on the "created_at" field. func CreatedAtLTE(v time.Time) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldCreatedAt), v)) - }) + return predicate.Event(sql.FieldLTE(FieldCreatedAt, v)) } // CreatedAtIsNil applies the IsNil predicate on the "created_at" field. func CreatedAtIsNil() predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldCreatedAt))) - }) + return predicate.Event(sql.FieldIsNull(FieldCreatedAt)) } // CreatedAtNotNil applies the NotNil predicate on the "created_at" field. func CreatedAtNotNil() predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldCreatedAt))) - }) + return predicate.Event(sql.FieldNotNull(FieldCreatedAt)) } // UpdatedAtEQ applies the EQ predicate on the "updated_at" field. func UpdatedAtEQ(v time.Time) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldUpdatedAt), v)) - }) + return predicate.Event(sql.FieldEQ(FieldUpdatedAt, v)) } // UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field. func UpdatedAtNEQ(v time.Time) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldUpdatedAt), v)) - }) + return predicate.Event(sql.FieldNEQ(FieldUpdatedAt, v)) } // UpdatedAtIn applies the In predicate on the "updated_at" field. func UpdatedAtIn(vs ...time.Time) predicate.Event { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldUpdatedAt), v...)) - }) + return predicate.Event(sql.FieldIn(FieldUpdatedAt, vs...)) } // UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field. func UpdatedAtNotIn(vs ...time.Time) predicate.Event { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldUpdatedAt), v...)) - }) + return predicate.Event(sql.FieldNotIn(FieldUpdatedAt, vs...)) } // UpdatedAtGT applies the GT predicate on the "updated_at" field. func UpdatedAtGT(v time.Time) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldUpdatedAt), v)) - }) + return predicate.Event(sql.FieldGT(FieldUpdatedAt, v)) } // UpdatedAtGTE applies the GTE predicate on the "updated_at" field. func UpdatedAtGTE(v time.Time) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldUpdatedAt), v)) - }) + return predicate.Event(sql.FieldGTE(FieldUpdatedAt, v)) } // UpdatedAtLT applies the LT predicate on the "updated_at" field. func UpdatedAtLT(v time.Time) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldUpdatedAt), v)) - }) + return predicate.Event(sql.FieldLT(FieldUpdatedAt, v)) } // UpdatedAtLTE applies the LTE predicate on the "updated_at" field. func UpdatedAtLTE(v time.Time) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldUpdatedAt), v)) - }) + return predicate.Event(sql.FieldLTE(FieldUpdatedAt, v)) } // UpdatedAtIsNil applies the IsNil predicate on the "updated_at" field. func UpdatedAtIsNil() predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldUpdatedAt))) - }) + return predicate.Event(sql.FieldIsNull(FieldUpdatedAt)) } // UpdatedAtNotNil applies the NotNil predicate on the "updated_at" field. func UpdatedAtNotNil() predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldUpdatedAt))) - }) + return predicate.Event(sql.FieldNotNull(FieldUpdatedAt)) } // TimeEQ applies the EQ predicate on the "time" field. func TimeEQ(v time.Time) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldTime), v)) - }) + return predicate.Event(sql.FieldEQ(FieldTime, v)) } // TimeNEQ applies the NEQ predicate on the "time" field. func TimeNEQ(v time.Time) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldTime), v)) - }) + return predicate.Event(sql.FieldNEQ(FieldTime, v)) } // TimeIn applies the In predicate on the "time" field. func TimeIn(vs ...time.Time) predicate.Event { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldTime), v...)) - }) + return predicate.Event(sql.FieldIn(FieldTime, vs...)) } // TimeNotIn applies the NotIn predicate on the "time" field. func TimeNotIn(vs ...time.Time) predicate.Event { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldTime), v...)) - }) + return predicate.Event(sql.FieldNotIn(FieldTime, vs...)) } // TimeGT applies the GT predicate on the "time" field. func TimeGT(v time.Time) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldTime), v)) - }) + return predicate.Event(sql.FieldGT(FieldTime, v)) } // TimeGTE applies the GTE predicate on the "time" field. func TimeGTE(v time.Time) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldTime), v)) - }) + return predicate.Event(sql.FieldGTE(FieldTime, v)) } // TimeLT applies the LT predicate on the "time" field. func TimeLT(v time.Time) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldTime), v)) - }) + return predicate.Event(sql.FieldLT(FieldTime, v)) } // TimeLTE applies the LTE predicate on the "time" field. func TimeLTE(v time.Time) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldTime), v)) - }) + return predicate.Event(sql.FieldLTE(FieldTime, v)) } // SerializedEQ applies the EQ predicate on the "serialized" field. func SerializedEQ(v string) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldSerialized), v)) - }) + return predicate.Event(sql.FieldEQ(FieldSerialized, v)) } // SerializedNEQ applies the NEQ predicate on the "serialized" field. func SerializedNEQ(v string) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldSerialized), v)) - }) + return predicate.Event(sql.FieldNEQ(FieldSerialized, v)) } // SerializedIn applies the In predicate on the "serialized" field. func SerializedIn(vs ...string) predicate.Event { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldSerialized), v...)) - }) + return predicate.Event(sql.FieldIn(FieldSerialized, vs...)) } // SerializedNotIn applies the NotIn predicate on the "serialized" field. func SerializedNotIn(vs ...string) predicate.Event { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldSerialized), v...)) - }) + return predicate.Event(sql.FieldNotIn(FieldSerialized, vs...)) } // SerializedGT applies the GT predicate on the "serialized" field. func SerializedGT(v string) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldSerialized), v)) - }) + return predicate.Event(sql.FieldGT(FieldSerialized, v)) } // SerializedGTE applies the GTE predicate on the "serialized" field. func SerializedGTE(v string) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldSerialized), v)) - }) + return predicate.Event(sql.FieldGTE(FieldSerialized, v)) } // SerializedLT applies the LT predicate on the "serialized" field. func SerializedLT(v string) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldSerialized), v)) - }) + return predicate.Event(sql.FieldLT(FieldSerialized, v)) } // SerializedLTE applies the LTE predicate on the "serialized" field. func SerializedLTE(v string) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldSerialized), v)) - }) + return predicate.Event(sql.FieldLTE(FieldSerialized, v)) } // SerializedContains applies the Contains predicate on the "serialized" field. func SerializedContains(v string) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldSerialized), v)) - }) + return predicate.Event(sql.FieldContains(FieldSerialized, v)) } // SerializedHasPrefix applies the HasPrefix predicate on the "serialized" field. func SerializedHasPrefix(v string) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldSerialized), v)) - }) + return predicate.Event(sql.FieldHasPrefix(FieldSerialized, v)) } // SerializedHasSuffix applies the HasSuffix predicate on the "serialized" field. func SerializedHasSuffix(v string) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldSerialized), v)) - }) + return predicate.Event(sql.FieldHasSuffix(FieldSerialized, v)) } // SerializedEqualFold applies the EqualFold predicate on the "serialized" field. func SerializedEqualFold(v string) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldSerialized), v)) - }) + return predicate.Event(sql.FieldEqualFold(FieldSerialized, v)) } // SerializedContainsFold applies the ContainsFold predicate on the "serialized" field. func SerializedContainsFold(v string) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldSerialized), v)) - }) + return predicate.Event(sql.FieldContainsFold(FieldSerialized, v)) } // AlertEventsEQ applies the EQ predicate on the "alert_events" field. func AlertEventsEQ(v int) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldAlertEvents), v)) - }) + return predicate.Event(sql.FieldEQ(FieldAlertEvents, v)) } // AlertEventsNEQ applies the NEQ predicate on the "alert_events" field. func AlertEventsNEQ(v int) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldAlertEvents), v)) - }) + return predicate.Event(sql.FieldNEQ(FieldAlertEvents, v)) } // AlertEventsIn applies the In predicate on the "alert_events" field. func AlertEventsIn(vs ...int) predicate.Event { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldAlertEvents), v...)) - }) + return predicate.Event(sql.FieldIn(FieldAlertEvents, vs...)) } // AlertEventsNotIn applies the NotIn predicate on the "alert_events" field. func AlertEventsNotIn(vs ...int) predicate.Event { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldAlertEvents), v...)) - }) + return predicate.Event(sql.FieldNotIn(FieldAlertEvents, vs...)) } // AlertEventsIsNil applies the IsNil predicate on the "alert_events" field. func AlertEventsIsNil() predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldAlertEvents))) - }) + return predicate.Event(sql.FieldIsNull(FieldAlertEvents)) } // AlertEventsNotNil applies the NotNil predicate on the "alert_events" field. func AlertEventsNotNil() predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldAlertEvents))) - }) + return predicate.Event(sql.FieldNotNull(FieldAlertEvents)) } // HasOwner applies the HasEdge predicate on the "owner" edge. @@ -490,7 +320,6 @@ func HasOwner() predicate.Event { return predicate.Event(func(s *sql.Selector) { step := sqlgraph.NewStep( sqlgraph.From(Table, FieldID), - sqlgraph.To(OwnerTable, FieldID), sqlgraph.Edge(sqlgraph.M2O, true, OwnerTable, OwnerColumn), ) sqlgraph.HasNeighbors(s, step) @@ -500,11 +329,7 @@ func HasOwner() predicate.Event { // HasOwnerWith applies the HasEdge predicate on the "owner" edge with a given conditions (other predicates). func HasOwnerWith(preds ...predicate.Alert) predicate.Event { return predicate.Event(func(s *sql.Selector) { - step := sqlgraph.NewStep( - sqlgraph.From(Table, FieldID), - sqlgraph.To(OwnerInverseTable, FieldID), - sqlgraph.Edge(sqlgraph.M2O, true, OwnerTable, OwnerColumn), - ) + step := newOwnerStep() sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { for _, p := range preds { p(s) @@ -515,32 +340,15 @@ func HasOwnerWith(preds ...predicate.Alert) predicate.Event { // And groups predicates with the AND operator between them. func And(predicates ...predicate.Event) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s1 := s.Clone().SetP(nil) - for _, p := range predicates { - p(s1) - } - s.Where(s1.P()) - }) + return predicate.Event(sql.AndPredicates(predicates...)) } // Or groups predicates with the OR operator between them. func Or(predicates ...predicate.Event) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s1 := s.Clone().SetP(nil) - for i, p := range predicates { - if i > 0 { - s1.Or() - } - p(s1) - } - s.Where(s1.P()) - }) + return predicate.Event(sql.OrPredicates(predicates...)) } // Not applies the not operator on the given predicate. func Not(p predicate.Event) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - p(s.Not()) - }) + return predicate.Event(sql.NotPredicates(p)) } diff --git a/pkg/database/ent/event_create.go b/pkg/database/ent/event_create.go index c5861305130..98194f2fd33 100644 --- a/pkg/database/ent/event_create.go +++ b/pkg/database/ent/event_create.go @@ -101,50 +101,8 @@ func (ec *EventCreate) Mutation() *EventMutation { // Save creates the Event in the database. func (ec *EventCreate) Save(ctx context.Context) (*Event, error) { - var ( - err error - node *Event - ) ec.defaults() - if len(ec.hooks) == 0 { - if err = ec.check(); err != nil { - return nil, err - } - node, err = ec.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*EventMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - if err = ec.check(); err != nil { - return nil, err - } - ec.mutation = mutation - if node, err = ec.sqlSave(ctx); err != nil { - return nil, err - } - mutation.id = &node.ID - mutation.done = true - return node, err - }) - for i := len(ec.hooks) - 1; i >= 0; i-- { - if ec.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = ec.hooks[i](mut) - } - v, err := mut.Mutate(ctx, ec.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*Event) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from EventMutation", v) - } - node = nv - } - return node, err + return withHooks(ctx, ec.sqlSave, ec.mutation, ec.hooks) } // SaveX calls Save and panics if Save returns an error. @@ -198,6 +156,9 @@ func (ec *EventCreate) check() error { } func (ec *EventCreate) sqlSave(ctx context.Context) (*Event, error) { + if err := ec.check(); err != nil { + return nil, err + } _node, _spec := ec.createSpec() if err := sqlgraph.CreateNode(ctx, ec.driver, _spec); err != nil { if sqlgraph.IsConstraintError(err) { @@ -207,50 +168,30 @@ func (ec *EventCreate) sqlSave(ctx context.Context) (*Event, error) { } id := _spec.ID.Value.(int64) _node.ID = int(id) + ec.mutation.id = &_node.ID + ec.mutation.done = true return _node, nil } func (ec *EventCreate) createSpec() (*Event, *sqlgraph.CreateSpec) { var ( _node = &Event{config: ec.config} - _spec = &sqlgraph.CreateSpec{ - Table: event.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: event.FieldID, - }, - } + _spec = sqlgraph.NewCreateSpec(event.Table, sqlgraph.NewFieldSpec(event.FieldID, field.TypeInt)) ) if value, ok := ec.mutation.CreatedAt(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: event.FieldCreatedAt, - }) + _spec.SetField(event.FieldCreatedAt, field.TypeTime, value) _node.CreatedAt = &value } if value, ok := ec.mutation.UpdatedAt(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: event.FieldUpdatedAt, - }) + _spec.SetField(event.FieldUpdatedAt, field.TypeTime, value) _node.UpdatedAt = &value } if value, ok := ec.mutation.Time(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: event.FieldTime, - }) + _spec.SetField(event.FieldTime, field.TypeTime, value) _node.Time = value } if value, ok := ec.mutation.Serialized(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: event.FieldSerialized, - }) + _spec.SetField(event.FieldSerialized, field.TypeString, value) _node.Serialized = value } if nodes := ec.mutation.OwnerIDs(); len(nodes) > 0 { @@ -261,10 +202,7 @@ func (ec *EventCreate) createSpec() (*Event, *sqlgraph.CreateSpec) { Columns: []string{event.OwnerColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: alert.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -279,11 +217,15 @@ func (ec *EventCreate) createSpec() (*Event, *sqlgraph.CreateSpec) { // EventCreateBulk is the builder for creating many Event entities in bulk. type EventCreateBulk struct { config + err error builders []*EventCreate } // Save creates the Event entities in the database. func (ecb *EventCreateBulk) Save(ctx context.Context) ([]*Event, error) { + if ecb.err != nil { + return nil, ecb.err + } specs := make([]*sqlgraph.CreateSpec, len(ecb.builders)) nodes := make([]*Event, len(ecb.builders)) mutators := make([]Mutator, len(ecb.builders)) @@ -300,8 +242,8 @@ func (ecb *EventCreateBulk) Save(ctx context.Context) ([]*Event, error) { return nil, err } builder.mutation = mutation - nodes[i], specs[i] = builder.createSpec() var err error + nodes[i], specs[i] = builder.createSpec() if i < len(mutators)-1 { _, err = mutators[i+1].Mutate(root, ecb.builders[i+1].mutation) } else { diff --git a/pkg/database/ent/event_delete.go b/pkg/database/ent/event_delete.go index 0220dc71d31..93dd1246b7e 100644 --- a/pkg/database/ent/event_delete.go +++ b/pkg/database/ent/event_delete.go @@ -4,7 +4,6 @@ package ent import ( "context" - "fmt" "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" @@ -28,34 +27,7 @@ func (ed *EventDelete) Where(ps ...predicate.Event) *EventDelete { // Exec executes the deletion query and returns how many vertices were deleted. func (ed *EventDelete) Exec(ctx context.Context) (int, error) { - var ( - err error - affected int - ) - if len(ed.hooks) == 0 { - affected, err = ed.sqlExec(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*EventMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - ed.mutation = mutation - affected, err = ed.sqlExec(ctx) - mutation.done = true - return affected, err - }) - for i := len(ed.hooks) - 1; i >= 0; i-- { - if ed.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = ed.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, ed.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks(ctx, ed.sqlExec, ed.mutation, ed.hooks) } // ExecX is like Exec, but panics if an error occurs. @@ -68,15 +40,7 @@ func (ed *EventDelete) ExecX(ctx context.Context) int { } func (ed *EventDelete) sqlExec(ctx context.Context) (int, error) { - _spec := &sqlgraph.DeleteSpec{ - Node: &sqlgraph.NodeSpec{ - Table: event.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: event.FieldID, - }, - }, - } + _spec := sqlgraph.NewDeleteSpec(event.Table, sqlgraph.NewFieldSpec(event.FieldID, field.TypeInt)) if ps := ed.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -88,6 +52,7 @@ func (ed *EventDelete) sqlExec(ctx context.Context) (int, error) { if err != nil && sqlgraph.IsConstraintError(err) { err = &ConstraintError{msg: err.Error(), wrap: err} } + ed.mutation.done = true return affected, err } @@ -96,6 +61,12 @@ type EventDeleteOne struct { ed *EventDelete } +// Where appends a list predicates to the EventDelete builder. +func (edo *EventDeleteOne) Where(ps ...predicate.Event) *EventDeleteOne { + edo.ed.mutation.Where(ps...) + return edo +} + // Exec executes the deletion query. func (edo *EventDeleteOne) Exec(ctx context.Context) error { n, err := edo.ed.Exec(ctx) @@ -111,5 +82,7 @@ func (edo *EventDeleteOne) Exec(ctx context.Context) error { // ExecX is like Exec, but panics if an error occurs. func (edo *EventDeleteOne) ExecX(ctx context.Context) { - edo.ed.ExecX(ctx) + if err := edo.Exec(ctx); err != nil { + panic(err) + } } diff --git a/pkg/database/ent/event_query.go b/pkg/database/ent/event_query.go index 045d750f818..1493d7bd32c 100644 --- a/pkg/database/ent/event_query.go +++ b/pkg/database/ent/event_query.go @@ -18,11 +18,9 @@ import ( // EventQuery is the builder for querying Event entities. type EventQuery struct { config - limit *int - offset *int - unique *bool - order []OrderFunc - fields []string + ctx *QueryContext + order []event.OrderOption + inters []Interceptor predicates []predicate.Event withOwner *AlertQuery // intermediate query (i.e. traversal path). @@ -36,34 +34,34 @@ func (eq *EventQuery) Where(ps ...predicate.Event) *EventQuery { return eq } -// Limit adds a limit step to the query. +// Limit the number of records to be returned by this query. func (eq *EventQuery) Limit(limit int) *EventQuery { - eq.limit = &limit + eq.ctx.Limit = &limit return eq } -// Offset adds an offset step to the query. +// Offset to start from. func (eq *EventQuery) Offset(offset int) *EventQuery { - eq.offset = &offset + eq.ctx.Offset = &offset return eq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (eq *EventQuery) Unique(unique bool) *EventQuery { - eq.unique = &unique + eq.ctx.Unique = &unique return eq } -// Order adds an order step to the query. -func (eq *EventQuery) Order(o ...OrderFunc) *EventQuery { +// Order specifies how the records should be ordered. +func (eq *EventQuery) Order(o ...event.OrderOption) *EventQuery { eq.order = append(eq.order, o...) return eq } // QueryOwner chains the current query on the "owner" edge. func (eq *EventQuery) QueryOwner() *AlertQuery { - query := &AlertQuery{config: eq.config} + query := (&AlertClient{config: eq.config}).Query() query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { if err := eq.prepareQuery(ctx); err != nil { return nil, err @@ -86,7 +84,7 @@ func (eq *EventQuery) QueryOwner() *AlertQuery { // First returns the first Event entity from the query. // Returns a *NotFoundError when no Event was found. func (eq *EventQuery) First(ctx context.Context) (*Event, error) { - nodes, err := eq.Limit(1).All(ctx) + nodes, err := eq.Limit(1).All(setContextOp(ctx, eq.ctx, "First")) if err != nil { return nil, err } @@ -109,7 +107,7 @@ func (eq *EventQuery) FirstX(ctx context.Context) *Event { // Returns a *NotFoundError when no Event ID was found. func (eq *EventQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = eq.Limit(1).IDs(ctx); err != nil { + if ids, err = eq.Limit(1).IDs(setContextOp(ctx, eq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -132,7 +130,7 @@ func (eq *EventQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one Event entity is found. // Returns a *NotFoundError when no Event entities are found. func (eq *EventQuery) Only(ctx context.Context) (*Event, error) { - nodes, err := eq.Limit(2).All(ctx) + nodes, err := eq.Limit(2).All(setContextOp(ctx, eq.ctx, "Only")) if err != nil { return nil, err } @@ -160,7 +158,7 @@ func (eq *EventQuery) OnlyX(ctx context.Context) *Event { // Returns a *NotFoundError when no entities are found. func (eq *EventQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = eq.Limit(2).IDs(ctx); err != nil { + if ids, err = eq.Limit(2).IDs(setContextOp(ctx, eq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -185,10 +183,12 @@ func (eq *EventQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of Events. func (eq *EventQuery) All(ctx context.Context) ([]*Event, error) { + ctx = setContextOp(ctx, eq.ctx, "All") if err := eq.prepareQuery(ctx); err != nil { return nil, err } - return eq.sqlAll(ctx) + qr := querierAll[[]*Event, *EventQuery]() + return withInterceptors[[]*Event](ctx, eq, qr, eq.inters) } // AllX is like All, but panics if an error occurs. @@ -201,9 +201,12 @@ func (eq *EventQuery) AllX(ctx context.Context) []*Event { } // IDs executes the query and returns a list of Event IDs. -func (eq *EventQuery) IDs(ctx context.Context) ([]int, error) { - var ids []int - if err := eq.Select(event.FieldID).Scan(ctx, &ids); err != nil { +func (eq *EventQuery) IDs(ctx context.Context) (ids []int, err error) { + if eq.ctx.Unique == nil && eq.path != nil { + eq.Unique(true) + } + ctx = setContextOp(ctx, eq.ctx, "IDs") + if err = eq.Select(event.FieldID).Scan(ctx, &ids); err != nil { return nil, err } return ids, nil @@ -220,10 +223,11 @@ func (eq *EventQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (eq *EventQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, eq.ctx, "Count") if err := eq.prepareQuery(ctx); err != nil { return 0, err } - return eq.sqlCount(ctx) + return withInterceptors[int](ctx, eq, querierCount[*EventQuery](), eq.inters) } // CountX is like Count, but panics if an error occurs. @@ -237,10 +241,15 @@ func (eq *EventQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (eq *EventQuery) Exist(ctx context.Context) (bool, error) { - if err := eq.prepareQuery(ctx); err != nil { - return false, err + ctx = setContextOp(ctx, eq.ctx, "Exist") + switch _, err := eq.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil } - return eq.sqlExist(ctx) } // ExistX is like Exist, but panics if an error occurs. @@ -260,22 +269,21 @@ func (eq *EventQuery) Clone() *EventQuery { } return &EventQuery{ config: eq.config, - limit: eq.limit, - offset: eq.offset, - order: append([]OrderFunc{}, eq.order...), + ctx: eq.ctx.Clone(), + order: append([]event.OrderOption{}, eq.order...), + inters: append([]Interceptor{}, eq.inters...), predicates: append([]predicate.Event{}, eq.predicates...), withOwner: eq.withOwner.Clone(), // clone intermediate query. - sql: eq.sql.Clone(), - path: eq.path, - unique: eq.unique, + sql: eq.sql.Clone(), + path: eq.path, } } // WithOwner tells the query-builder to eager-load the nodes that are connected to // the "owner" edge. The optional arguments are used to configure the query builder of the edge. func (eq *EventQuery) WithOwner(opts ...func(*AlertQuery)) *EventQuery { - query := &AlertQuery{config: eq.config} + query := (&AlertClient{config: eq.config}).Query() for _, opt := range opts { opt(query) } @@ -298,16 +306,11 @@ func (eq *EventQuery) WithOwner(opts ...func(*AlertQuery)) *EventQuery { // Aggregate(ent.Count()). // Scan(ctx, &v) func (eq *EventQuery) GroupBy(field string, fields ...string) *EventGroupBy { - grbuild := &EventGroupBy{config: eq.config} - grbuild.fields = append([]string{field}, fields...) - grbuild.path = func(ctx context.Context) (prev *sql.Selector, err error) { - if err := eq.prepareQuery(ctx); err != nil { - return nil, err - } - return eq.sqlQuery(ctx), nil - } + eq.ctx.Fields = append([]string{field}, fields...) + grbuild := &EventGroupBy{build: eq} + grbuild.flds = &eq.ctx.Fields grbuild.label = event.Label - grbuild.flds, grbuild.scan = &grbuild.fields, grbuild.Scan + grbuild.scan = grbuild.Scan return grbuild } @@ -324,15 +327,30 @@ func (eq *EventQuery) GroupBy(field string, fields ...string) *EventGroupBy { // Select(event.FieldCreatedAt). // Scan(ctx, &v) func (eq *EventQuery) Select(fields ...string) *EventSelect { - eq.fields = append(eq.fields, fields...) - selbuild := &EventSelect{EventQuery: eq} - selbuild.label = event.Label - selbuild.flds, selbuild.scan = &eq.fields, selbuild.Scan - return selbuild + eq.ctx.Fields = append(eq.ctx.Fields, fields...) + sbuild := &EventSelect{EventQuery: eq} + sbuild.label = event.Label + sbuild.flds, sbuild.scan = &eq.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a EventSelect configured with the given aggregations. +func (eq *EventQuery) Aggregate(fns ...AggregateFunc) *EventSelect { + return eq.Select().Aggregate(fns...) } func (eq *EventQuery) prepareQuery(ctx context.Context) error { - for _, f := range eq.fields { + for _, inter := range eq.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, eq); err != nil { + return err + } + } + } + for _, f := range eq.ctx.Fields { if !event.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -392,6 +410,9 @@ func (eq *EventQuery) loadOwner(ctx context.Context, query *AlertQuery, nodes [] } nodeids[fk] = append(nodeids[fk], nodes[i]) } + if len(ids) == 0 { + return nil + } query.Where(alert.IDIn(ids...)) neighbors, err := query.All(ctx) if err != nil { @@ -411,41 +432,22 @@ func (eq *EventQuery) loadOwner(ctx context.Context, query *AlertQuery, nodes [] func (eq *EventQuery) sqlCount(ctx context.Context) (int, error) { _spec := eq.querySpec() - _spec.Node.Columns = eq.fields - if len(eq.fields) > 0 { - _spec.Unique = eq.unique != nil && *eq.unique + _spec.Node.Columns = eq.ctx.Fields + if len(eq.ctx.Fields) > 0 { + _spec.Unique = eq.ctx.Unique != nil && *eq.ctx.Unique } return sqlgraph.CountNodes(ctx, eq.driver, _spec) } -func (eq *EventQuery) sqlExist(ctx context.Context) (bool, error) { - switch _, err := eq.FirstID(ctx); { - case IsNotFound(err): - return false, nil - case err != nil: - return false, fmt.Errorf("ent: check existence: %w", err) - default: - return true, nil - } -} - func (eq *EventQuery) querySpec() *sqlgraph.QuerySpec { - _spec := &sqlgraph.QuerySpec{ - Node: &sqlgraph.NodeSpec{ - Table: event.Table, - Columns: event.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: event.FieldID, - }, - }, - From: eq.sql, - Unique: true, - } - if unique := eq.unique; unique != nil { + _spec := sqlgraph.NewQuerySpec(event.Table, event.Columns, sqlgraph.NewFieldSpec(event.FieldID, field.TypeInt)) + _spec.From = eq.sql + if unique := eq.ctx.Unique; unique != nil { _spec.Unique = *unique + } else if eq.path != nil { + _spec.Unique = true } - if fields := eq.fields; len(fields) > 0 { + if fields := eq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, event.FieldID) for i := range fields { @@ -453,6 +455,9 @@ func (eq *EventQuery) querySpec() *sqlgraph.QuerySpec { _spec.Node.Columns = append(_spec.Node.Columns, fields[i]) } } + if eq.withOwner != nil { + _spec.Node.AddColumnOnce(event.FieldAlertEvents) + } } if ps := eq.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { @@ -461,10 +466,10 @@ func (eq *EventQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := eq.limit; limit != nil { + if limit := eq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := eq.offset; offset != nil { + if offset := eq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := eq.order; len(ps) > 0 { @@ -480,7 +485,7 @@ func (eq *EventQuery) querySpec() *sqlgraph.QuerySpec { func (eq *EventQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(eq.driver.Dialect()) t1 := builder.Table(event.Table) - columns := eq.fields + columns := eq.ctx.Fields if len(columns) == 0 { columns = event.Columns } @@ -489,7 +494,7 @@ func (eq *EventQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = eq.sql selector.Select(selector.Columns(columns...)...) } - if eq.unique != nil && *eq.unique { + if eq.ctx.Unique != nil && *eq.ctx.Unique { selector.Distinct() } for _, p := range eq.predicates { @@ -498,12 +503,12 @@ func (eq *EventQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range eq.order { p(selector) } - if offset := eq.offset; offset != nil { + if offset := eq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := eq.limit; limit != nil { + if limit := eq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -511,13 +516,8 @@ func (eq *EventQuery) sqlQuery(ctx context.Context) *sql.Selector { // EventGroupBy is the group-by builder for Event entities. type EventGroupBy struct { - config selector - fields []string - fns []AggregateFunc - // intermediate query (i.e. traversal path). - sql *sql.Selector - path func(context.Context) (*sql.Selector, error) + build *EventQuery } // Aggregate adds the given aggregation functions to the group-by query. @@ -526,74 +526,77 @@ func (egb *EventGroupBy) Aggregate(fns ...AggregateFunc) *EventGroupBy { return egb } -// Scan applies the group-by query and scans the result into the given value. +// Scan applies the selector query and scans the result into the given value. func (egb *EventGroupBy) Scan(ctx context.Context, v any) error { - query, err := egb.path(ctx) - if err != nil { + ctx = setContextOp(ctx, egb.build.ctx, "GroupBy") + if err := egb.build.prepareQuery(ctx); err != nil { return err } - egb.sql = query - return egb.sqlScan(ctx, v) + return scanWithInterceptors[*EventQuery, *EventGroupBy](ctx, egb.build, egb, egb.build.inters, v) } -func (egb *EventGroupBy) sqlScan(ctx context.Context, v any) error { - for _, f := range egb.fields { - if !event.ValidColumn(f) { - return &ValidationError{Name: f, err: fmt.Errorf("invalid field %q for group-by", f)} +func (egb *EventGroupBy) sqlScan(ctx context.Context, root *EventQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(egb.fns)) + for _, fn := range egb.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*egb.flds)+len(egb.fns)) + for _, f := range *egb.flds { + columns = append(columns, selector.C(f)) } + columns = append(columns, aggregation...) + selector.Select(columns...) } - selector := egb.sqlQuery() + selector.GroupBy(selector.Columns(*egb.flds...)...) if err := selector.Err(); err != nil { return err } rows := &sql.Rows{} query, args := selector.Query() - if err := egb.driver.Query(ctx, query, args, rows); err != nil { + if err := egb.build.driver.Query(ctx, query, args, rows); err != nil { return err } defer rows.Close() return sql.ScanSlice(rows, v) } -func (egb *EventGroupBy) sqlQuery() *sql.Selector { - selector := egb.sql.Select() - aggregation := make([]string, 0, len(egb.fns)) - for _, fn := range egb.fns { - aggregation = append(aggregation, fn(selector)) - } - // If no columns were selected in a custom aggregation function, the default - // selection is the fields used for "group-by", and the aggregation functions. - if len(selector.SelectedColumns()) == 0 { - columns := make([]string, 0, len(egb.fields)+len(egb.fns)) - for _, f := range egb.fields { - columns = append(columns, selector.C(f)) - } - columns = append(columns, aggregation...) - selector.Select(columns...) - } - return selector.GroupBy(selector.Columns(egb.fields...)...) -} - // EventSelect is the builder for selecting fields of Event entities. type EventSelect struct { *EventQuery selector - // intermediate query (i.e. traversal path). - sql *sql.Selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (es *EventSelect) Aggregate(fns ...AggregateFunc) *EventSelect { + es.fns = append(es.fns, fns...) + return es } // Scan applies the selector query and scans the result into the given value. func (es *EventSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, es.ctx, "Select") if err := es.prepareQuery(ctx); err != nil { return err } - es.sql = es.EventQuery.sqlQuery(ctx) - return es.sqlScan(ctx, v) + return scanWithInterceptors[*EventQuery, *EventSelect](ctx, es.EventQuery, es, es.inters, v) } -func (es *EventSelect) sqlScan(ctx context.Context, v any) error { +func (es *EventSelect) sqlScan(ctx context.Context, root *EventQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(es.fns)) + for _, fn := range es.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*es.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } rows := &sql.Rows{} - query, args := es.sql.Query() + query, args := selector.Query() if err := es.driver.Query(ctx, query, args, rows); err != nil { return err } diff --git a/pkg/database/ent/event_update.go b/pkg/database/ent/event_update.go index fcd0cc50c99..db748101519 100644 --- a/pkg/database/ent/event_update.go +++ b/pkg/database/ent/event_update.go @@ -117,41 +117,8 @@ func (eu *EventUpdate) ClearOwner() *EventUpdate { // Save executes the query and returns the number of nodes affected by the update operation. func (eu *EventUpdate) Save(ctx context.Context) (int, error) { - var ( - err error - affected int - ) eu.defaults() - if len(eu.hooks) == 0 { - if err = eu.check(); err != nil { - return 0, err - } - affected, err = eu.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*EventMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - if err = eu.check(); err != nil { - return 0, err - } - eu.mutation = mutation - affected, err = eu.sqlSave(ctx) - mutation.done = true - return affected, err - }) - for i := len(eu.hooks) - 1; i >= 0; i-- { - if eu.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = eu.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, eu.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks(ctx, eu.sqlSave, eu.mutation, eu.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -199,16 +166,10 @@ func (eu *EventUpdate) check() error { } func (eu *EventUpdate) sqlSave(ctx context.Context) (n int, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: event.Table, - Columns: event.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: event.FieldID, - }, - }, + if err := eu.check(); err != nil { + return n, err } + _spec := sqlgraph.NewUpdateSpec(event.Table, event.Columns, sqlgraph.NewFieldSpec(event.FieldID, field.TypeInt)) if ps := eu.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -217,44 +178,22 @@ func (eu *EventUpdate) sqlSave(ctx context.Context) (n int, err error) { } } if value, ok := eu.mutation.CreatedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: event.FieldCreatedAt, - }) + _spec.SetField(event.FieldCreatedAt, field.TypeTime, value) } if eu.mutation.CreatedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: event.FieldCreatedAt, - }) + _spec.ClearField(event.FieldCreatedAt, field.TypeTime) } if value, ok := eu.mutation.UpdatedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: event.FieldUpdatedAt, - }) + _spec.SetField(event.FieldUpdatedAt, field.TypeTime, value) } if eu.mutation.UpdatedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: event.FieldUpdatedAt, - }) + _spec.ClearField(event.FieldUpdatedAt, field.TypeTime) } if value, ok := eu.mutation.Time(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: event.FieldTime, - }) + _spec.SetField(event.FieldTime, field.TypeTime, value) } if value, ok := eu.mutation.Serialized(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: event.FieldSerialized, - }) + _spec.SetField(event.FieldSerialized, field.TypeString, value) } if eu.mutation.OwnerCleared() { edge := &sqlgraph.EdgeSpec{ @@ -264,10 +203,7 @@ func (eu *EventUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{event.OwnerColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: alert.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt), }, } _spec.Edges.Clear = append(_spec.Edges.Clear, edge) @@ -280,10 +216,7 @@ func (eu *EventUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{event.OwnerColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: alert.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -299,6 +232,7 @@ func (eu *EventUpdate) sqlSave(ctx context.Context) (n int, err error) { } return 0, err } + eu.mutation.done = true return n, nil } @@ -396,6 +330,12 @@ func (euo *EventUpdateOne) ClearOwner() *EventUpdateOne { return euo } +// Where appends a list predicates to the EventUpdate builder. +func (euo *EventUpdateOne) Where(ps ...predicate.Event) *EventUpdateOne { + euo.mutation.Where(ps...) + return euo +} + // Select allows selecting one or more fields (columns) of the returned entity. // The default is selecting all fields defined in the entity schema. func (euo *EventUpdateOne) Select(field string, fields ...string) *EventUpdateOne { @@ -405,47 +345,8 @@ func (euo *EventUpdateOne) Select(field string, fields ...string) *EventUpdateOn // Save executes the query and returns the updated Event entity. func (euo *EventUpdateOne) Save(ctx context.Context) (*Event, error) { - var ( - err error - node *Event - ) euo.defaults() - if len(euo.hooks) == 0 { - if err = euo.check(); err != nil { - return nil, err - } - node, err = euo.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*EventMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - if err = euo.check(); err != nil { - return nil, err - } - euo.mutation = mutation - node, err = euo.sqlSave(ctx) - mutation.done = true - return node, err - }) - for i := len(euo.hooks) - 1; i >= 0; i-- { - if euo.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = euo.hooks[i](mut) - } - v, err := mut.Mutate(ctx, euo.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*Event) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from EventMutation", v) - } - node = nv - } - return node, err + return withHooks(ctx, euo.sqlSave, euo.mutation, euo.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -493,16 +394,10 @@ func (euo *EventUpdateOne) check() error { } func (euo *EventUpdateOne) sqlSave(ctx context.Context) (_node *Event, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: event.Table, - Columns: event.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: event.FieldID, - }, - }, + if err := euo.check(); err != nil { + return _node, err } + _spec := sqlgraph.NewUpdateSpec(event.Table, event.Columns, sqlgraph.NewFieldSpec(event.FieldID, field.TypeInt)) id, ok := euo.mutation.ID() if !ok { return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "Event.id" for update`)} @@ -528,44 +423,22 @@ func (euo *EventUpdateOne) sqlSave(ctx context.Context) (_node *Event, err error } } if value, ok := euo.mutation.CreatedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: event.FieldCreatedAt, - }) + _spec.SetField(event.FieldCreatedAt, field.TypeTime, value) } if euo.mutation.CreatedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: event.FieldCreatedAt, - }) + _spec.ClearField(event.FieldCreatedAt, field.TypeTime) } if value, ok := euo.mutation.UpdatedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: event.FieldUpdatedAt, - }) + _spec.SetField(event.FieldUpdatedAt, field.TypeTime, value) } if euo.mutation.UpdatedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: event.FieldUpdatedAt, - }) + _spec.ClearField(event.FieldUpdatedAt, field.TypeTime) } if value, ok := euo.mutation.Time(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: event.FieldTime, - }) + _spec.SetField(event.FieldTime, field.TypeTime, value) } if value, ok := euo.mutation.Serialized(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: event.FieldSerialized, - }) + _spec.SetField(event.FieldSerialized, field.TypeString, value) } if euo.mutation.OwnerCleared() { edge := &sqlgraph.EdgeSpec{ @@ -575,10 +448,7 @@ func (euo *EventUpdateOne) sqlSave(ctx context.Context) (_node *Event, err error Columns: []string{event.OwnerColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: alert.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt), }, } _spec.Edges.Clear = append(_spec.Edges.Clear, edge) @@ -591,10 +461,7 @@ func (euo *EventUpdateOne) sqlSave(ctx context.Context) (_node *Event, err error Columns: []string{event.OwnerColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: alert.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -613,5 +480,6 @@ func (euo *EventUpdateOne) sqlSave(ctx context.Context) (_node *Event, err error } return nil, err } + euo.mutation.done = true return _node, nil } diff --git a/pkg/database/ent/hook/hook.go b/pkg/database/ent/hook/hook.go index 85ab00b01fb..7ec9c3ab1d8 100644 --- a/pkg/database/ent/hook/hook.go +++ b/pkg/database/ent/hook/hook.go @@ -15,11 +15,10 @@ type AlertFunc func(context.Context, *ent.AlertMutation) (ent.Value, error) // Mutate calls f(ctx, m). func (f AlertFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { - mv, ok := m.(*ent.AlertMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.AlertMutation", m) + if mv, ok := m.(*ent.AlertMutation); ok { + return f(ctx, mv) } - return f(ctx, mv) + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.AlertMutation", m) } // The BouncerFunc type is an adapter to allow the use of ordinary @@ -28,11 +27,10 @@ type BouncerFunc func(context.Context, *ent.BouncerMutation) (ent.Value, error) // Mutate calls f(ctx, m). func (f BouncerFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { - mv, ok := m.(*ent.BouncerMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.BouncerMutation", m) + if mv, ok := m.(*ent.BouncerMutation); ok { + return f(ctx, mv) } - return f(ctx, mv) + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.BouncerMutation", m) } // The ConfigItemFunc type is an adapter to allow the use of ordinary @@ -41,11 +39,10 @@ type ConfigItemFunc func(context.Context, *ent.ConfigItemMutation) (ent.Value, e // Mutate calls f(ctx, m). func (f ConfigItemFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { - mv, ok := m.(*ent.ConfigItemMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.ConfigItemMutation", m) + if mv, ok := m.(*ent.ConfigItemMutation); ok { + return f(ctx, mv) } - return f(ctx, mv) + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.ConfigItemMutation", m) } // The DecisionFunc type is an adapter to allow the use of ordinary @@ -54,11 +51,10 @@ type DecisionFunc func(context.Context, *ent.DecisionMutation) (ent.Value, error // Mutate calls f(ctx, m). func (f DecisionFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { - mv, ok := m.(*ent.DecisionMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.DecisionMutation", m) + if mv, ok := m.(*ent.DecisionMutation); ok { + return f(ctx, mv) } - return f(ctx, mv) + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.DecisionMutation", m) } // The EventFunc type is an adapter to allow the use of ordinary @@ -67,11 +63,10 @@ type EventFunc func(context.Context, *ent.EventMutation) (ent.Value, error) // Mutate calls f(ctx, m). func (f EventFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { - mv, ok := m.(*ent.EventMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.EventMutation", m) + if mv, ok := m.(*ent.EventMutation); ok { + return f(ctx, mv) } - return f(ctx, mv) + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.EventMutation", m) } // The MachineFunc type is an adapter to allow the use of ordinary @@ -80,11 +75,10 @@ type MachineFunc func(context.Context, *ent.MachineMutation) (ent.Value, error) // Mutate calls f(ctx, m). func (f MachineFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { - mv, ok := m.(*ent.MachineMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.MachineMutation", m) + if mv, ok := m.(*ent.MachineMutation); ok { + return f(ctx, mv) } - return f(ctx, mv) + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.MachineMutation", m) } // The MetaFunc type is an adapter to allow the use of ordinary @@ -93,11 +87,10 @@ type MetaFunc func(context.Context, *ent.MetaMutation) (ent.Value, error) // Mutate calls f(ctx, m). func (f MetaFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { - mv, ok := m.(*ent.MetaMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.MetaMutation", m) + if mv, ok := m.(*ent.MetaMutation); ok { + return f(ctx, mv) } - return f(ctx, mv) + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.MetaMutation", m) } // Condition is a hook condition function. diff --git a/pkg/database/ent/machine.go b/pkg/database/ent/machine.go index dc2b18ee81c..346a8d084ba 100644 --- a/pkg/database/ent/machine.go +++ b/pkg/database/ent/machine.go @@ -7,6 +7,7 @@ import ( "strings" "time" + "entgo.io/ent" "entgo.io/ent/dialect/sql" "github.com/crowdsecurity/crowdsec/pkg/database/ent/machine" ) @@ -42,7 +43,8 @@ type Machine struct { AuthType string `json:"auth_type"` // Edges holds the relations/edges for other nodes in the graph. // The values are being populated by the MachineQuery when eager-loading is set. - Edges MachineEdges `json:"edges"` + Edges MachineEdges `json:"edges"` + selectValues sql.SelectValues } // MachineEdges holds the relations/edges for other nodes in the graph. @@ -77,7 +79,7 @@ func (*Machine) scanValues(columns []string) ([]any, error) { case machine.FieldCreatedAt, machine.FieldUpdatedAt, machine.FieldLastPush, machine.FieldLastHeartbeat: values[i] = new(sql.NullTime) default: - return nil, fmt.Errorf("unexpected column %q for type Machine", columns[i]) + values[i] = new(sql.UnknownType) } } return values, nil @@ -173,21 +175,29 @@ func (m *Machine) assignValues(columns []string, values []any) error { } else if value.Valid { m.AuthType = value.String } + default: + m.selectValues.Set(columns[i], values[i]) } } return nil } +// Value returns the ent.Value that was dynamically selected and assigned to the Machine. +// This includes values selected through modifiers, order, etc. +func (m *Machine) Value(name string) (ent.Value, error) { + return m.selectValues.Get(name) +} + // QueryAlerts queries the "alerts" edge of the Machine entity. func (m *Machine) QueryAlerts() *AlertQuery { - return (&MachineClient{config: m.config}).QueryAlerts(m) + return NewMachineClient(m.config).QueryAlerts(m) } // Update returns a builder for updating this Machine. // Note that you need to call Machine.Unwrap() before calling this method if this Machine // was returned from a transaction, and the transaction was committed or rolled back. func (m *Machine) Update() *MachineUpdateOne { - return (&MachineClient{config: m.config}).UpdateOne(m) + return NewMachineClient(m.config).UpdateOne(m) } // Unwrap unwraps the Machine entity that was returned from a transaction after it was closed, @@ -254,9 +264,3 @@ func (m *Machine) String() string { // Machines is a parsable slice of Machine. type Machines []*Machine - -func (m Machines) config(cfg config) { - for _i := range m { - m[_i].config = cfg - } -} diff --git a/pkg/database/ent/machine/machine.go b/pkg/database/ent/machine/machine.go index e6900dd21e1..5456935e04c 100644 --- a/pkg/database/ent/machine/machine.go +++ b/pkg/database/ent/machine/machine.go @@ -4,6 +4,9 @@ package machine import ( "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" ) const ( @@ -99,3 +102,92 @@ var ( // DefaultAuthType holds the default value on creation for the "auth_type" field. DefaultAuthType string ) + +// OrderOption defines the ordering options for the Machine queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByCreatedAt orders the results by the created_at field. +func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() +} + +// ByUpdatedAt orders the results by the updated_at field. +func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc() +} + +// ByLastPush orders the results by the last_push field. +func ByLastPush(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldLastPush, opts...).ToFunc() +} + +// ByLastHeartbeat orders the results by the last_heartbeat field. +func ByLastHeartbeat(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldLastHeartbeat, opts...).ToFunc() +} + +// ByMachineId orders the results by the machineId field. +func ByMachineId(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldMachineId, opts...).ToFunc() +} + +// ByPassword orders the results by the password field. +func ByPassword(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldPassword, opts...).ToFunc() +} + +// ByIpAddress orders the results by the ipAddress field. +func ByIpAddress(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldIpAddress, opts...).ToFunc() +} + +// ByScenarios orders the results by the scenarios field. +func ByScenarios(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldScenarios, opts...).ToFunc() +} + +// ByVersion orders the results by the version field. +func ByVersion(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldVersion, opts...).ToFunc() +} + +// ByIsValidated orders the results by the isValidated field. +func ByIsValidated(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldIsValidated, opts...).ToFunc() +} + +// ByStatus orders the results by the status field. +func ByStatus(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldStatus, opts...).ToFunc() +} + +// ByAuthType orders the results by the auth_type field. +func ByAuthType(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldAuthType, opts...).ToFunc() +} + +// ByAlertsCount orders the results by alerts count. +func ByAlertsCount(opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborsCount(s, newAlertsStep(), opts...) + } +} + +// ByAlerts orders the results by alerts terms. +func ByAlerts(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newAlertsStep(), append([]sql.OrderTerm{term}, terms...)...) + } +} +func newAlertsStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(AlertsInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, AlertsTable, AlertsColumn), + ) +} diff --git a/pkg/database/ent/machine/where.go b/pkg/database/ent/machine/where.go index 7d0227731cc..e9d00e7e01e 100644 --- a/pkg/database/ent/machine/where.go +++ b/pkg/database/ent/machine/where.go @@ -12,1218 +12,802 @@ import ( // ID filters vertices based on their ID field. func ID(id int) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.Machine(sql.FieldEQ(FieldID, id)) } // IDEQ applies the EQ predicate on the ID field. func IDEQ(id int) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.Machine(sql.FieldEQ(FieldID, id)) } // IDNEQ applies the NEQ predicate on the ID field. func IDNEQ(id int) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldID), id)) - }) + return predicate.Machine(sql.FieldNEQ(FieldID, id)) } // IDIn applies the In predicate on the ID field. func IDIn(ids ...int) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - v := make([]any, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.In(s.C(FieldID), v...)) - }) + return predicate.Machine(sql.FieldIn(FieldID, ids...)) } // IDNotIn applies the NotIn predicate on the ID field. func IDNotIn(ids ...int) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - v := make([]any, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.NotIn(s.C(FieldID), v...)) - }) + return predicate.Machine(sql.FieldNotIn(FieldID, ids...)) } // IDGT applies the GT predicate on the ID field. func IDGT(id int) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldID), id)) - }) + return predicate.Machine(sql.FieldGT(FieldID, id)) } // IDGTE applies the GTE predicate on the ID field. func IDGTE(id int) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldID), id)) - }) + return predicate.Machine(sql.FieldGTE(FieldID, id)) } // IDLT applies the LT predicate on the ID field. func IDLT(id int) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldID), id)) - }) + return predicate.Machine(sql.FieldLT(FieldID, id)) } // IDLTE applies the LTE predicate on the ID field. func IDLTE(id int) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldID), id)) - }) + return predicate.Machine(sql.FieldLTE(FieldID, id)) } // CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. func CreatedAt(v time.Time) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldCreatedAt), v)) - }) + return predicate.Machine(sql.FieldEQ(FieldCreatedAt, v)) } // UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ. func UpdatedAt(v time.Time) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldUpdatedAt), v)) - }) + return predicate.Machine(sql.FieldEQ(FieldUpdatedAt, v)) } // LastPush applies equality check predicate on the "last_push" field. It's identical to LastPushEQ. func LastPush(v time.Time) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldLastPush), v)) - }) + return predicate.Machine(sql.FieldEQ(FieldLastPush, v)) } // LastHeartbeat applies equality check predicate on the "last_heartbeat" field. It's identical to LastHeartbeatEQ. func LastHeartbeat(v time.Time) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldLastHeartbeat), v)) - }) + return predicate.Machine(sql.FieldEQ(FieldLastHeartbeat, v)) } // MachineId applies equality check predicate on the "machineId" field. It's identical to MachineIdEQ. func MachineId(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldMachineId), v)) - }) + return predicate.Machine(sql.FieldEQ(FieldMachineId, v)) } // Password applies equality check predicate on the "password" field. It's identical to PasswordEQ. func Password(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldPassword), v)) - }) + return predicate.Machine(sql.FieldEQ(FieldPassword, v)) } // IpAddress applies equality check predicate on the "ipAddress" field. It's identical to IpAddressEQ. func IpAddress(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldIpAddress), v)) - }) + return predicate.Machine(sql.FieldEQ(FieldIpAddress, v)) } // Scenarios applies equality check predicate on the "scenarios" field. It's identical to ScenariosEQ. func Scenarios(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldScenarios), v)) - }) + return predicate.Machine(sql.FieldEQ(FieldScenarios, v)) } // Version applies equality check predicate on the "version" field. It's identical to VersionEQ. func Version(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldVersion), v)) - }) + return predicate.Machine(sql.FieldEQ(FieldVersion, v)) } // IsValidated applies equality check predicate on the "isValidated" field. It's identical to IsValidatedEQ. func IsValidated(v bool) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldIsValidated), v)) - }) + return predicate.Machine(sql.FieldEQ(FieldIsValidated, v)) } // Status applies equality check predicate on the "status" field. It's identical to StatusEQ. func Status(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldStatus), v)) - }) + return predicate.Machine(sql.FieldEQ(FieldStatus, v)) } // AuthType applies equality check predicate on the "auth_type" field. It's identical to AuthTypeEQ. func AuthType(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldAuthType), v)) - }) + return predicate.Machine(sql.FieldEQ(FieldAuthType, v)) } // CreatedAtEQ applies the EQ predicate on the "created_at" field. func CreatedAtEQ(v time.Time) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldCreatedAt), v)) - }) + return predicate.Machine(sql.FieldEQ(FieldCreatedAt, v)) } // CreatedAtNEQ applies the NEQ predicate on the "created_at" field. func CreatedAtNEQ(v time.Time) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldCreatedAt), v)) - }) + return predicate.Machine(sql.FieldNEQ(FieldCreatedAt, v)) } // CreatedAtIn applies the In predicate on the "created_at" field. func CreatedAtIn(vs ...time.Time) predicate.Machine { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldCreatedAt), v...)) - }) + return predicate.Machine(sql.FieldIn(FieldCreatedAt, vs...)) } // CreatedAtNotIn applies the NotIn predicate on the "created_at" field. func CreatedAtNotIn(vs ...time.Time) predicate.Machine { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldCreatedAt), v...)) - }) + return predicate.Machine(sql.FieldNotIn(FieldCreatedAt, vs...)) } // CreatedAtGT applies the GT predicate on the "created_at" field. func CreatedAtGT(v time.Time) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldCreatedAt), v)) - }) + return predicate.Machine(sql.FieldGT(FieldCreatedAt, v)) } // CreatedAtGTE applies the GTE predicate on the "created_at" field. func CreatedAtGTE(v time.Time) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldCreatedAt), v)) - }) + return predicate.Machine(sql.FieldGTE(FieldCreatedAt, v)) } // CreatedAtLT applies the LT predicate on the "created_at" field. func CreatedAtLT(v time.Time) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldCreatedAt), v)) - }) + return predicate.Machine(sql.FieldLT(FieldCreatedAt, v)) } // CreatedAtLTE applies the LTE predicate on the "created_at" field. func CreatedAtLTE(v time.Time) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldCreatedAt), v)) - }) + return predicate.Machine(sql.FieldLTE(FieldCreatedAt, v)) } // CreatedAtIsNil applies the IsNil predicate on the "created_at" field. func CreatedAtIsNil() predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldCreatedAt))) - }) + return predicate.Machine(sql.FieldIsNull(FieldCreatedAt)) } // CreatedAtNotNil applies the NotNil predicate on the "created_at" field. func CreatedAtNotNil() predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldCreatedAt))) - }) + return predicate.Machine(sql.FieldNotNull(FieldCreatedAt)) } // UpdatedAtEQ applies the EQ predicate on the "updated_at" field. func UpdatedAtEQ(v time.Time) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldUpdatedAt), v)) - }) + return predicate.Machine(sql.FieldEQ(FieldUpdatedAt, v)) } // UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field. func UpdatedAtNEQ(v time.Time) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldUpdatedAt), v)) - }) + return predicate.Machine(sql.FieldNEQ(FieldUpdatedAt, v)) } // UpdatedAtIn applies the In predicate on the "updated_at" field. func UpdatedAtIn(vs ...time.Time) predicate.Machine { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldUpdatedAt), v...)) - }) + return predicate.Machine(sql.FieldIn(FieldUpdatedAt, vs...)) } // UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field. func UpdatedAtNotIn(vs ...time.Time) predicate.Machine { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldUpdatedAt), v...)) - }) + return predicate.Machine(sql.FieldNotIn(FieldUpdatedAt, vs...)) } // UpdatedAtGT applies the GT predicate on the "updated_at" field. func UpdatedAtGT(v time.Time) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldUpdatedAt), v)) - }) + return predicate.Machine(sql.FieldGT(FieldUpdatedAt, v)) } // UpdatedAtGTE applies the GTE predicate on the "updated_at" field. func UpdatedAtGTE(v time.Time) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldUpdatedAt), v)) - }) + return predicate.Machine(sql.FieldGTE(FieldUpdatedAt, v)) } // UpdatedAtLT applies the LT predicate on the "updated_at" field. func UpdatedAtLT(v time.Time) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldUpdatedAt), v)) - }) + return predicate.Machine(sql.FieldLT(FieldUpdatedAt, v)) } // UpdatedAtLTE applies the LTE predicate on the "updated_at" field. func UpdatedAtLTE(v time.Time) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldUpdatedAt), v)) - }) + return predicate.Machine(sql.FieldLTE(FieldUpdatedAt, v)) } // UpdatedAtIsNil applies the IsNil predicate on the "updated_at" field. func UpdatedAtIsNil() predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldUpdatedAt))) - }) + return predicate.Machine(sql.FieldIsNull(FieldUpdatedAt)) } // UpdatedAtNotNil applies the NotNil predicate on the "updated_at" field. func UpdatedAtNotNil() predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldUpdatedAt))) - }) + return predicate.Machine(sql.FieldNotNull(FieldUpdatedAt)) } // LastPushEQ applies the EQ predicate on the "last_push" field. func LastPushEQ(v time.Time) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldLastPush), v)) - }) + return predicate.Machine(sql.FieldEQ(FieldLastPush, v)) } // LastPushNEQ applies the NEQ predicate on the "last_push" field. func LastPushNEQ(v time.Time) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldLastPush), v)) - }) + return predicate.Machine(sql.FieldNEQ(FieldLastPush, v)) } // LastPushIn applies the In predicate on the "last_push" field. func LastPushIn(vs ...time.Time) predicate.Machine { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldLastPush), v...)) - }) + return predicate.Machine(sql.FieldIn(FieldLastPush, vs...)) } // LastPushNotIn applies the NotIn predicate on the "last_push" field. func LastPushNotIn(vs ...time.Time) predicate.Machine { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldLastPush), v...)) - }) + return predicate.Machine(sql.FieldNotIn(FieldLastPush, vs...)) } // LastPushGT applies the GT predicate on the "last_push" field. func LastPushGT(v time.Time) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldLastPush), v)) - }) + return predicate.Machine(sql.FieldGT(FieldLastPush, v)) } // LastPushGTE applies the GTE predicate on the "last_push" field. func LastPushGTE(v time.Time) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldLastPush), v)) - }) + return predicate.Machine(sql.FieldGTE(FieldLastPush, v)) } // LastPushLT applies the LT predicate on the "last_push" field. func LastPushLT(v time.Time) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldLastPush), v)) - }) + return predicate.Machine(sql.FieldLT(FieldLastPush, v)) } // LastPushLTE applies the LTE predicate on the "last_push" field. func LastPushLTE(v time.Time) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldLastPush), v)) - }) + return predicate.Machine(sql.FieldLTE(FieldLastPush, v)) } // LastPushIsNil applies the IsNil predicate on the "last_push" field. func LastPushIsNil() predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldLastPush))) - }) + return predicate.Machine(sql.FieldIsNull(FieldLastPush)) } // LastPushNotNil applies the NotNil predicate on the "last_push" field. func LastPushNotNil() predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldLastPush))) - }) + return predicate.Machine(sql.FieldNotNull(FieldLastPush)) } // LastHeartbeatEQ applies the EQ predicate on the "last_heartbeat" field. func LastHeartbeatEQ(v time.Time) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldLastHeartbeat), v)) - }) + return predicate.Machine(sql.FieldEQ(FieldLastHeartbeat, v)) } // LastHeartbeatNEQ applies the NEQ predicate on the "last_heartbeat" field. func LastHeartbeatNEQ(v time.Time) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldLastHeartbeat), v)) - }) + return predicate.Machine(sql.FieldNEQ(FieldLastHeartbeat, v)) } // LastHeartbeatIn applies the In predicate on the "last_heartbeat" field. func LastHeartbeatIn(vs ...time.Time) predicate.Machine { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldLastHeartbeat), v...)) - }) + return predicate.Machine(sql.FieldIn(FieldLastHeartbeat, vs...)) } // LastHeartbeatNotIn applies the NotIn predicate on the "last_heartbeat" field. func LastHeartbeatNotIn(vs ...time.Time) predicate.Machine { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldLastHeartbeat), v...)) - }) + return predicate.Machine(sql.FieldNotIn(FieldLastHeartbeat, vs...)) } // LastHeartbeatGT applies the GT predicate on the "last_heartbeat" field. func LastHeartbeatGT(v time.Time) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldLastHeartbeat), v)) - }) + return predicate.Machine(sql.FieldGT(FieldLastHeartbeat, v)) } // LastHeartbeatGTE applies the GTE predicate on the "last_heartbeat" field. func LastHeartbeatGTE(v time.Time) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldLastHeartbeat), v)) - }) + return predicate.Machine(sql.FieldGTE(FieldLastHeartbeat, v)) } // LastHeartbeatLT applies the LT predicate on the "last_heartbeat" field. func LastHeartbeatLT(v time.Time) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldLastHeartbeat), v)) - }) + return predicate.Machine(sql.FieldLT(FieldLastHeartbeat, v)) } // LastHeartbeatLTE applies the LTE predicate on the "last_heartbeat" field. func LastHeartbeatLTE(v time.Time) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldLastHeartbeat), v)) - }) + return predicate.Machine(sql.FieldLTE(FieldLastHeartbeat, v)) } // LastHeartbeatIsNil applies the IsNil predicate on the "last_heartbeat" field. func LastHeartbeatIsNil() predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldLastHeartbeat))) - }) + return predicate.Machine(sql.FieldIsNull(FieldLastHeartbeat)) } // LastHeartbeatNotNil applies the NotNil predicate on the "last_heartbeat" field. func LastHeartbeatNotNil() predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldLastHeartbeat))) - }) + return predicate.Machine(sql.FieldNotNull(FieldLastHeartbeat)) } // MachineIdEQ applies the EQ predicate on the "machineId" field. func MachineIdEQ(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldMachineId), v)) - }) + return predicate.Machine(sql.FieldEQ(FieldMachineId, v)) } // MachineIdNEQ applies the NEQ predicate on the "machineId" field. func MachineIdNEQ(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldMachineId), v)) - }) + return predicate.Machine(sql.FieldNEQ(FieldMachineId, v)) } // MachineIdIn applies the In predicate on the "machineId" field. func MachineIdIn(vs ...string) predicate.Machine { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldMachineId), v...)) - }) + return predicate.Machine(sql.FieldIn(FieldMachineId, vs...)) } // MachineIdNotIn applies the NotIn predicate on the "machineId" field. func MachineIdNotIn(vs ...string) predicate.Machine { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldMachineId), v...)) - }) + return predicate.Machine(sql.FieldNotIn(FieldMachineId, vs...)) } // MachineIdGT applies the GT predicate on the "machineId" field. func MachineIdGT(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldMachineId), v)) - }) + return predicate.Machine(sql.FieldGT(FieldMachineId, v)) } // MachineIdGTE applies the GTE predicate on the "machineId" field. func MachineIdGTE(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldMachineId), v)) - }) + return predicate.Machine(sql.FieldGTE(FieldMachineId, v)) } // MachineIdLT applies the LT predicate on the "machineId" field. func MachineIdLT(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldMachineId), v)) - }) + return predicate.Machine(sql.FieldLT(FieldMachineId, v)) } // MachineIdLTE applies the LTE predicate on the "machineId" field. func MachineIdLTE(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldMachineId), v)) - }) + return predicate.Machine(sql.FieldLTE(FieldMachineId, v)) } // MachineIdContains applies the Contains predicate on the "machineId" field. func MachineIdContains(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldMachineId), v)) - }) + return predicate.Machine(sql.FieldContains(FieldMachineId, v)) } // MachineIdHasPrefix applies the HasPrefix predicate on the "machineId" field. func MachineIdHasPrefix(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldMachineId), v)) - }) + return predicate.Machine(sql.FieldHasPrefix(FieldMachineId, v)) } // MachineIdHasSuffix applies the HasSuffix predicate on the "machineId" field. func MachineIdHasSuffix(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldMachineId), v)) - }) + return predicate.Machine(sql.FieldHasSuffix(FieldMachineId, v)) } // MachineIdEqualFold applies the EqualFold predicate on the "machineId" field. func MachineIdEqualFold(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldMachineId), v)) - }) + return predicate.Machine(sql.FieldEqualFold(FieldMachineId, v)) } // MachineIdContainsFold applies the ContainsFold predicate on the "machineId" field. func MachineIdContainsFold(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldMachineId), v)) - }) + return predicate.Machine(sql.FieldContainsFold(FieldMachineId, v)) } // PasswordEQ applies the EQ predicate on the "password" field. func PasswordEQ(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldPassword), v)) - }) + return predicate.Machine(sql.FieldEQ(FieldPassword, v)) } // PasswordNEQ applies the NEQ predicate on the "password" field. func PasswordNEQ(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldPassword), v)) - }) + return predicate.Machine(sql.FieldNEQ(FieldPassword, v)) } // PasswordIn applies the In predicate on the "password" field. func PasswordIn(vs ...string) predicate.Machine { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldPassword), v...)) - }) + return predicate.Machine(sql.FieldIn(FieldPassword, vs...)) } // PasswordNotIn applies the NotIn predicate on the "password" field. func PasswordNotIn(vs ...string) predicate.Machine { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldPassword), v...)) - }) + return predicate.Machine(sql.FieldNotIn(FieldPassword, vs...)) } // PasswordGT applies the GT predicate on the "password" field. func PasswordGT(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldPassword), v)) - }) + return predicate.Machine(sql.FieldGT(FieldPassword, v)) } // PasswordGTE applies the GTE predicate on the "password" field. func PasswordGTE(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldPassword), v)) - }) + return predicate.Machine(sql.FieldGTE(FieldPassword, v)) } // PasswordLT applies the LT predicate on the "password" field. func PasswordLT(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldPassword), v)) - }) + return predicate.Machine(sql.FieldLT(FieldPassword, v)) } // PasswordLTE applies the LTE predicate on the "password" field. func PasswordLTE(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldPassword), v)) - }) + return predicate.Machine(sql.FieldLTE(FieldPassword, v)) } // PasswordContains applies the Contains predicate on the "password" field. func PasswordContains(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldPassword), v)) - }) + return predicate.Machine(sql.FieldContains(FieldPassword, v)) } // PasswordHasPrefix applies the HasPrefix predicate on the "password" field. func PasswordHasPrefix(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldPassword), v)) - }) + return predicate.Machine(sql.FieldHasPrefix(FieldPassword, v)) } // PasswordHasSuffix applies the HasSuffix predicate on the "password" field. func PasswordHasSuffix(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldPassword), v)) - }) + return predicate.Machine(sql.FieldHasSuffix(FieldPassword, v)) } // PasswordEqualFold applies the EqualFold predicate on the "password" field. func PasswordEqualFold(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldPassword), v)) - }) + return predicate.Machine(sql.FieldEqualFold(FieldPassword, v)) } // PasswordContainsFold applies the ContainsFold predicate on the "password" field. func PasswordContainsFold(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldPassword), v)) - }) + return predicate.Machine(sql.FieldContainsFold(FieldPassword, v)) } // IpAddressEQ applies the EQ predicate on the "ipAddress" field. func IpAddressEQ(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldIpAddress), v)) - }) + return predicate.Machine(sql.FieldEQ(FieldIpAddress, v)) } // IpAddressNEQ applies the NEQ predicate on the "ipAddress" field. func IpAddressNEQ(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldIpAddress), v)) - }) + return predicate.Machine(sql.FieldNEQ(FieldIpAddress, v)) } // IpAddressIn applies the In predicate on the "ipAddress" field. func IpAddressIn(vs ...string) predicate.Machine { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldIpAddress), v...)) - }) + return predicate.Machine(sql.FieldIn(FieldIpAddress, vs...)) } // IpAddressNotIn applies the NotIn predicate on the "ipAddress" field. func IpAddressNotIn(vs ...string) predicate.Machine { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldIpAddress), v...)) - }) + return predicate.Machine(sql.FieldNotIn(FieldIpAddress, vs...)) } // IpAddressGT applies the GT predicate on the "ipAddress" field. func IpAddressGT(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldIpAddress), v)) - }) + return predicate.Machine(sql.FieldGT(FieldIpAddress, v)) } // IpAddressGTE applies the GTE predicate on the "ipAddress" field. func IpAddressGTE(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldIpAddress), v)) - }) + return predicate.Machine(sql.FieldGTE(FieldIpAddress, v)) } // IpAddressLT applies the LT predicate on the "ipAddress" field. func IpAddressLT(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldIpAddress), v)) - }) + return predicate.Machine(sql.FieldLT(FieldIpAddress, v)) } // IpAddressLTE applies the LTE predicate on the "ipAddress" field. func IpAddressLTE(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldIpAddress), v)) - }) + return predicate.Machine(sql.FieldLTE(FieldIpAddress, v)) } // IpAddressContains applies the Contains predicate on the "ipAddress" field. func IpAddressContains(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldIpAddress), v)) - }) + return predicate.Machine(sql.FieldContains(FieldIpAddress, v)) } // IpAddressHasPrefix applies the HasPrefix predicate on the "ipAddress" field. func IpAddressHasPrefix(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldIpAddress), v)) - }) + return predicate.Machine(sql.FieldHasPrefix(FieldIpAddress, v)) } // IpAddressHasSuffix applies the HasSuffix predicate on the "ipAddress" field. func IpAddressHasSuffix(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldIpAddress), v)) - }) + return predicate.Machine(sql.FieldHasSuffix(FieldIpAddress, v)) } // IpAddressEqualFold applies the EqualFold predicate on the "ipAddress" field. func IpAddressEqualFold(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldIpAddress), v)) - }) + return predicate.Machine(sql.FieldEqualFold(FieldIpAddress, v)) } // IpAddressContainsFold applies the ContainsFold predicate on the "ipAddress" field. func IpAddressContainsFold(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldIpAddress), v)) - }) + return predicate.Machine(sql.FieldContainsFold(FieldIpAddress, v)) } // ScenariosEQ applies the EQ predicate on the "scenarios" field. func ScenariosEQ(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldScenarios), v)) - }) + return predicate.Machine(sql.FieldEQ(FieldScenarios, v)) } // ScenariosNEQ applies the NEQ predicate on the "scenarios" field. func ScenariosNEQ(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldScenarios), v)) - }) + return predicate.Machine(sql.FieldNEQ(FieldScenarios, v)) } // ScenariosIn applies the In predicate on the "scenarios" field. func ScenariosIn(vs ...string) predicate.Machine { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldScenarios), v...)) - }) + return predicate.Machine(sql.FieldIn(FieldScenarios, vs...)) } // ScenariosNotIn applies the NotIn predicate on the "scenarios" field. func ScenariosNotIn(vs ...string) predicate.Machine { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldScenarios), v...)) - }) + return predicate.Machine(sql.FieldNotIn(FieldScenarios, vs...)) } // ScenariosGT applies the GT predicate on the "scenarios" field. func ScenariosGT(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldScenarios), v)) - }) + return predicate.Machine(sql.FieldGT(FieldScenarios, v)) } // ScenariosGTE applies the GTE predicate on the "scenarios" field. func ScenariosGTE(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldScenarios), v)) - }) + return predicate.Machine(sql.FieldGTE(FieldScenarios, v)) } // ScenariosLT applies the LT predicate on the "scenarios" field. func ScenariosLT(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldScenarios), v)) - }) + return predicate.Machine(sql.FieldLT(FieldScenarios, v)) } // ScenariosLTE applies the LTE predicate on the "scenarios" field. func ScenariosLTE(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldScenarios), v)) - }) + return predicate.Machine(sql.FieldLTE(FieldScenarios, v)) } // ScenariosContains applies the Contains predicate on the "scenarios" field. func ScenariosContains(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldScenarios), v)) - }) + return predicate.Machine(sql.FieldContains(FieldScenarios, v)) } // ScenariosHasPrefix applies the HasPrefix predicate on the "scenarios" field. func ScenariosHasPrefix(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldScenarios), v)) - }) + return predicate.Machine(sql.FieldHasPrefix(FieldScenarios, v)) } // ScenariosHasSuffix applies the HasSuffix predicate on the "scenarios" field. func ScenariosHasSuffix(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldScenarios), v)) - }) + return predicate.Machine(sql.FieldHasSuffix(FieldScenarios, v)) } // ScenariosIsNil applies the IsNil predicate on the "scenarios" field. func ScenariosIsNil() predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldScenarios))) - }) + return predicate.Machine(sql.FieldIsNull(FieldScenarios)) } // ScenariosNotNil applies the NotNil predicate on the "scenarios" field. func ScenariosNotNil() predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldScenarios))) - }) + return predicate.Machine(sql.FieldNotNull(FieldScenarios)) } // ScenariosEqualFold applies the EqualFold predicate on the "scenarios" field. func ScenariosEqualFold(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldScenarios), v)) - }) + return predicate.Machine(sql.FieldEqualFold(FieldScenarios, v)) } // ScenariosContainsFold applies the ContainsFold predicate on the "scenarios" field. func ScenariosContainsFold(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldScenarios), v)) - }) + return predicate.Machine(sql.FieldContainsFold(FieldScenarios, v)) } // VersionEQ applies the EQ predicate on the "version" field. func VersionEQ(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldVersion), v)) - }) + return predicate.Machine(sql.FieldEQ(FieldVersion, v)) } // VersionNEQ applies the NEQ predicate on the "version" field. func VersionNEQ(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldVersion), v)) - }) + return predicate.Machine(sql.FieldNEQ(FieldVersion, v)) } // VersionIn applies the In predicate on the "version" field. func VersionIn(vs ...string) predicate.Machine { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldVersion), v...)) - }) + return predicate.Machine(sql.FieldIn(FieldVersion, vs...)) } // VersionNotIn applies the NotIn predicate on the "version" field. func VersionNotIn(vs ...string) predicate.Machine { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldVersion), v...)) - }) + return predicate.Machine(sql.FieldNotIn(FieldVersion, vs...)) } // VersionGT applies the GT predicate on the "version" field. func VersionGT(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldVersion), v)) - }) + return predicate.Machine(sql.FieldGT(FieldVersion, v)) } // VersionGTE applies the GTE predicate on the "version" field. func VersionGTE(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldVersion), v)) - }) + return predicate.Machine(sql.FieldGTE(FieldVersion, v)) } // VersionLT applies the LT predicate on the "version" field. func VersionLT(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldVersion), v)) - }) + return predicate.Machine(sql.FieldLT(FieldVersion, v)) } // VersionLTE applies the LTE predicate on the "version" field. func VersionLTE(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldVersion), v)) - }) + return predicate.Machine(sql.FieldLTE(FieldVersion, v)) } // VersionContains applies the Contains predicate on the "version" field. func VersionContains(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldVersion), v)) - }) + return predicate.Machine(sql.FieldContains(FieldVersion, v)) } // VersionHasPrefix applies the HasPrefix predicate on the "version" field. func VersionHasPrefix(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldVersion), v)) - }) + return predicate.Machine(sql.FieldHasPrefix(FieldVersion, v)) } // VersionHasSuffix applies the HasSuffix predicate on the "version" field. func VersionHasSuffix(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldVersion), v)) - }) + return predicate.Machine(sql.FieldHasSuffix(FieldVersion, v)) } // VersionIsNil applies the IsNil predicate on the "version" field. func VersionIsNil() predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldVersion))) - }) + return predicate.Machine(sql.FieldIsNull(FieldVersion)) } // VersionNotNil applies the NotNil predicate on the "version" field. func VersionNotNil() predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldVersion))) - }) + return predicate.Machine(sql.FieldNotNull(FieldVersion)) } // VersionEqualFold applies the EqualFold predicate on the "version" field. func VersionEqualFold(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldVersion), v)) - }) + return predicate.Machine(sql.FieldEqualFold(FieldVersion, v)) } // VersionContainsFold applies the ContainsFold predicate on the "version" field. func VersionContainsFold(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldVersion), v)) - }) + return predicate.Machine(sql.FieldContainsFold(FieldVersion, v)) } // IsValidatedEQ applies the EQ predicate on the "isValidated" field. func IsValidatedEQ(v bool) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldIsValidated), v)) - }) + return predicate.Machine(sql.FieldEQ(FieldIsValidated, v)) } // IsValidatedNEQ applies the NEQ predicate on the "isValidated" field. func IsValidatedNEQ(v bool) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldIsValidated), v)) - }) + return predicate.Machine(sql.FieldNEQ(FieldIsValidated, v)) } // StatusEQ applies the EQ predicate on the "status" field. func StatusEQ(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldStatus), v)) - }) + return predicate.Machine(sql.FieldEQ(FieldStatus, v)) } // StatusNEQ applies the NEQ predicate on the "status" field. func StatusNEQ(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldStatus), v)) - }) + return predicate.Machine(sql.FieldNEQ(FieldStatus, v)) } // StatusIn applies the In predicate on the "status" field. func StatusIn(vs ...string) predicate.Machine { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldStatus), v...)) - }) + return predicate.Machine(sql.FieldIn(FieldStatus, vs...)) } // StatusNotIn applies the NotIn predicate on the "status" field. func StatusNotIn(vs ...string) predicate.Machine { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldStatus), v...)) - }) + return predicate.Machine(sql.FieldNotIn(FieldStatus, vs...)) } // StatusGT applies the GT predicate on the "status" field. func StatusGT(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldStatus), v)) - }) + return predicate.Machine(sql.FieldGT(FieldStatus, v)) } // StatusGTE applies the GTE predicate on the "status" field. func StatusGTE(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldStatus), v)) - }) + return predicate.Machine(sql.FieldGTE(FieldStatus, v)) } // StatusLT applies the LT predicate on the "status" field. func StatusLT(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldStatus), v)) - }) + return predicate.Machine(sql.FieldLT(FieldStatus, v)) } // StatusLTE applies the LTE predicate on the "status" field. func StatusLTE(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldStatus), v)) - }) + return predicate.Machine(sql.FieldLTE(FieldStatus, v)) } // StatusContains applies the Contains predicate on the "status" field. func StatusContains(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldStatus), v)) - }) + return predicate.Machine(sql.FieldContains(FieldStatus, v)) } // StatusHasPrefix applies the HasPrefix predicate on the "status" field. func StatusHasPrefix(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldStatus), v)) - }) + return predicate.Machine(sql.FieldHasPrefix(FieldStatus, v)) } // StatusHasSuffix applies the HasSuffix predicate on the "status" field. func StatusHasSuffix(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldStatus), v)) - }) + return predicate.Machine(sql.FieldHasSuffix(FieldStatus, v)) } // StatusIsNil applies the IsNil predicate on the "status" field. func StatusIsNil() predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldStatus))) - }) + return predicate.Machine(sql.FieldIsNull(FieldStatus)) } // StatusNotNil applies the NotNil predicate on the "status" field. func StatusNotNil() predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldStatus))) - }) + return predicate.Machine(sql.FieldNotNull(FieldStatus)) } // StatusEqualFold applies the EqualFold predicate on the "status" field. func StatusEqualFold(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldStatus), v)) - }) + return predicate.Machine(sql.FieldEqualFold(FieldStatus, v)) } // StatusContainsFold applies the ContainsFold predicate on the "status" field. func StatusContainsFold(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldStatus), v)) - }) + return predicate.Machine(sql.FieldContainsFold(FieldStatus, v)) } // AuthTypeEQ applies the EQ predicate on the "auth_type" field. func AuthTypeEQ(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldAuthType), v)) - }) + return predicate.Machine(sql.FieldEQ(FieldAuthType, v)) } // AuthTypeNEQ applies the NEQ predicate on the "auth_type" field. func AuthTypeNEQ(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldAuthType), v)) - }) + return predicate.Machine(sql.FieldNEQ(FieldAuthType, v)) } // AuthTypeIn applies the In predicate on the "auth_type" field. func AuthTypeIn(vs ...string) predicate.Machine { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldAuthType), v...)) - }) + return predicate.Machine(sql.FieldIn(FieldAuthType, vs...)) } // AuthTypeNotIn applies the NotIn predicate on the "auth_type" field. func AuthTypeNotIn(vs ...string) predicate.Machine { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldAuthType), v...)) - }) + return predicate.Machine(sql.FieldNotIn(FieldAuthType, vs...)) } // AuthTypeGT applies the GT predicate on the "auth_type" field. func AuthTypeGT(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldAuthType), v)) - }) + return predicate.Machine(sql.FieldGT(FieldAuthType, v)) } // AuthTypeGTE applies the GTE predicate on the "auth_type" field. func AuthTypeGTE(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldAuthType), v)) - }) + return predicate.Machine(sql.FieldGTE(FieldAuthType, v)) } // AuthTypeLT applies the LT predicate on the "auth_type" field. func AuthTypeLT(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldAuthType), v)) - }) + return predicate.Machine(sql.FieldLT(FieldAuthType, v)) } // AuthTypeLTE applies the LTE predicate on the "auth_type" field. func AuthTypeLTE(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldAuthType), v)) - }) + return predicate.Machine(sql.FieldLTE(FieldAuthType, v)) } // AuthTypeContains applies the Contains predicate on the "auth_type" field. func AuthTypeContains(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldAuthType), v)) - }) + return predicate.Machine(sql.FieldContains(FieldAuthType, v)) } // AuthTypeHasPrefix applies the HasPrefix predicate on the "auth_type" field. func AuthTypeHasPrefix(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldAuthType), v)) - }) + return predicate.Machine(sql.FieldHasPrefix(FieldAuthType, v)) } // AuthTypeHasSuffix applies the HasSuffix predicate on the "auth_type" field. func AuthTypeHasSuffix(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldAuthType), v)) - }) + return predicate.Machine(sql.FieldHasSuffix(FieldAuthType, v)) } // AuthTypeEqualFold applies the EqualFold predicate on the "auth_type" field. func AuthTypeEqualFold(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldAuthType), v)) - }) + return predicate.Machine(sql.FieldEqualFold(FieldAuthType, v)) } // AuthTypeContainsFold applies the ContainsFold predicate on the "auth_type" field. func AuthTypeContainsFold(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldAuthType), v)) - }) + return predicate.Machine(sql.FieldContainsFold(FieldAuthType, v)) } // HasAlerts applies the HasEdge predicate on the "alerts" edge. @@ -1231,7 +815,6 @@ func HasAlerts() predicate.Machine { return predicate.Machine(func(s *sql.Selector) { step := sqlgraph.NewStep( sqlgraph.From(Table, FieldID), - sqlgraph.To(AlertsTable, FieldID), sqlgraph.Edge(sqlgraph.O2M, false, AlertsTable, AlertsColumn), ) sqlgraph.HasNeighbors(s, step) @@ -1241,11 +824,7 @@ func HasAlerts() predicate.Machine { // HasAlertsWith applies the HasEdge predicate on the "alerts" edge with a given conditions (other predicates). func HasAlertsWith(preds ...predicate.Alert) predicate.Machine { return predicate.Machine(func(s *sql.Selector) { - step := sqlgraph.NewStep( - sqlgraph.From(Table, FieldID), - sqlgraph.To(AlertsInverseTable, FieldID), - sqlgraph.Edge(sqlgraph.O2M, false, AlertsTable, AlertsColumn), - ) + step := newAlertsStep() sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { for _, p := range preds { p(s) @@ -1256,32 +835,15 @@ func HasAlertsWith(preds ...predicate.Alert) predicate.Machine { // And groups predicates with the AND operator between them. func And(predicates ...predicate.Machine) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s1 := s.Clone().SetP(nil) - for _, p := range predicates { - p(s1) - } - s.Where(s1.P()) - }) + return predicate.Machine(sql.AndPredicates(predicates...)) } // Or groups predicates with the OR operator between them. func Or(predicates ...predicate.Machine) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s1 := s.Clone().SetP(nil) - for i, p := range predicates { - if i > 0 { - s1.Or() - } - p(s1) - } - s.Where(s1.P()) - }) + return predicate.Machine(sql.OrPredicates(predicates...)) } // Not applies the not operator on the given predicate. func Not(p predicate.Machine) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - p(s.Not()) - }) + return predicate.Machine(sql.NotPredicates(p)) } diff --git a/pkg/database/ent/machine_create.go b/pkg/database/ent/machine_create.go index efe02782f6b..ff704e6ab74 100644 --- a/pkg/database/ent/machine_create.go +++ b/pkg/database/ent/machine_create.go @@ -187,50 +187,8 @@ func (mc *MachineCreate) Mutation() *MachineMutation { // Save creates the Machine in the database. func (mc *MachineCreate) Save(ctx context.Context) (*Machine, error) { - var ( - err error - node *Machine - ) mc.defaults() - if len(mc.hooks) == 0 { - if err = mc.check(); err != nil { - return nil, err - } - node, err = mc.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*MachineMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - if err = mc.check(); err != nil { - return nil, err - } - mc.mutation = mutation - if node, err = mc.sqlSave(ctx); err != nil { - return nil, err - } - mutation.id = &node.ID - mutation.done = true - return node, err - }) - for i := len(mc.hooks) - 1; i >= 0; i-- { - if mc.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = mc.hooks[i](mut) - } - v, err := mut.Mutate(ctx, mc.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*Machine) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from MachineMutation", v) - } - node = nv - } - return node, err + return withHooks(ctx, mc.sqlSave, mc.mutation, mc.hooks) } // SaveX calls Save and panics if Save returns an error. @@ -309,6 +267,9 @@ func (mc *MachineCreate) check() error { } func (mc *MachineCreate) sqlSave(ctx context.Context) (*Machine, error) { + if err := mc.check(); err != nil { + return nil, err + } _node, _spec := mc.createSpec() if err := sqlgraph.CreateNode(ctx, mc.driver, _spec); err != nil { if sqlgraph.IsConstraintError(err) { @@ -318,114 +279,62 @@ func (mc *MachineCreate) sqlSave(ctx context.Context) (*Machine, error) { } id := _spec.ID.Value.(int64) _node.ID = int(id) + mc.mutation.id = &_node.ID + mc.mutation.done = true return _node, nil } func (mc *MachineCreate) createSpec() (*Machine, *sqlgraph.CreateSpec) { var ( _node = &Machine{config: mc.config} - _spec = &sqlgraph.CreateSpec{ - Table: machine.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: machine.FieldID, - }, - } + _spec = sqlgraph.NewCreateSpec(machine.Table, sqlgraph.NewFieldSpec(machine.FieldID, field.TypeInt)) ) if value, ok := mc.mutation.CreatedAt(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: machine.FieldCreatedAt, - }) + _spec.SetField(machine.FieldCreatedAt, field.TypeTime, value) _node.CreatedAt = &value } if value, ok := mc.mutation.UpdatedAt(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: machine.FieldUpdatedAt, - }) + _spec.SetField(machine.FieldUpdatedAt, field.TypeTime, value) _node.UpdatedAt = &value } if value, ok := mc.mutation.LastPush(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: machine.FieldLastPush, - }) + _spec.SetField(machine.FieldLastPush, field.TypeTime, value) _node.LastPush = &value } if value, ok := mc.mutation.LastHeartbeat(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: machine.FieldLastHeartbeat, - }) + _spec.SetField(machine.FieldLastHeartbeat, field.TypeTime, value) _node.LastHeartbeat = &value } if value, ok := mc.mutation.MachineId(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: machine.FieldMachineId, - }) + _spec.SetField(machine.FieldMachineId, field.TypeString, value) _node.MachineId = value } if value, ok := mc.mutation.Password(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: machine.FieldPassword, - }) + _spec.SetField(machine.FieldPassword, field.TypeString, value) _node.Password = value } if value, ok := mc.mutation.IpAddress(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: machine.FieldIpAddress, - }) + _spec.SetField(machine.FieldIpAddress, field.TypeString, value) _node.IpAddress = value } if value, ok := mc.mutation.Scenarios(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: machine.FieldScenarios, - }) + _spec.SetField(machine.FieldScenarios, field.TypeString, value) _node.Scenarios = value } if value, ok := mc.mutation.Version(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: machine.FieldVersion, - }) + _spec.SetField(machine.FieldVersion, field.TypeString, value) _node.Version = value } if value, ok := mc.mutation.IsValidated(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeBool, - Value: value, - Column: machine.FieldIsValidated, - }) + _spec.SetField(machine.FieldIsValidated, field.TypeBool, value) _node.IsValidated = value } if value, ok := mc.mutation.Status(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: machine.FieldStatus, - }) + _spec.SetField(machine.FieldStatus, field.TypeString, value) _node.Status = value } if value, ok := mc.mutation.AuthType(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: machine.FieldAuthType, - }) + _spec.SetField(machine.FieldAuthType, field.TypeString, value) _node.AuthType = value } if nodes := mc.mutation.AlertsIDs(); len(nodes) > 0 { @@ -436,10 +345,7 @@ func (mc *MachineCreate) createSpec() (*Machine, *sqlgraph.CreateSpec) { Columns: []string{machine.AlertsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: alert.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -453,11 +359,15 @@ func (mc *MachineCreate) createSpec() (*Machine, *sqlgraph.CreateSpec) { // MachineCreateBulk is the builder for creating many Machine entities in bulk. type MachineCreateBulk struct { config + err error builders []*MachineCreate } // Save creates the Machine entities in the database. func (mcb *MachineCreateBulk) Save(ctx context.Context) ([]*Machine, error) { + if mcb.err != nil { + return nil, mcb.err + } specs := make([]*sqlgraph.CreateSpec, len(mcb.builders)) nodes := make([]*Machine, len(mcb.builders)) mutators := make([]Mutator, len(mcb.builders)) @@ -474,8 +384,8 @@ func (mcb *MachineCreateBulk) Save(ctx context.Context) ([]*Machine, error) { return nil, err } builder.mutation = mutation - nodes[i], specs[i] = builder.createSpec() var err error + nodes[i], specs[i] = builder.createSpec() if i < len(mutators)-1 { _, err = mutators[i+1].Mutate(root, mcb.builders[i+1].mutation) } else { diff --git a/pkg/database/ent/machine_delete.go b/pkg/database/ent/machine_delete.go index bead8acb46d..ac3aa751d5e 100644 --- a/pkg/database/ent/machine_delete.go +++ b/pkg/database/ent/machine_delete.go @@ -4,7 +4,6 @@ package ent import ( "context" - "fmt" "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" @@ -28,34 +27,7 @@ func (md *MachineDelete) Where(ps ...predicate.Machine) *MachineDelete { // Exec executes the deletion query and returns how many vertices were deleted. func (md *MachineDelete) Exec(ctx context.Context) (int, error) { - var ( - err error - affected int - ) - if len(md.hooks) == 0 { - affected, err = md.sqlExec(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*MachineMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - md.mutation = mutation - affected, err = md.sqlExec(ctx) - mutation.done = true - return affected, err - }) - for i := len(md.hooks) - 1; i >= 0; i-- { - if md.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = md.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, md.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks(ctx, md.sqlExec, md.mutation, md.hooks) } // ExecX is like Exec, but panics if an error occurs. @@ -68,15 +40,7 @@ func (md *MachineDelete) ExecX(ctx context.Context) int { } func (md *MachineDelete) sqlExec(ctx context.Context) (int, error) { - _spec := &sqlgraph.DeleteSpec{ - Node: &sqlgraph.NodeSpec{ - Table: machine.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: machine.FieldID, - }, - }, - } + _spec := sqlgraph.NewDeleteSpec(machine.Table, sqlgraph.NewFieldSpec(machine.FieldID, field.TypeInt)) if ps := md.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -88,6 +52,7 @@ func (md *MachineDelete) sqlExec(ctx context.Context) (int, error) { if err != nil && sqlgraph.IsConstraintError(err) { err = &ConstraintError{msg: err.Error(), wrap: err} } + md.mutation.done = true return affected, err } @@ -96,6 +61,12 @@ type MachineDeleteOne struct { md *MachineDelete } +// Where appends a list predicates to the MachineDelete builder. +func (mdo *MachineDeleteOne) Where(ps ...predicate.Machine) *MachineDeleteOne { + mdo.md.mutation.Where(ps...) + return mdo +} + // Exec executes the deletion query. func (mdo *MachineDeleteOne) Exec(ctx context.Context) error { n, err := mdo.md.Exec(ctx) @@ -111,5 +82,7 @@ func (mdo *MachineDeleteOne) Exec(ctx context.Context) error { // ExecX is like Exec, but panics if an error occurs. func (mdo *MachineDeleteOne) ExecX(ctx context.Context) { - mdo.md.ExecX(ctx) + if err := mdo.Exec(ctx); err != nil { + panic(err) + } } diff --git a/pkg/database/ent/machine_query.go b/pkg/database/ent/machine_query.go index 2839142196b..462c2cf35b1 100644 --- a/pkg/database/ent/machine_query.go +++ b/pkg/database/ent/machine_query.go @@ -19,11 +19,9 @@ import ( // MachineQuery is the builder for querying Machine entities. type MachineQuery struct { config - limit *int - offset *int - unique *bool - order []OrderFunc - fields []string + ctx *QueryContext + order []machine.OrderOption + inters []Interceptor predicates []predicate.Machine withAlerts *AlertQuery // intermediate query (i.e. traversal path). @@ -37,34 +35,34 @@ func (mq *MachineQuery) Where(ps ...predicate.Machine) *MachineQuery { return mq } -// Limit adds a limit step to the query. +// Limit the number of records to be returned by this query. func (mq *MachineQuery) Limit(limit int) *MachineQuery { - mq.limit = &limit + mq.ctx.Limit = &limit return mq } -// Offset adds an offset step to the query. +// Offset to start from. func (mq *MachineQuery) Offset(offset int) *MachineQuery { - mq.offset = &offset + mq.ctx.Offset = &offset return mq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (mq *MachineQuery) Unique(unique bool) *MachineQuery { - mq.unique = &unique + mq.ctx.Unique = &unique return mq } -// Order adds an order step to the query. -func (mq *MachineQuery) Order(o ...OrderFunc) *MachineQuery { +// Order specifies how the records should be ordered. +func (mq *MachineQuery) Order(o ...machine.OrderOption) *MachineQuery { mq.order = append(mq.order, o...) return mq } // QueryAlerts chains the current query on the "alerts" edge. func (mq *MachineQuery) QueryAlerts() *AlertQuery { - query := &AlertQuery{config: mq.config} + query := (&AlertClient{config: mq.config}).Query() query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { if err := mq.prepareQuery(ctx); err != nil { return nil, err @@ -87,7 +85,7 @@ func (mq *MachineQuery) QueryAlerts() *AlertQuery { // First returns the first Machine entity from the query. // Returns a *NotFoundError when no Machine was found. func (mq *MachineQuery) First(ctx context.Context) (*Machine, error) { - nodes, err := mq.Limit(1).All(ctx) + nodes, err := mq.Limit(1).All(setContextOp(ctx, mq.ctx, "First")) if err != nil { return nil, err } @@ -110,7 +108,7 @@ func (mq *MachineQuery) FirstX(ctx context.Context) *Machine { // Returns a *NotFoundError when no Machine ID was found. func (mq *MachineQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = mq.Limit(1).IDs(ctx); err != nil { + if ids, err = mq.Limit(1).IDs(setContextOp(ctx, mq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -133,7 +131,7 @@ func (mq *MachineQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one Machine entity is found. // Returns a *NotFoundError when no Machine entities are found. func (mq *MachineQuery) Only(ctx context.Context) (*Machine, error) { - nodes, err := mq.Limit(2).All(ctx) + nodes, err := mq.Limit(2).All(setContextOp(ctx, mq.ctx, "Only")) if err != nil { return nil, err } @@ -161,7 +159,7 @@ func (mq *MachineQuery) OnlyX(ctx context.Context) *Machine { // Returns a *NotFoundError when no entities are found. func (mq *MachineQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = mq.Limit(2).IDs(ctx); err != nil { + if ids, err = mq.Limit(2).IDs(setContextOp(ctx, mq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -186,10 +184,12 @@ func (mq *MachineQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of Machines. func (mq *MachineQuery) All(ctx context.Context) ([]*Machine, error) { + ctx = setContextOp(ctx, mq.ctx, "All") if err := mq.prepareQuery(ctx); err != nil { return nil, err } - return mq.sqlAll(ctx) + qr := querierAll[[]*Machine, *MachineQuery]() + return withInterceptors[[]*Machine](ctx, mq, qr, mq.inters) } // AllX is like All, but panics if an error occurs. @@ -202,9 +202,12 @@ func (mq *MachineQuery) AllX(ctx context.Context) []*Machine { } // IDs executes the query and returns a list of Machine IDs. -func (mq *MachineQuery) IDs(ctx context.Context) ([]int, error) { - var ids []int - if err := mq.Select(machine.FieldID).Scan(ctx, &ids); err != nil { +func (mq *MachineQuery) IDs(ctx context.Context) (ids []int, err error) { + if mq.ctx.Unique == nil && mq.path != nil { + mq.Unique(true) + } + ctx = setContextOp(ctx, mq.ctx, "IDs") + if err = mq.Select(machine.FieldID).Scan(ctx, &ids); err != nil { return nil, err } return ids, nil @@ -221,10 +224,11 @@ func (mq *MachineQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (mq *MachineQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, mq.ctx, "Count") if err := mq.prepareQuery(ctx); err != nil { return 0, err } - return mq.sqlCount(ctx) + return withInterceptors[int](ctx, mq, querierCount[*MachineQuery](), mq.inters) } // CountX is like Count, but panics if an error occurs. @@ -238,10 +242,15 @@ func (mq *MachineQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (mq *MachineQuery) Exist(ctx context.Context) (bool, error) { - if err := mq.prepareQuery(ctx); err != nil { - return false, err + ctx = setContextOp(ctx, mq.ctx, "Exist") + switch _, err := mq.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil } - return mq.sqlExist(ctx) } // ExistX is like Exist, but panics if an error occurs. @@ -261,22 +270,21 @@ func (mq *MachineQuery) Clone() *MachineQuery { } return &MachineQuery{ config: mq.config, - limit: mq.limit, - offset: mq.offset, - order: append([]OrderFunc{}, mq.order...), + ctx: mq.ctx.Clone(), + order: append([]machine.OrderOption{}, mq.order...), + inters: append([]Interceptor{}, mq.inters...), predicates: append([]predicate.Machine{}, mq.predicates...), withAlerts: mq.withAlerts.Clone(), // clone intermediate query. - sql: mq.sql.Clone(), - path: mq.path, - unique: mq.unique, + sql: mq.sql.Clone(), + path: mq.path, } } // WithAlerts tells the query-builder to eager-load the nodes that are connected to // the "alerts" edge. The optional arguments are used to configure the query builder of the edge. func (mq *MachineQuery) WithAlerts(opts ...func(*AlertQuery)) *MachineQuery { - query := &AlertQuery{config: mq.config} + query := (&AlertClient{config: mq.config}).Query() for _, opt := range opts { opt(query) } @@ -299,16 +307,11 @@ func (mq *MachineQuery) WithAlerts(opts ...func(*AlertQuery)) *MachineQuery { // Aggregate(ent.Count()). // Scan(ctx, &v) func (mq *MachineQuery) GroupBy(field string, fields ...string) *MachineGroupBy { - grbuild := &MachineGroupBy{config: mq.config} - grbuild.fields = append([]string{field}, fields...) - grbuild.path = func(ctx context.Context) (prev *sql.Selector, err error) { - if err := mq.prepareQuery(ctx); err != nil { - return nil, err - } - return mq.sqlQuery(ctx), nil - } + mq.ctx.Fields = append([]string{field}, fields...) + grbuild := &MachineGroupBy{build: mq} + grbuild.flds = &mq.ctx.Fields grbuild.label = machine.Label - grbuild.flds, grbuild.scan = &grbuild.fields, grbuild.Scan + grbuild.scan = grbuild.Scan return grbuild } @@ -325,15 +328,30 @@ func (mq *MachineQuery) GroupBy(field string, fields ...string) *MachineGroupBy // Select(machine.FieldCreatedAt). // Scan(ctx, &v) func (mq *MachineQuery) Select(fields ...string) *MachineSelect { - mq.fields = append(mq.fields, fields...) - selbuild := &MachineSelect{MachineQuery: mq} - selbuild.label = machine.Label - selbuild.flds, selbuild.scan = &mq.fields, selbuild.Scan - return selbuild + mq.ctx.Fields = append(mq.ctx.Fields, fields...) + sbuild := &MachineSelect{MachineQuery: mq} + sbuild.label = machine.Label + sbuild.flds, sbuild.scan = &mq.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a MachineSelect configured with the given aggregations. +func (mq *MachineQuery) Aggregate(fns ...AggregateFunc) *MachineSelect { + return mq.Select().Aggregate(fns...) } func (mq *MachineQuery) prepareQuery(ctx context.Context) error { - for _, f := range mq.fields { + for _, inter := range mq.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, mq); err != nil { + return err + } + } + } + for _, f := range mq.ctx.Fields { if !machine.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -396,7 +414,7 @@ func (mq *MachineQuery) loadAlerts(ctx context.Context, query *AlertQuery, nodes } query.withFKs = true query.Where(predicate.Alert(func(s *sql.Selector) { - s.Where(sql.InValues(machine.AlertsColumn, fks...)) + s.Where(sql.InValues(s.C(machine.AlertsColumn), fks...)) })) neighbors, err := query.All(ctx) if err != nil { @@ -409,7 +427,7 @@ func (mq *MachineQuery) loadAlerts(ctx context.Context, query *AlertQuery, nodes } node, ok := nodeids[*fk] if !ok { - return fmt.Errorf(`unexpected foreign-key "machine_alerts" returned %v for node %v`, *fk, n.ID) + return fmt.Errorf(`unexpected referenced foreign-key "machine_alerts" returned %v for node %v`, *fk, n.ID) } assign(node, n) } @@ -418,41 +436,22 @@ func (mq *MachineQuery) loadAlerts(ctx context.Context, query *AlertQuery, nodes func (mq *MachineQuery) sqlCount(ctx context.Context) (int, error) { _spec := mq.querySpec() - _spec.Node.Columns = mq.fields - if len(mq.fields) > 0 { - _spec.Unique = mq.unique != nil && *mq.unique + _spec.Node.Columns = mq.ctx.Fields + if len(mq.ctx.Fields) > 0 { + _spec.Unique = mq.ctx.Unique != nil && *mq.ctx.Unique } return sqlgraph.CountNodes(ctx, mq.driver, _spec) } -func (mq *MachineQuery) sqlExist(ctx context.Context) (bool, error) { - switch _, err := mq.FirstID(ctx); { - case IsNotFound(err): - return false, nil - case err != nil: - return false, fmt.Errorf("ent: check existence: %w", err) - default: - return true, nil - } -} - func (mq *MachineQuery) querySpec() *sqlgraph.QuerySpec { - _spec := &sqlgraph.QuerySpec{ - Node: &sqlgraph.NodeSpec{ - Table: machine.Table, - Columns: machine.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: machine.FieldID, - }, - }, - From: mq.sql, - Unique: true, - } - if unique := mq.unique; unique != nil { + _spec := sqlgraph.NewQuerySpec(machine.Table, machine.Columns, sqlgraph.NewFieldSpec(machine.FieldID, field.TypeInt)) + _spec.From = mq.sql + if unique := mq.ctx.Unique; unique != nil { _spec.Unique = *unique + } else if mq.path != nil { + _spec.Unique = true } - if fields := mq.fields; len(fields) > 0 { + if fields := mq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, machine.FieldID) for i := range fields { @@ -468,10 +467,10 @@ func (mq *MachineQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := mq.limit; limit != nil { + if limit := mq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := mq.offset; offset != nil { + if offset := mq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := mq.order; len(ps) > 0 { @@ -487,7 +486,7 @@ func (mq *MachineQuery) querySpec() *sqlgraph.QuerySpec { func (mq *MachineQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(mq.driver.Dialect()) t1 := builder.Table(machine.Table) - columns := mq.fields + columns := mq.ctx.Fields if len(columns) == 0 { columns = machine.Columns } @@ -496,7 +495,7 @@ func (mq *MachineQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = mq.sql selector.Select(selector.Columns(columns...)...) } - if mq.unique != nil && *mq.unique { + if mq.ctx.Unique != nil && *mq.ctx.Unique { selector.Distinct() } for _, p := range mq.predicates { @@ -505,12 +504,12 @@ func (mq *MachineQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range mq.order { p(selector) } - if offset := mq.offset; offset != nil { + if offset := mq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := mq.limit; limit != nil { + if limit := mq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -518,13 +517,8 @@ func (mq *MachineQuery) sqlQuery(ctx context.Context) *sql.Selector { // MachineGroupBy is the group-by builder for Machine entities. type MachineGroupBy struct { - config selector - fields []string - fns []AggregateFunc - // intermediate query (i.e. traversal path). - sql *sql.Selector - path func(context.Context) (*sql.Selector, error) + build *MachineQuery } // Aggregate adds the given aggregation functions to the group-by query. @@ -533,74 +527,77 @@ func (mgb *MachineGroupBy) Aggregate(fns ...AggregateFunc) *MachineGroupBy { return mgb } -// Scan applies the group-by query and scans the result into the given value. +// Scan applies the selector query and scans the result into the given value. func (mgb *MachineGroupBy) Scan(ctx context.Context, v any) error { - query, err := mgb.path(ctx) - if err != nil { + ctx = setContextOp(ctx, mgb.build.ctx, "GroupBy") + if err := mgb.build.prepareQuery(ctx); err != nil { return err } - mgb.sql = query - return mgb.sqlScan(ctx, v) + return scanWithInterceptors[*MachineQuery, *MachineGroupBy](ctx, mgb.build, mgb, mgb.build.inters, v) } -func (mgb *MachineGroupBy) sqlScan(ctx context.Context, v any) error { - for _, f := range mgb.fields { - if !machine.ValidColumn(f) { - return &ValidationError{Name: f, err: fmt.Errorf("invalid field %q for group-by", f)} +func (mgb *MachineGroupBy) sqlScan(ctx context.Context, root *MachineQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(mgb.fns)) + for _, fn := range mgb.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*mgb.flds)+len(mgb.fns)) + for _, f := range *mgb.flds { + columns = append(columns, selector.C(f)) } + columns = append(columns, aggregation...) + selector.Select(columns...) } - selector := mgb.sqlQuery() + selector.GroupBy(selector.Columns(*mgb.flds...)...) if err := selector.Err(); err != nil { return err } rows := &sql.Rows{} query, args := selector.Query() - if err := mgb.driver.Query(ctx, query, args, rows); err != nil { + if err := mgb.build.driver.Query(ctx, query, args, rows); err != nil { return err } defer rows.Close() return sql.ScanSlice(rows, v) } -func (mgb *MachineGroupBy) sqlQuery() *sql.Selector { - selector := mgb.sql.Select() - aggregation := make([]string, 0, len(mgb.fns)) - for _, fn := range mgb.fns { - aggregation = append(aggregation, fn(selector)) - } - // If no columns were selected in a custom aggregation function, the default - // selection is the fields used for "group-by", and the aggregation functions. - if len(selector.SelectedColumns()) == 0 { - columns := make([]string, 0, len(mgb.fields)+len(mgb.fns)) - for _, f := range mgb.fields { - columns = append(columns, selector.C(f)) - } - columns = append(columns, aggregation...) - selector.Select(columns...) - } - return selector.GroupBy(selector.Columns(mgb.fields...)...) -} - // MachineSelect is the builder for selecting fields of Machine entities. type MachineSelect struct { *MachineQuery selector - // intermediate query (i.e. traversal path). - sql *sql.Selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (ms *MachineSelect) Aggregate(fns ...AggregateFunc) *MachineSelect { + ms.fns = append(ms.fns, fns...) + return ms } // Scan applies the selector query and scans the result into the given value. func (ms *MachineSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, ms.ctx, "Select") if err := ms.prepareQuery(ctx); err != nil { return err } - ms.sql = ms.MachineQuery.sqlQuery(ctx) - return ms.sqlScan(ctx, v) + return scanWithInterceptors[*MachineQuery, *MachineSelect](ctx, ms.MachineQuery, ms, ms.inters, v) } -func (ms *MachineSelect) sqlScan(ctx context.Context, v any) error { +func (ms *MachineSelect) sqlScan(ctx context.Context, root *MachineQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(ms.fns)) + for _, fn := range ms.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*ms.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } rows := &sql.Rows{} - query, args := ms.sql.Query() + query, args := selector.Query() if err := ms.driver.Query(ctx, query, args, rows); err != nil { return err } diff --git a/pkg/database/ent/machine_update.go b/pkg/database/ent/machine_update.go index de9f8d12460..eb517081174 100644 --- a/pkg/database/ent/machine_update.go +++ b/pkg/database/ent/machine_update.go @@ -226,41 +226,8 @@ func (mu *MachineUpdate) RemoveAlerts(a ...*Alert) *MachineUpdate { // Save executes the query and returns the number of nodes affected by the update operation. func (mu *MachineUpdate) Save(ctx context.Context) (int, error) { - var ( - err error - affected int - ) mu.defaults() - if len(mu.hooks) == 0 { - if err = mu.check(); err != nil { - return 0, err - } - affected, err = mu.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*MachineMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - if err = mu.check(); err != nil { - return 0, err - } - mu.mutation = mutation - affected, err = mu.sqlSave(ctx) - mutation.done = true - return affected, err - }) - for i := len(mu.hooks) - 1; i >= 0; i-- { - if mu.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = mu.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, mu.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks(ctx, mu.sqlSave, mu.mutation, mu.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -316,16 +283,10 @@ func (mu *MachineUpdate) check() error { } func (mu *MachineUpdate) sqlSave(ctx context.Context) (n int, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: machine.Table, - Columns: machine.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: machine.FieldID, - }, - }, + if err := mu.check(); err != nil { + return n, err } + _spec := sqlgraph.NewUpdateSpec(machine.Table, machine.Columns, sqlgraph.NewFieldSpec(machine.FieldID, field.TypeInt)) if ps := mu.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -334,130 +295,61 @@ func (mu *MachineUpdate) sqlSave(ctx context.Context) (n int, err error) { } } if value, ok := mu.mutation.CreatedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: machine.FieldCreatedAt, - }) + _spec.SetField(machine.FieldCreatedAt, field.TypeTime, value) } if mu.mutation.CreatedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: machine.FieldCreatedAt, - }) + _spec.ClearField(machine.FieldCreatedAt, field.TypeTime) } if value, ok := mu.mutation.UpdatedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: machine.FieldUpdatedAt, - }) + _spec.SetField(machine.FieldUpdatedAt, field.TypeTime, value) } if mu.mutation.UpdatedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: machine.FieldUpdatedAt, - }) + _spec.ClearField(machine.FieldUpdatedAt, field.TypeTime) } if value, ok := mu.mutation.LastPush(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: machine.FieldLastPush, - }) + _spec.SetField(machine.FieldLastPush, field.TypeTime, value) } if mu.mutation.LastPushCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: machine.FieldLastPush, - }) + _spec.ClearField(machine.FieldLastPush, field.TypeTime) } if value, ok := mu.mutation.LastHeartbeat(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: machine.FieldLastHeartbeat, - }) + _spec.SetField(machine.FieldLastHeartbeat, field.TypeTime, value) } if mu.mutation.LastHeartbeatCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: machine.FieldLastHeartbeat, - }) + _spec.ClearField(machine.FieldLastHeartbeat, field.TypeTime) } if value, ok := mu.mutation.MachineId(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: machine.FieldMachineId, - }) + _spec.SetField(machine.FieldMachineId, field.TypeString, value) } if value, ok := mu.mutation.Password(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: machine.FieldPassword, - }) + _spec.SetField(machine.FieldPassword, field.TypeString, value) } if value, ok := mu.mutation.IpAddress(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: machine.FieldIpAddress, - }) + _spec.SetField(machine.FieldIpAddress, field.TypeString, value) } if value, ok := mu.mutation.Scenarios(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: machine.FieldScenarios, - }) + _spec.SetField(machine.FieldScenarios, field.TypeString, value) } if mu.mutation.ScenariosCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: machine.FieldScenarios, - }) + _spec.ClearField(machine.FieldScenarios, field.TypeString) } if value, ok := mu.mutation.Version(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: machine.FieldVersion, - }) + _spec.SetField(machine.FieldVersion, field.TypeString, value) } if mu.mutation.VersionCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: machine.FieldVersion, - }) + _spec.ClearField(machine.FieldVersion, field.TypeString) } if value, ok := mu.mutation.IsValidated(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeBool, - Value: value, - Column: machine.FieldIsValidated, - }) + _spec.SetField(machine.FieldIsValidated, field.TypeBool, value) } if value, ok := mu.mutation.Status(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: machine.FieldStatus, - }) + _spec.SetField(machine.FieldStatus, field.TypeString, value) } if mu.mutation.StatusCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: machine.FieldStatus, - }) + _spec.ClearField(machine.FieldStatus, field.TypeString) } if value, ok := mu.mutation.AuthType(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: machine.FieldAuthType, - }) + _spec.SetField(machine.FieldAuthType, field.TypeString, value) } if mu.mutation.AlertsCleared() { edge := &sqlgraph.EdgeSpec{ @@ -467,10 +359,7 @@ func (mu *MachineUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{machine.AlertsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: alert.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt), }, } _spec.Edges.Clear = append(_spec.Edges.Clear, edge) @@ -483,10 +372,7 @@ func (mu *MachineUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{machine.AlertsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: alert.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -502,10 +388,7 @@ func (mu *MachineUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{machine.AlertsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: alert.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -521,6 +404,7 @@ func (mu *MachineUpdate) sqlSave(ctx context.Context) (n int, err error) { } return 0, err } + mu.mutation.done = true return n, nil } @@ -727,6 +611,12 @@ func (muo *MachineUpdateOne) RemoveAlerts(a ...*Alert) *MachineUpdateOne { return muo.RemoveAlertIDs(ids...) } +// Where appends a list predicates to the MachineUpdate builder. +func (muo *MachineUpdateOne) Where(ps ...predicate.Machine) *MachineUpdateOne { + muo.mutation.Where(ps...) + return muo +} + // Select allows selecting one or more fields (columns) of the returned entity. // The default is selecting all fields defined in the entity schema. func (muo *MachineUpdateOne) Select(field string, fields ...string) *MachineUpdateOne { @@ -736,47 +626,8 @@ func (muo *MachineUpdateOne) Select(field string, fields ...string) *MachineUpda // Save executes the query and returns the updated Machine entity. func (muo *MachineUpdateOne) Save(ctx context.Context) (*Machine, error) { - var ( - err error - node *Machine - ) muo.defaults() - if len(muo.hooks) == 0 { - if err = muo.check(); err != nil { - return nil, err - } - node, err = muo.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*MachineMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - if err = muo.check(); err != nil { - return nil, err - } - muo.mutation = mutation - node, err = muo.sqlSave(ctx) - mutation.done = true - return node, err - }) - for i := len(muo.hooks) - 1; i >= 0; i-- { - if muo.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = muo.hooks[i](mut) - } - v, err := mut.Mutate(ctx, muo.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*Machine) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from MachineMutation", v) - } - node = nv - } - return node, err + return withHooks(ctx, muo.sqlSave, muo.mutation, muo.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -832,16 +683,10 @@ func (muo *MachineUpdateOne) check() error { } func (muo *MachineUpdateOne) sqlSave(ctx context.Context) (_node *Machine, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: machine.Table, - Columns: machine.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: machine.FieldID, - }, - }, + if err := muo.check(); err != nil { + return _node, err } + _spec := sqlgraph.NewUpdateSpec(machine.Table, machine.Columns, sqlgraph.NewFieldSpec(machine.FieldID, field.TypeInt)) id, ok := muo.mutation.ID() if !ok { return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "Machine.id" for update`)} @@ -867,130 +712,61 @@ func (muo *MachineUpdateOne) sqlSave(ctx context.Context) (_node *Machine, err e } } if value, ok := muo.mutation.CreatedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: machine.FieldCreatedAt, - }) + _spec.SetField(machine.FieldCreatedAt, field.TypeTime, value) } if muo.mutation.CreatedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: machine.FieldCreatedAt, - }) + _spec.ClearField(machine.FieldCreatedAt, field.TypeTime) } if value, ok := muo.mutation.UpdatedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: machine.FieldUpdatedAt, - }) + _spec.SetField(machine.FieldUpdatedAt, field.TypeTime, value) } if muo.mutation.UpdatedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: machine.FieldUpdatedAt, - }) + _spec.ClearField(machine.FieldUpdatedAt, field.TypeTime) } if value, ok := muo.mutation.LastPush(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: machine.FieldLastPush, - }) + _spec.SetField(machine.FieldLastPush, field.TypeTime, value) } if muo.mutation.LastPushCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: machine.FieldLastPush, - }) + _spec.ClearField(machine.FieldLastPush, field.TypeTime) } if value, ok := muo.mutation.LastHeartbeat(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: machine.FieldLastHeartbeat, - }) + _spec.SetField(machine.FieldLastHeartbeat, field.TypeTime, value) } if muo.mutation.LastHeartbeatCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: machine.FieldLastHeartbeat, - }) + _spec.ClearField(machine.FieldLastHeartbeat, field.TypeTime) } if value, ok := muo.mutation.MachineId(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: machine.FieldMachineId, - }) + _spec.SetField(machine.FieldMachineId, field.TypeString, value) } if value, ok := muo.mutation.Password(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: machine.FieldPassword, - }) + _spec.SetField(machine.FieldPassword, field.TypeString, value) } if value, ok := muo.mutation.IpAddress(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: machine.FieldIpAddress, - }) + _spec.SetField(machine.FieldIpAddress, field.TypeString, value) } if value, ok := muo.mutation.Scenarios(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: machine.FieldScenarios, - }) + _spec.SetField(machine.FieldScenarios, field.TypeString, value) } if muo.mutation.ScenariosCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: machine.FieldScenarios, - }) + _spec.ClearField(machine.FieldScenarios, field.TypeString) } if value, ok := muo.mutation.Version(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: machine.FieldVersion, - }) + _spec.SetField(machine.FieldVersion, field.TypeString, value) } if muo.mutation.VersionCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: machine.FieldVersion, - }) + _spec.ClearField(machine.FieldVersion, field.TypeString) } if value, ok := muo.mutation.IsValidated(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeBool, - Value: value, - Column: machine.FieldIsValidated, - }) + _spec.SetField(machine.FieldIsValidated, field.TypeBool, value) } if value, ok := muo.mutation.Status(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: machine.FieldStatus, - }) + _spec.SetField(machine.FieldStatus, field.TypeString, value) } if muo.mutation.StatusCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: machine.FieldStatus, - }) + _spec.ClearField(machine.FieldStatus, field.TypeString) } if value, ok := muo.mutation.AuthType(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: machine.FieldAuthType, - }) + _spec.SetField(machine.FieldAuthType, field.TypeString, value) } if muo.mutation.AlertsCleared() { edge := &sqlgraph.EdgeSpec{ @@ -1000,10 +776,7 @@ func (muo *MachineUpdateOne) sqlSave(ctx context.Context) (_node *Machine, err e Columns: []string{machine.AlertsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: alert.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt), }, } _spec.Edges.Clear = append(_spec.Edges.Clear, edge) @@ -1016,10 +789,7 @@ func (muo *MachineUpdateOne) sqlSave(ctx context.Context) (_node *Machine, err e Columns: []string{machine.AlertsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: alert.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -1035,10 +805,7 @@ func (muo *MachineUpdateOne) sqlSave(ctx context.Context) (_node *Machine, err e Columns: []string{machine.AlertsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: alert.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -1057,5 +824,6 @@ func (muo *MachineUpdateOne) sqlSave(ctx context.Context) (_node *Machine, err e } return nil, err } + muo.mutation.done = true return _node, nil } diff --git a/pkg/database/ent/meta.go b/pkg/database/ent/meta.go index 660f1a4db73..cadc210937e 100644 --- a/pkg/database/ent/meta.go +++ b/pkg/database/ent/meta.go @@ -7,6 +7,7 @@ import ( "strings" "time" + "entgo.io/ent" "entgo.io/ent/dialect/sql" "github.com/crowdsecurity/crowdsec/pkg/database/ent/alert" "github.com/crowdsecurity/crowdsec/pkg/database/ent/meta" @@ -29,7 +30,8 @@ type Meta struct { AlertMetas int `json:"alert_metas,omitempty"` // Edges holds the relations/edges for other nodes in the graph. // The values are being populated by the MetaQuery when eager-loading is set. - Edges MetaEdges `json:"edges"` + Edges MetaEdges `json:"edges"` + selectValues sql.SelectValues } // MetaEdges holds the relations/edges for other nodes in the graph. @@ -66,7 +68,7 @@ func (*Meta) scanValues(columns []string) ([]any, error) { case meta.FieldCreatedAt, meta.FieldUpdatedAt: values[i] = new(sql.NullTime) default: - return nil, fmt.Errorf("unexpected column %q for type Meta", columns[i]) + values[i] = new(sql.UnknownType) } } return values, nil @@ -118,21 +120,29 @@ func (m *Meta) assignValues(columns []string, values []any) error { } else if value.Valid { m.AlertMetas = int(value.Int64) } + default: + m.selectValues.Set(columns[i], values[i]) } } return nil } +// GetValue returns the ent.Value that was dynamically selected and assigned to the Meta. +// This includes values selected through modifiers, order, etc. +func (m *Meta) GetValue(name string) (ent.Value, error) { + return m.selectValues.Get(name) +} + // QueryOwner queries the "owner" edge of the Meta entity. func (m *Meta) QueryOwner() *AlertQuery { - return (&MetaClient{config: m.config}).QueryOwner(m) + return NewMetaClient(m.config).QueryOwner(m) } // Update returns a builder for updating this Meta. // Note that you need to call Meta.Unwrap() before calling this method if this Meta // was returned from a transaction, and the transaction was committed or rolled back. func (m *Meta) Update() *MetaUpdateOne { - return (&MetaClient{config: m.config}).UpdateOne(m) + return NewMetaClient(m.config).UpdateOne(m) } // Unwrap unwraps the Meta entity that was returned from a transaction after it was closed, @@ -175,9 +185,3 @@ func (m *Meta) String() string { // MetaSlice is a parsable slice of Meta. type MetaSlice []*Meta - -func (m MetaSlice) config(cfg config) { - for _i := range m { - m[_i].config = cfg - } -} diff --git a/pkg/database/ent/meta/meta.go b/pkg/database/ent/meta/meta.go index 6d10f258919..583496fb710 100644 --- a/pkg/database/ent/meta/meta.go +++ b/pkg/database/ent/meta/meta.go @@ -4,6 +4,9 @@ package meta import ( "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" ) const ( @@ -66,3 +69,50 @@ var ( // ValueValidator is a validator for the "value" field. It is called by the builders before save. ValueValidator func(string) error ) + +// OrderOption defines the ordering options for the Meta queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByCreatedAt orders the results by the created_at field. +func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() +} + +// ByUpdatedAt orders the results by the updated_at field. +func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc() +} + +// ByKey orders the results by the key field. +func ByKey(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldKey, opts...).ToFunc() +} + +// ByValue orders the results by the value field. +func ByValue(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldValue, opts...).ToFunc() +} + +// ByAlertMetas orders the results by the alert_metas field. +func ByAlertMetas(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldAlertMetas, opts...).ToFunc() +} + +// ByOwnerField orders the results by owner field. +func ByOwnerField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newOwnerStep(), sql.OrderByField(field, opts...)) + } +} +func newOwnerStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(OwnerInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, OwnerTable, OwnerColumn), + ) +} diff --git a/pkg/database/ent/meta/where.go b/pkg/database/ent/meta/where.go index 479792fd4a6..7fc99136972 100644 --- a/pkg/database/ent/meta/where.go +++ b/pkg/database/ent/meta/where.go @@ -12,512 +12,332 @@ import ( // ID filters vertices based on their ID field. func ID(id int) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.Meta(sql.FieldEQ(FieldID, id)) } // IDEQ applies the EQ predicate on the ID field. func IDEQ(id int) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.Meta(sql.FieldEQ(FieldID, id)) } // IDNEQ applies the NEQ predicate on the ID field. func IDNEQ(id int) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldID), id)) - }) + return predicate.Meta(sql.FieldNEQ(FieldID, id)) } // IDIn applies the In predicate on the ID field. func IDIn(ids ...int) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - v := make([]any, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.In(s.C(FieldID), v...)) - }) + return predicate.Meta(sql.FieldIn(FieldID, ids...)) } // IDNotIn applies the NotIn predicate on the ID field. func IDNotIn(ids ...int) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - v := make([]any, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.NotIn(s.C(FieldID), v...)) - }) + return predicate.Meta(sql.FieldNotIn(FieldID, ids...)) } // IDGT applies the GT predicate on the ID field. func IDGT(id int) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldID), id)) - }) + return predicate.Meta(sql.FieldGT(FieldID, id)) } // IDGTE applies the GTE predicate on the ID field. func IDGTE(id int) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldID), id)) - }) + return predicate.Meta(sql.FieldGTE(FieldID, id)) } // IDLT applies the LT predicate on the ID field. func IDLT(id int) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldID), id)) - }) + return predicate.Meta(sql.FieldLT(FieldID, id)) } // IDLTE applies the LTE predicate on the ID field. func IDLTE(id int) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldID), id)) - }) + return predicate.Meta(sql.FieldLTE(FieldID, id)) } // CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. func CreatedAt(v time.Time) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldCreatedAt), v)) - }) + return predicate.Meta(sql.FieldEQ(FieldCreatedAt, v)) } // UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ. func UpdatedAt(v time.Time) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldUpdatedAt), v)) - }) + return predicate.Meta(sql.FieldEQ(FieldUpdatedAt, v)) } // Key applies equality check predicate on the "key" field. It's identical to KeyEQ. func Key(v string) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldKey), v)) - }) + return predicate.Meta(sql.FieldEQ(FieldKey, v)) } // Value applies equality check predicate on the "value" field. It's identical to ValueEQ. func Value(v string) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldValue), v)) - }) + return predicate.Meta(sql.FieldEQ(FieldValue, v)) } // AlertMetas applies equality check predicate on the "alert_metas" field. It's identical to AlertMetasEQ. func AlertMetas(v int) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldAlertMetas), v)) - }) + return predicate.Meta(sql.FieldEQ(FieldAlertMetas, v)) } // CreatedAtEQ applies the EQ predicate on the "created_at" field. func CreatedAtEQ(v time.Time) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldCreatedAt), v)) - }) + return predicate.Meta(sql.FieldEQ(FieldCreatedAt, v)) } // CreatedAtNEQ applies the NEQ predicate on the "created_at" field. func CreatedAtNEQ(v time.Time) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldCreatedAt), v)) - }) + return predicate.Meta(sql.FieldNEQ(FieldCreatedAt, v)) } // CreatedAtIn applies the In predicate on the "created_at" field. func CreatedAtIn(vs ...time.Time) predicate.Meta { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldCreatedAt), v...)) - }) + return predicate.Meta(sql.FieldIn(FieldCreatedAt, vs...)) } // CreatedAtNotIn applies the NotIn predicate on the "created_at" field. func CreatedAtNotIn(vs ...time.Time) predicate.Meta { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldCreatedAt), v...)) - }) + return predicate.Meta(sql.FieldNotIn(FieldCreatedAt, vs...)) } // CreatedAtGT applies the GT predicate on the "created_at" field. func CreatedAtGT(v time.Time) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldCreatedAt), v)) - }) + return predicate.Meta(sql.FieldGT(FieldCreatedAt, v)) } // CreatedAtGTE applies the GTE predicate on the "created_at" field. func CreatedAtGTE(v time.Time) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldCreatedAt), v)) - }) + return predicate.Meta(sql.FieldGTE(FieldCreatedAt, v)) } // CreatedAtLT applies the LT predicate on the "created_at" field. func CreatedAtLT(v time.Time) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldCreatedAt), v)) - }) + return predicate.Meta(sql.FieldLT(FieldCreatedAt, v)) } // CreatedAtLTE applies the LTE predicate on the "created_at" field. func CreatedAtLTE(v time.Time) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldCreatedAt), v)) - }) + return predicate.Meta(sql.FieldLTE(FieldCreatedAt, v)) } // CreatedAtIsNil applies the IsNil predicate on the "created_at" field. func CreatedAtIsNil() predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldCreatedAt))) - }) + return predicate.Meta(sql.FieldIsNull(FieldCreatedAt)) } // CreatedAtNotNil applies the NotNil predicate on the "created_at" field. func CreatedAtNotNil() predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldCreatedAt))) - }) + return predicate.Meta(sql.FieldNotNull(FieldCreatedAt)) } // UpdatedAtEQ applies the EQ predicate on the "updated_at" field. func UpdatedAtEQ(v time.Time) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldUpdatedAt), v)) - }) + return predicate.Meta(sql.FieldEQ(FieldUpdatedAt, v)) } // UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field. func UpdatedAtNEQ(v time.Time) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldUpdatedAt), v)) - }) + return predicate.Meta(sql.FieldNEQ(FieldUpdatedAt, v)) } // UpdatedAtIn applies the In predicate on the "updated_at" field. func UpdatedAtIn(vs ...time.Time) predicate.Meta { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldUpdatedAt), v...)) - }) + return predicate.Meta(sql.FieldIn(FieldUpdatedAt, vs...)) } // UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field. func UpdatedAtNotIn(vs ...time.Time) predicate.Meta { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldUpdatedAt), v...)) - }) + return predicate.Meta(sql.FieldNotIn(FieldUpdatedAt, vs...)) } // UpdatedAtGT applies the GT predicate on the "updated_at" field. func UpdatedAtGT(v time.Time) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldUpdatedAt), v)) - }) + return predicate.Meta(sql.FieldGT(FieldUpdatedAt, v)) } // UpdatedAtGTE applies the GTE predicate on the "updated_at" field. func UpdatedAtGTE(v time.Time) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldUpdatedAt), v)) - }) + return predicate.Meta(sql.FieldGTE(FieldUpdatedAt, v)) } // UpdatedAtLT applies the LT predicate on the "updated_at" field. func UpdatedAtLT(v time.Time) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldUpdatedAt), v)) - }) + return predicate.Meta(sql.FieldLT(FieldUpdatedAt, v)) } // UpdatedAtLTE applies the LTE predicate on the "updated_at" field. func UpdatedAtLTE(v time.Time) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldUpdatedAt), v)) - }) + return predicate.Meta(sql.FieldLTE(FieldUpdatedAt, v)) } // UpdatedAtIsNil applies the IsNil predicate on the "updated_at" field. func UpdatedAtIsNil() predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldUpdatedAt))) - }) + return predicate.Meta(sql.FieldIsNull(FieldUpdatedAt)) } // UpdatedAtNotNil applies the NotNil predicate on the "updated_at" field. func UpdatedAtNotNil() predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldUpdatedAt))) - }) + return predicate.Meta(sql.FieldNotNull(FieldUpdatedAt)) } // KeyEQ applies the EQ predicate on the "key" field. func KeyEQ(v string) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldKey), v)) - }) + return predicate.Meta(sql.FieldEQ(FieldKey, v)) } // KeyNEQ applies the NEQ predicate on the "key" field. func KeyNEQ(v string) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldKey), v)) - }) + return predicate.Meta(sql.FieldNEQ(FieldKey, v)) } // KeyIn applies the In predicate on the "key" field. func KeyIn(vs ...string) predicate.Meta { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldKey), v...)) - }) + return predicate.Meta(sql.FieldIn(FieldKey, vs...)) } // KeyNotIn applies the NotIn predicate on the "key" field. func KeyNotIn(vs ...string) predicate.Meta { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldKey), v...)) - }) + return predicate.Meta(sql.FieldNotIn(FieldKey, vs...)) } // KeyGT applies the GT predicate on the "key" field. func KeyGT(v string) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldKey), v)) - }) + return predicate.Meta(sql.FieldGT(FieldKey, v)) } // KeyGTE applies the GTE predicate on the "key" field. func KeyGTE(v string) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldKey), v)) - }) + return predicate.Meta(sql.FieldGTE(FieldKey, v)) } // KeyLT applies the LT predicate on the "key" field. func KeyLT(v string) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldKey), v)) - }) + return predicate.Meta(sql.FieldLT(FieldKey, v)) } // KeyLTE applies the LTE predicate on the "key" field. func KeyLTE(v string) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldKey), v)) - }) + return predicate.Meta(sql.FieldLTE(FieldKey, v)) } // KeyContains applies the Contains predicate on the "key" field. func KeyContains(v string) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldKey), v)) - }) + return predicate.Meta(sql.FieldContains(FieldKey, v)) } // KeyHasPrefix applies the HasPrefix predicate on the "key" field. func KeyHasPrefix(v string) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldKey), v)) - }) + return predicate.Meta(sql.FieldHasPrefix(FieldKey, v)) } // KeyHasSuffix applies the HasSuffix predicate on the "key" field. func KeyHasSuffix(v string) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldKey), v)) - }) + return predicate.Meta(sql.FieldHasSuffix(FieldKey, v)) } // KeyEqualFold applies the EqualFold predicate on the "key" field. func KeyEqualFold(v string) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldKey), v)) - }) + return predicate.Meta(sql.FieldEqualFold(FieldKey, v)) } // KeyContainsFold applies the ContainsFold predicate on the "key" field. func KeyContainsFold(v string) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldKey), v)) - }) + return predicate.Meta(sql.FieldContainsFold(FieldKey, v)) } // ValueEQ applies the EQ predicate on the "value" field. func ValueEQ(v string) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldValue), v)) - }) + return predicate.Meta(sql.FieldEQ(FieldValue, v)) } // ValueNEQ applies the NEQ predicate on the "value" field. func ValueNEQ(v string) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldValue), v)) - }) + return predicate.Meta(sql.FieldNEQ(FieldValue, v)) } // ValueIn applies the In predicate on the "value" field. func ValueIn(vs ...string) predicate.Meta { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldValue), v...)) - }) + return predicate.Meta(sql.FieldIn(FieldValue, vs...)) } // ValueNotIn applies the NotIn predicate on the "value" field. func ValueNotIn(vs ...string) predicate.Meta { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldValue), v...)) - }) + return predicate.Meta(sql.FieldNotIn(FieldValue, vs...)) } // ValueGT applies the GT predicate on the "value" field. func ValueGT(v string) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldValue), v)) - }) + return predicate.Meta(sql.FieldGT(FieldValue, v)) } // ValueGTE applies the GTE predicate on the "value" field. func ValueGTE(v string) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldValue), v)) - }) + return predicate.Meta(sql.FieldGTE(FieldValue, v)) } // ValueLT applies the LT predicate on the "value" field. func ValueLT(v string) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldValue), v)) - }) + return predicate.Meta(sql.FieldLT(FieldValue, v)) } // ValueLTE applies the LTE predicate on the "value" field. func ValueLTE(v string) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldValue), v)) - }) + return predicate.Meta(sql.FieldLTE(FieldValue, v)) } // ValueContains applies the Contains predicate on the "value" field. func ValueContains(v string) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldValue), v)) - }) + return predicate.Meta(sql.FieldContains(FieldValue, v)) } // ValueHasPrefix applies the HasPrefix predicate on the "value" field. func ValueHasPrefix(v string) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldValue), v)) - }) + return predicate.Meta(sql.FieldHasPrefix(FieldValue, v)) } // ValueHasSuffix applies the HasSuffix predicate on the "value" field. func ValueHasSuffix(v string) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldValue), v)) - }) + return predicate.Meta(sql.FieldHasSuffix(FieldValue, v)) } // ValueEqualFold applies the EqualFold predicate on the "value" field. func ValueEqualFold(v string) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldValue), v)) - }) + return predicate.Meta(sql.FieldEqualFold(FieldValue, v)) } // ValueContainsFold applies the ContainsFold predicate on the "value" field. func ValueContainsFold(v string) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldValue), v)) - }) + return predicate.Meta(sql.FieldContainsFold(FieldValue, v)) } // AlertMetasEQ applies the EQ predicate on the "alert_metas" field. func AlertMetasEQ(v int) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldAlertMetas), v)) - }) + return predicate.Meta(sql.FieldEQ(FieldAlertMetas, v)) } // AlertMetasNEQ applies the NEQ predicate on the "alert_metas" field. func AlertMetasNEQ(v int) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldAlertMetas), v)) - }) + return predicate.Meta(sql.FieldNEQ(FieldAlertMetas, v)) } // AlertMetasIn applies the In predicate on the "alert_metas" field. func AlertMetasIn(vs ...int) predicate.Meta { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldAlertMetas), v...)) - }) + return predicate.Meta(sql.FieldIn(FieldAlertMetas, vs...)) } // AlertMetasNotIn applies the NotIn predicate on the "alert_metas" field. func AlertMetasNotIn(vs ...int) predicate.Meta { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldAlertMetas), v...)) - }) + return predicate.Meta(sql.FieldNotIn(FieldAlertMetas, vs...)) } // AlertMetasIsNil applies the IsNil predicate on the "alert_metas" field. func AlertMetasIsNil() predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldAlertMetas))) - }) + return predicate.Meta(sql.FieldIsNull(FieldAlertMetas)) } // AlertMetasNotNil applies the NotNil predicate on the "alert_metas" field. func AlertMetasNotNil() predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldAlertMetas))) - }) + return predicate.Meta(sql.FieldNotNull(FieldAlertMetas)) } // HasOwner applies the HasEdge predicate on the "owner" edge. @@ -525,7 +345,6 @@ func HasOwner() predicate.Meta { return predicate.Meta(func(s *sql.Selector) { step := sqlgraph.NewStep( sqlgraph.From(Table, FieldID), - sqlgraph.To(OwnerTable, FieldID), sqlgraph.Edge(sqlgraph.M2O, true, OwnerTable, OwnerColumn), ) sqlgraph.HasNeighbors(s, step) @@ -535,11 +354,7 @@ func HasOwner() predicate.Meta { // HasOwnerWith applies the HasEdge predicate on the "owner" edge with a given conditions (other predicates). func HasOwnerWith(preds ...predicate.Alert) predicate.Meta { return predicate.Meta(func(s *sql.Selector) { - step := sqlgraph.NewStep( - sqlgraph.From(Table, FieldID), - sqlgraph.To(OwnerInverseTable, FieldID), - sqlgraph.Edge(sqlgraph.M2O, true, OwnerTable, OwnerColumn), - ) + step := newOwnerStep() sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { for _, p := range preds { p(s) @@ -550,32 +365,15 @@ func HasOwnerWith(preds ...predicate.Alert) predicate.Meta { // And groups predicates with the AND operator between them. func And(predicates ...predicate.Meta) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s1 := s.Clone().SetP(nil) - for _, p := range predicates { - p(s1) - } - s.Where(s1.P()) - }) + return predicate.Meta(sql.AndPredicates(predicates...)) } // Or groups predicates with the OR operator between them. func Or(predicates ...predicate.Meta) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s1 := s.Clone().SetP(nil) - for i, p := range predicates { - if i > 0 { - s1.Or() - } - p(s1) - } - s.Where(s1.P()) - }) + return predicate.Meta(sql.OrPredicates(predicates...)) } // Not applies the not operator on the given predicate. func Not(p predicate.Meta) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - p(s.Not()) - }) + return predicate.Meta(sql.NotPredicates(p)) } diff --git a/pkg/database/ent/meta_create.go b/pkg/database/ent/meta_create.go index df4f6315911..3bf30f0def9 100644 --- a/pkg/database/ent/meta_create.go +++ b/pkg/database/ent/meta_create.go @@ -101,50 +101,8 @@ func (mc *MetaCreate) Mutation() *MetaMutation { // Save creates the Meta in the database. func (mc *MetaCreate) Save(ctx context.Context) (*Meta, error) { - var ( - err error - node *Meta - ) mc.defaults() - if len(mc.hooks) == 0 { - if err = mc.check(); err != nil { - return nil, err - } - node, err = mc.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*MetaMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - if err = mc.check(); err != nil { - return nil, err - } - mc.mutation = mutation - if node, err = mc.sqlSave(ctx); err != nil { - return nil, err - } - mutation.id = &node.ID - mutation.done = true - return node, err - }) - for i := len(mc.hooks) - 1; i >= 0; i-- { - if mc.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = mc.hooks[i](mut) - } - v, err := mut.Mutate(ctx, mc.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*Meta) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from MetaMutation", v) - } - node = nv - } - return node, err + return withHooks(ctx, mc.sqlSave, mc.mutation, mc.hooks) } // SaveX calls Save and panics if Save returns an error. @@ -198,6 +156,9 @@ func (mc *MetaCreate) check() error { } func (mc *MetaCreate) sqlSave(ctx context.Context) (*Meta, error) { + if err := mc.check(); err != nil { + return nil, err + } _node, _spec := mc.createSpec() if err := sqlgraph.CreateNode(ctx, mc.driver, _spec); err != nil { if sqlgraph.IsConstraintError(err) { @@ -207,50 +168,30 @@ func (mc *MetaCreate) sqlSave(ctx context.Context) (*Meta, error) { } id := _spec.ID.Value.(int64) _node.ID = int(id) + mc.mutation.id = &_node.ID + mc.mutation.done = true return _node, nil } func (mc *MetaCreate) createSpec() (*Meta, *sqlgraph.CreateSpec) { var ( _node = &Meta{config: mc.config} - _spec = &sqlgraph.CreateSpec{ - Table: meta.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: meta.FieldID, - }, - } + _spec = sqlgraph.NewCreateSpec(meta.Table, sqlgraph.NewFieldSpec(meta.FieldID, field.TypeInt)) ) if value, ok := mc.mutation.CreatedAt(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: meta.FieldCreatedAt, - }) + _spec.SetField(meta.FieldCreatedAt, field.TypeTime, value) _node.CreatedAt = &value } if value, ok := mc.mutation.UpdatedAt(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: meta.FieldUpdatedAt, - }) + _spec.SetField(meta.FieldUpdatedAt, field.TypeTime, value) _node.UpdatedAt = &value } if value, ok := mc.mutation.Key(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: meta.FieldKey, - }) + _spec.SetField(meta.FieldKey, field.TypeString, value) _node.Key = value } if value, ok := mc.mutation.Value(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: meta.FieldValue, - }) + _spec.SetField(meta.FieldValue, field.TypeString, value) _node.Value = value } if nodes := mc.mutation.OwnerIDs(); len(nodes) > 0 { @@ -261,10 +202,7 @@ func (mc *MetaCreate) createSpec() (*Meta, *sqlgraph.CreateSpec) { Columns: []string{meta.OwnerColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: alert.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -279,11 +217,15 @@ func (mc *MetaCreate) createSpec() (*Meta, *sqlgraph.CreateSpec) { // MetaCreateBulk is the builder for creating many Meta entities in bulk. type MetaCreateBulk struct { config + err error builders []*MetaCreate } // Save creates the Meta entities in the database. func (mcb *MetaCreateBulk) Save(ctx context.Context) ([]*Meta, error) { + if mcb.err != nil { + return nil, mcb.err + } specs := make([]*sqlgraph.CreateSpec, len(mcb.builders)) nodes := make([]*Meta, len(mcb.builders)) mutators := make([]Mutator, len(mcb.builders)) @@ -300,8 +242,8 @@ func (mcb *MetaCreateBulk) Save(ctx context.Context) ([]*Meta, error) { return nil, err } builder.mutation = mutation - nodes[i], specs[i] = builder.createSpec() var err error + nodes[i], specs[i] = builder.createSpec() if i < len(mutators)-1 { _, err = mutators[i+1].Mutate(root, mcb.builders[i+1].mutation) } else { diff --git a/pkg/database/ent/meta_delete.go b/pkg/database/ent/meta_delete.go index e1e49d2acdc..ee25dd07eb9 100644 --- a/pkg/database/ent/meta_delete.go +++ b/pkg/database/ent/meta_delete.go @@ -4,7 +4,6 @@ package ent import ( "context" - "fmt" "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" @@ -28,34 +27,7 @@ func (md *MetaDelete) Where(ps ...predicate.Meta) *MetaDelete { // Exec executes the deletion query and returns how many vertices were deleted. func (md *MetaDelete) Exec(ctx context.Context) (int, error) { - var ( - err error - affected int - ) - if len(md.hooks) == 0 { - affected, err = md.sqlExec(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*MetaMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - md.mutation = mutation - affected, err = md.sqlExec(ctx) - mutation.done = true - return affected, err - }) - for i := len(md.hooks) - 1; i >= 0; i-- { - if md.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = md.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, md.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks(ctx, md.sqlExec, md.mutation, md.hooks) } // ExecX is like Exec, but panics if an error occurs. @@ -68,15 +40,7 @@ func (md *MetaDelete) ExecX(ctx context.Context) int { } func (md *MetaDelete) sqlExec(ctx context.Context) (int, error) { - _spec := &sqlgraph.DeleteSpec{ - Node: &sqlgraph.NodeSpec{ - Table: meta.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: meta.FieldID, - }, - }, - } + _spec := sqlgraph.NewDeleteSpec(meta.Table, sqlgraph.NewFieldSpec(meta.FieldID, field.TypeInt)) if ps := md.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -88,6 +52,7 @@ func (md *MetaDelete) sqlExec(ctx context.Context) (int, error) { if err != nil && sqlgraph.IsConstraintError(err) { err = &ConstraintError{msg: err.Error(), wrap: err} } + md.mutation.done = true return affected, err } @@ -96,6 +61,12 @@ type MetaDeleteOne struct { md *MetaDelete } +// Where appends a list predicates to the MetaDelete builder. +func (mdo *MetaDeleteOne) Where(ps ...predicate.Meta) *MetaDeleteOne { + mdo.md.mutation.Where(ps...) + return mdo +} + // Exec executes the deletion query. func (mdo *MetaDeleteOne) Exec(ctx context.Context) error { n, err := mdo.md.Exec(ctx) @@ -111,5 +82,7 @@ func (mdo *MetaDeleteOne) Exec(ctx context.Context) error { // ExecX is like Exec, but panics if an error occurs. func (mdo *MetaDeleteOne) ExecX(ctx context.Context) { - mdo.md.ExecX(ctx) + if err := mdo.Exec(ctx); err != nil { + panic(err) + } } diff --git a/pkg/database/ent/meta_query.go b/pkg/database/ent/meta_query.go index d6fd4f3d522..87d91d09e0e 100644 --- a/pkg/database/ent/meta_query.go +++ b/pkg/database/ent/meta_query.go @@ -18,11 +18,9 @@ import ( // MetaQuery is the builder for querying Meta entities. type MetaQuery struct { config - limit *int - offset *int - unique *bool - order []OrderFunc - fields []string + ctx *QueryContext + order []meta.OrderOption + inters []Interceptor predicates []predicate.Meta withOwner *AlertQuery // intermediate query (i.e. traversal path). @@ -36,34 +34,34 @@ func (mq *MetaQuery) Where(ps ...predicate.Meta) *MetaQuery { return mq } -// Limit adds a limit step to the query. +// Limit the number of records to be returned by this query. func (mq *MetaQuery) Limit(limit int) *MetaQuery { - mq.limit = &limit + mq.ctx.Limit = &limit return mq } -// Offset adds an offset step to the query. +// Offset to start from. func (mq *MetaQuery) Offset(offset int) *MetaQuery { - mq.offset = &offset + mq.ctx.Offset = &offset return mq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (mq *MetaQuery) Unique(unique bool) *MetaQuery { - mq.unique = &unique + mq.ctx.Unique = &unique return mq } -// Order adds an order step to the query. -func (mq *MetaQuery) Order(o ...OrderFunc) *MetaQuery { +// Order specifies how the records should be ordered. +func (mq *MetaQuery) Order(o ...meta.OrderOption) *MetaQuery { mq.order = append(mq.order, o...) return mq } // QueryOwner chains the current query on the "owner" edge. func (mq *MetaQuery) QueryOwner() *AlertQuery { - query := &AlertQuery{config: mq.config} + query := (&AlertClient{config: mq.config}).Query() query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { if err := mq.prepareQuery(ctx); err != nil { return nil, err @@ -86,7 +84,7 @@ func (mq *MetaQuery) QueryOwner() *AlertQuery { // First returns the first Meta entity from the query. // Returns a *NotFoundError when no Meta was found. func (mq *MetaQuery) First(ctx context.Context) (*Meta, error) { - nodes, err := mq.Limit(1).All(ctx) + nodes, err := mq.Limit(1).All(setContextOp(ctx, mq.ctx, "First")) if err != nil { return nil, err } @@ -109,7 +107,7 @@ func (mq *MetaQuery) FirstX(ctx context.Context) *Meta { // Returns a *NotFoundError when no Meta ID was found. func (mq *MetaQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = mq.Limit(1).IDs(ctx); err != nil { + if ids, err = mq.Limit(1).IDs(setContextOp(ctx, mq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -132,7 +130,7 @@ func (mq *MetaQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one Meta entity is found. // Returns a *NotFoundError when no Meta entities are found. func (mq *MetaQuery) Only(ctx context.Context) (*Meta, error) { - nodes, err := mq.Limit(2).All(ctx) + nodes, err := mq.Limit(2).All(setContextOp(ctx, mq.ctx, "Only")) if err != nil { return nil, err } @@ -160,7 +158,7 @@ func (mq *MetaQuery) OnlyX(ctx context.Context) *Meta { // Returns a *NotFoundError when no entities are found. func (mq *MetaQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = mq.Limit(2).IDs(ctx); err != nil { + if ids, err = mq.Limit(2).IDs(setContextOp(ctx, mq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -185,10 +183,12 @@ func (mq *MetaQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of MetaSlice. func (mq *MetaQuery) All(ctx context.Context) ([]*Meta, error) { + ctx = setContextOp(ctx, mq.ctx, "All") if err := mq.prepareQuery(ctx); err != nil { return nil, err } - return mq.sqlAll(ctx) + qr := querierAll[[]*Meta, *MetaQuery]() + return withInterceptors[[]*Meta](ctx, mq, qr, mq.inters) } // AllX is like All, but panics if an error occurs. @@ -201,9 +201,12 @@ func (mq *MetaQuery) AllX(ctx context.Context) []*Meta { } // IDs executes the query and returns a list of Meta IDs. -func (mq *MetaQuery) IDs(ctx context.Context) ([]int, error) { - var ids []int - if err := mq.Select(meta.FieldID).Scan(ctx, &ids); err != nil { +func (mq *MetaQuery) IDs(ctx context.Context) (ids []int, err error) { + if mq.ctx.Unique == nil && mq.path != nil { + mq.Unique(true) + } + ctx = setContextOp(ctx, mq.ctx, "IDs") + if err = mq.Select(meta.FieldID).Scan(ctx, &ids); err != nil { return nil, err } return ids, nil @@ -220,10 +223,11 @@ func (mq *MetaQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (mq *MetaQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, mq.ctx, "Count") if err := mq.prepareQuery(ctx); err != nil { return 0, err } - return mq.sqlCount(ctx) + return withInterceptors[int](ctx, mq, querierCount[*MetaQuery](), mq.inters) } // CountX is like Count, but panics if an error occurs. @@ -237,10 +241,15 @@ func (mq *MetaQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (mq *MetaQuery) Exist(ctx context.Context) (bool, error) { - if err := mq.prepareQuery(ctx); err != nil { - return false, err + ctx = setContextOp(ctx, mq.ctx, "Exist") + switch _, err := mq.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil } - return mq.sqlExist(ctx) } // ExistX is like Exist, but panics if an error occurs. @@ -260,22 +269,21 @@ func (mq *MetaQuery) Clone() *MetaQuery { } return &MetaQuery{ config: mq.config, - limit: mq.limit, - offset: mq.offset, - order: append([]OrderFunc{}, mq.order...), + ctx: mq.ctx.Clone(), + order: append([]meta.OrderOption{}, mq.order...), + inters: append([]Interceptor{}, mq.inters...), predicates: append([]predicate.Meta{}, mq.predicates...), withOwner: mq.withOwner.Clone(), // clone intermediate query. - sql: mq.sql.Clone(), - path: mq.path, - unique: mq.unique, + sql: mq.sql.Clone(), + path: mq.path, } } // WithOwner tells the query-builder to eager-load the nodes that are connected to // the "owner" edge. The optional arguments are used to configure the query builder of the edge. func (mq *MetaQuery) WithOwner(opts ...func(*AlertQuery)) *MetaQuery { - query := &AlertQuery{config: mq.config} + query := (&AlertClient{config: mq.config}).Query() for _, opt := range opts { opt(query) } @@ -298,16 +306,11 @@ func (mq *MetaQuery) WithOwner(opts ...func(*AlertQuery)) *MetaQuery { // Aggregate(ent.Count()). // Scan(ctx, &v) func (mq *MetaQuery) GroupBy(field string, fields ...string) *MetaGroupBy { - grbuild := &MetaGroupBy{config: mq.config} - grbuild.fields = append([]string{field}, fields...) - grbuild.path = func(ctx context.Context) (prev *sql.Selector, err error) { - if err := mq.prepareQuery(ctx); err != nil { - return nil, err - } - return mq.sqlQuery(ctx), nil - } + mq.ctx.Fields = append([]string{field}, fields...) + grbuild := &MetaGroupBy{build: mq} + grbuild.flds = &mq.ctx.Fields grbuild.label = meta.Label - grbuild.flds, grbuild.scan = &grbuild.fields, grbuild.Scan + grbuild.scan = grbuild.Scan return grbuild } @@ -324,15 +327,30 @@ func (mq *MetaQuery) GroupBy(field string, fields ...string) *MetaGroupBy { // Select(meta.FieldCreatedAt). // Scan(ctx, &v) func (mq *MetaQuery) Select(fields ...string) *MetaSelect { - mq.fields = append(mq.fields, fields...) - selbuild := &MetaSelect{MetaQuery: mq} - selbuild.label = meta.Label - selbuild.flds, selbuild.scan = &mq.fields, selbuild.Scan - return selbuild + mq.ctx.Fields = append(mq.ctx.Fields, fields...) + sbuild := &MetaSelect{MetaQuery: mq} + sbuild.label = meta.Label + sbuild.flds, sbuild.scan = &mq.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a MetaSelect configured with the given aggregations. +func (mq *MetaQuery) Aggregate(fns ...AggregateFunc) *MetaSelect { + return mq.Select().Aggregate(fns...) } func (mq *MetaQuery) prepareQuery(ctx context.Context) error { - for _, f := range mq.fields { + for _, inter := range mq.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, mq); err != nil { + return err + } + } + } + for _, f := range mq.ctx.Fields { if !meta.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -392,6 +410,9 @@ func (mq *MetaQuery) loadOwner(ctx context.Context, query *AlertQuery, nodes []* } nodeids[fk] = append(nodeids[fk], nodes[i]) } + if len(ids) == 0 { + return nil + } query.Where(alert.IDIn(ids...)) neighbors, err := query.All(ctx) if err != nil { @@ -411,41 +432,22 @@ func (mq *MetaQuery) loadOwner(ctx context.Context, query *AlertQuery, nodes []* func (mq *MetaQuery) sqlCount(ctx context.Context) (int, error) { _spec := mq.querySpec() - _spec.Node.Columns = mq.fields - if len(mq.fields) > 0 { - _spec.Unique = mq.unique != nil && *mq.unique + _spec.Node.Columns = mq.ctx.Fields + if len(mq.ctx.Fields) > 0 { + _spec.Unique = mq.ctx.Unique != nil && *mq.ctx.Unique } return sqlgraph.CountNodes(ctx, mq.driver, _spec) } -func (mq *MetaQuery) sqlExist(ctx context.Context) (bool, error) { - switch _, err := mq.FirstID(ctx); { - case IsNotFound(err): - return false, nil - case err != nil: - return false, fmt.Errorf("ent: check existence: %w", err) - default: - return true, nil - } -} - func (mq *MetaQuery) querySpec() *sqlgraph.QuerySpec { - _spec := &sqlgraph.QuerySpec{ - Node: &sqlgraph.NodeSpec{ - Table: meta.Table, - Columns: meta.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: meta.FieldID, - }, - }, - From: mq.sql, - Unique: true, - } - if unique := mq.unique; unique != nil { + _spec := sqlgraph.NewQuerySpec(meta.Table, meta.Columns, sqlgraph.NewFieldSpec(meta.FieldID, field.TypeInt)) + _spec.From = mq.sql + if unique := mq.ctx.Unique; unique != nil { _spec.Unique = *unique + } else if mq.path != nil { + _spec.Unique = true } - if fields := mq.fields; len(fields) > 0 { + if fields := mq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, meta.FieldID) for i := range fields { @@ -453,6 +455,9 @@ func (mq *MetaQuery) querySpec() *sqlgraph.QuerySpec { _spec.Node.Columns = append(_spec.Node.Columns, fields[i]) } } + if mq.withOwner != nil { + _spec.Node.AddColumnOnce(meta.FieldAlertMetas) + } } if ps := mq.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { @@ -461,10 +466,10 @@ func (mq *MetaQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := mq.limit; limit != nil { + if limit := mq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := mq.offset; offset != nil { + if offset := mq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := mq.order; len(ps) > 0 { @@ -480,7 +485,7 @@ func (mq *MetaQuery) querySpec() *sqlgraph.QuerySpec { func (mq *MetaQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(mq.driver.Dialect()) t1 := builder.Table(meta.Table) - columns := mq.fields + columns := mq.ctx.Fields if len(columns) == 0 { columns = meta.Columns } @@ -489,7 +494,7 @@ func (mq *MetaQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = mq.sql selector.Select(selector.Columns(columns...)...) } - if mq.unique != nil && *mq.unique { + if mq.ctx.Unique != nil && *mq.ctx.Unique { selector.Distinct() } for _, p := range mq.predicates { @@ -498,12 +503,12 @@ func (mq *MetaQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range mq.order { p(selector) } - if offset := mq.offset; offset != nil { + if offset := mq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := mq.limit; limit != nil { + if limit := mq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -511,13 +516,8 @@ func (mq *MetaQuery) sqlQuery(ctx context.Context) *sql.Selector { // MetaGroupBy is the group-by builder for Meta entities. type MetaGroupBy struct { - config selector - fields []string - fns []AggregateFunc - // intermediate query (i.e. traversal path). - sql *sql.Selector - path func(context.Context) (*sql.Selector, error) + build *MetaQuery } // Aggregate adds the given aggregation functions to the group-by query. @@ -526,74 +526,77 @@ func (mgb *MetaGroupBy) Aggregate(fns ...AggregateFunc) *MetaGroupBy { return mgb } -// Scan applies the group-by query and scans the result into the given value. +// Scan applies the selector query and scans the result into the given value. func (mgb *MetaGroupBy) Scan(ctx context.Context, v any) error { - query, err := mgb.path(ctx) - if err != nil { + ctx = setContextOp(ctx, mgb.build.ctx, "GroupBy") + if err := mgb.build.prepareQuery(ctx); err != nil { return err } - mgb.sql = query - return mgb.sqlScan(ctx, v) + return scanWithInterceptors[*MetaQuery, *MetaGroupBy](ctx, mgb.build, mgb, mgb.build.inters, v) } -func (mgb *MetaGroupBy) sqlScan(ctx context.Context, v any) error { - for _, f := range mgb.fields { - if !meta.ValidColumn(f) { - return &ValidationError{Name: f, err: fmt.Errorf("invalid field %q for group-by", f)} +func (mgb *MetaGroupBy) sqlScan(ctx context.Context, root *MetaQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(mgb.fns)) + for _, fn := range mgb.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*mgb.flds)+len(mgb.fns)) + for _, f := range *mgb.flds { + columns = append(columns, selector.C(f)) } + columns = append(columns, aggregation...) + selector.Select(columns...) } - selector := mgb.sqlQuery() + selector.GroupBy(selector.Columns(*mgb.flds...)...) if err := selector.Err(); err != nil { return err } rows := &sql.Rows{} query, args := selector.Query() - if err := mgb.driver.Query(ctx, query, args, rows); err != nil { + if err := mgb.build.driver.Query(ctx, query, args, rows); err != nil { return err } defer rows.Close() return sql.ScanSlice(rows, v) } -func (mgb *MetaGroupBy) sqlQuery() *sql.Selector { - selector := mgb.sql.Select() - aggregation := make([]string, 0, len(mgb.fns)) - for _, fn := range mgb.fns { - aggregation = append(aggregation, fn(selector)) - } - // If no columns were selected in a custom aggregation function, the default - // selection is the fields used for "group-by", and the aggregation functions. - if len(selector.SelectedColumns()) == 0 { - columns := make([]string, 0, len(mgb.fields)+len(mgb.fns)) - for _, f := range mgb.fields { - columns = append(columns, selector.C(f)) - } - columns = append(columns, aggregation...) - selector.Select(columns...) - } - return selector.GroupBy(selector.Columns(mgb.fields...)...) -} - // MetaSelect is the builder for selecting fields of Meta entities. type MetaSelect struct { *MetaQuery selector - // intermediate query (i.e. traversal path). - sql *sql.Selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (ms *MetaSelect) Aggregate(fns ...AggregateFunc) *MetaSelect { + ms.fns = append(ms.fns, fns...) + return ms } // Scan applies the selector query and scans the result into the given value. func (ms *MetaSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, ms.ctx, "Select") if err := ms.prepareQuery(ctx); err != nil { return err } - ms.sql = ms.MetaQuery.sqlQuery(ctx) - return ms.sqlScan(ctx, v) + return scanWithInterceptors[*MetaQuery, *MetaSelect](ctx, ms.MetaQuery, ms, ms.inters, v) } -func (ms *MetaSelect) sqlScan(ctx context.Context, v any) error { +func (ms *MetaSelect) sqlScan(ctx context.Context, root *MetaQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(ms.fns)) + for _, fn := range ms.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*ms.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } rows := &sql.Rows{} - query, args := ms.sql.Query() + query, args := selector.Query() if err := ms.driver.Query(ctx, query, args, rows); err != nil { return err } diff --git a/pkg/database/ent/meta_update.go b/pkg/database/ent/meta_update.go index 67a198dddfa..8071c4f0df5 100644 --- a/pkg/database/ent/meta_update.go +++ b/pkg/database/ent/meta_update.go @@ -117,41 +117,8 @@ func (mu *MetaUpdate) ClearOwner() *MetaUpdate { // Save executes the query and returns the number of nodes affected by the update operation. func (mu *MetaUpdate) Save(ctx context.Context) (int, error) { - var ( - err error - affected int - ) mu.defaults() - if len(mu.hooks) == 0 { - if err = mu.check(); err != nil { - return 0, err - } - affected, err = mu.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*MetaMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - if err = mu.check(); err != nil { - return 0, err - } - mu.mutation = mutation - affected, err = mu.sqlSave(ctx) - mutation.done = true - return affected, err - }) - for i := len(mu.hooks) - 1; i >= 0; i-- { - if mu.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = mu.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, mu.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks(ctx, mu.sqlSave, mu.mutation, mu.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -199,16 +166,10 @@ func (mu *MetaUpdate) check() error { } func (mu *MetaUpdate) sqlSave(ctx context.Context) (n int, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: meta.Table, - Columns: meta.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: meta.FieldID, - }, - }, + if err := mu.check(); err != nil { + return n, err } + _spec := sqlgraph.NewUpdateSpec(meta.Table, meta.Columns, sqlgraph.NewFieldSpec(meta.FieldID, field.TypeInt)) if ps := mu.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -217,44 +178,22 @@ func (mu *MetaUpdate) sqlSave(ctx context.Context) (n int, err error) { } } if value, ok := mu.mutation.CreatedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: meta.FieldCreatedAt, - }) + _spec.SetField(meta.FieldCreatedAt, field.TypeTime, value) } if mu.mutation.CreatedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: meta.FieldCreatedAt, - }) + _spec.ClearField(meta.FieldCreatedAt, field.TypeTime) } if value, ok := mu.mutation.UpdatedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: meta.FieldUpdatedAt, - }) + _spec.SetField(meta.FieldUpdatedAt, field.TypeTime, value) } if mu.mutation.UpdatedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: meta.FieldUpdatedAt, - }) + _spec.ClearField(meta.FieldUpdatedAt, field.TypeTime) } if value, ok := mu.mutation.Key(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: meta.FieldKey, - }) + _spec.SetField(meta.FieldKey, field.TypeString, value) } if value, ok := mu.mutation.Value(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: meta.FieldValue, - }) + _spec.SetField(meta.FieldValue, field.TypeString, value) } if mu.mutation.OwnerCleared() { edge := &sqlgraph.EdgeSpec{ @@ -264,10 +203,7 @@ func (mu *MetaUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{meta.OwnerColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: alert.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt), }, } _spec.Edges.Clear = append(_spec.Edges.Clear, edge) @@ -280,10 +216,7 @@ func (mu *MetaUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{meta.OwnerColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: alert.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -299,6 +232,7 @@ func (mu *MetaUpdate) sqlSave(ctx context.Context) (n int, err error) { } return 0, err } + mu.mutation.done = true return n, nil } @@ -396,6 +330,12 @@ func (muo *MetaUpdateOne) ClearOwner() *MetaUpdateOne { return muo } +// Where appends a list predicates to the MetaUpdate builder. +func (muo *MetaUpdateOne) Where(ps ...predicate.Meta) *MetaUpdateOne { + muo.mutation.Where(ps...) + return muo +} + // Select allows selecting one or more fields (columns) of the returned entity. // The default is selecting all fields defined in the entity schema. func (muo *MetaUpdateOne) Select(field string, fields ...string) *MetaUpdateOne { @@ -405,47 +345,8 @@ func (muo *MetaUpdateOne) Select(field string, fields ...string) *MetaUpdateOne // Save executes the query and returns the updated Meta entity. func (muo *MetaUpdateOne) Save(ctx context.Context) (*Meta, error) { - var ( - err error - node *Meta - ) muo.defaults() - if len(muo.hooks) == 0 { - if err = muo.check(); err != nil { - return nil, err - } - node, err = muo.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*MetaMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - if err = muo.check(); err != nil { - return nil, err - } - muo.mutation = mutation - node, err = muo.sqlSave(ctx) - mutation.done = true - return node, err - }) - for i := len(muo.hooks) - 1; i >= 0; i-- { - if muo.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = muo.hooks[i](mut) - } - v, err := mut.Mutate(ctx, muo.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*Meta) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from MetaMutation", v) - } - node = nv - } - return node, err + return withHooks(ctx, muo.sqlSave, muo.mutation, muo.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -493,16 +394,10 @@ func (muo *MetaUpdateOne) check() error { } func (muo *MetaUpdateOne) sqlSave(ctx context.Context) (_node *Meta, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: meta.Table, - Columns: meta.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: meta.FieldID, - }, - }, + if err := muo.check(); err != nil { + return _node, err } + _spec := sqlgraph.NewUpdateSpec(meta.Table, meta.Columns, sqlgraph.NewFieldSpec(meta.FieldID, field.TypeInt)) id, ok := muo.mutation.ID() if !ok { return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "Meta.id" for update`)} @@ -528,44 +423,22 @@ func (muo *MetaUpdateOne) sqlSave(ctx context.Context) (_node *Meta, err error) } } if value, ok := muo.mutation.CreatedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: meta.FieldCreatedAt, - }) + _spec.SetField(meta.FieldCreatedAt, field.TypeTime, value) } if muo.mutation.CreatedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: meta.FieldCreatedAt, - }) + _spec.ClearField(meta.FieldCreatedAt, field.TypeTime) } if value, ok := muo.mutation.UpdatedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: meta.FieldUpdatedAt, - }) + _spec.SetField(meta.FieldUpdatedAt, field.TypeTime, value) } if muo.mutation.UpdatedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: meta.FieldUpdatedAt, - }) + _spec.ClearField(meta.FieldUpdatedAt, field.TypeTime) } if value, ok := muo.mutation.Key(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: meta.FieldKey, - }) + _spec.SetField(meta.FieldKey, field.TypeString, value) } if value, ok := muo.mutation.Value(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: meta.FieldValue, - }) + _spec.SetField(meta.FieldValue, field.TypeString, value) } if muo.mutation.OwnerCleared() { edge := &sqlgraph.EdgeSpec{ @@ -575,10 +448,7 @@ func (muo *MetaUpdateOne) sqlSave(ctx context.Context) (_node *Meta, err error) Columns: []string{meta.OwnerColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: alert.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt), }, } _spec.Edges.Clear = append(_spec.Edges.Clear, edge) @@ -591,10 +461,7 @@ func (muo *MetaUpdateOne) sqlSave(ctx context.Context) (_node *Meta, err error) Columns: []string{meta.OwnerColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: alert.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -613,5 +480,6 @@ func (muo *MetaUpdateOne) sqlSave(ctx context.Context) (_node *Meta, err error) } return nil, err } + muo.mutation.done = true return _node, nil } diff --git a/pkg/database/ent/mutation.go b/pkg/database/ent/mutation.go index 907c1ef015e..c5808d0d9b8 100644 --- a/pkg/database/ent/mutation.go +++ b/pkg/database/ent/mutation.go @@ -9,6 +9,8 @@ import ( "sync" "time" + "entgo.io/ent" + "entgo.io/ent/dialect/sql" "github.com/crowdsecurity/crowdsec/pkg/database/ent/alert" "github.com/crowdsecurity/crowdsec/pkg/database/ent/bouncer" "github.com/crowdsecurity/crowdsec/pkg/database/ent/configitem" @@ -17,8 +19,6 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/database/ent/machine" "github.com/crowdsecurity/crowdsec/pkg/database/ent/meta" "github.com/crowdsecurity/crowdsec/pkg/database/ent/predicate" - - "entgo.io/ent" ) const ( @@ -1578,11 +1578,26 @@ func (m *AlertMutation) Where(ps ...predicate.Alert) { m.predicates = append(m.predicates, ps...) } +// WhereP appends storage-level predicates to the AlertMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *AlertMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.Alert, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + // Op returns the operation name. func (m *AlertMutation) Op() Op { return m.op } +// SetOp allows setting the mutation operation. +func (m *AlertMutation) SetOp(op Op) { + m.op = op +} + // Type returns the node type of this mutation (Alert). func (m *AlertMutation) Type() string { return m.typ @@ -2997,11 +3012,26 @@ func (m *BouncerMutation) Where(ps ...predicate.Bouncer) { m.predicates = append(m.predicates, ps...) } +// WhereP appends storage-level predicates to the BouncerMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *BouncerMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.Bouncer, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + // Op returns the operation name. func (m *BouncerMutation) Op() Op { return m.op } +// SetOp allows setting the mutation operation. +func (m *BouncerMutation) SetOp(op Op) { + m.op = op +} + // Type returns the node type of this mutation (Bouncer). func (m *BouncerMutation) Type() string { return m.typ @@ -3654,11 +3684,26 @@ func (m *ConfigItemMutation) Where(ps ...predicate.ConfigItem) { m.predicates = append(m.predicates, ps...) } +// WhereP appends storage-level predicates to the ConfigItemMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *ConfigItemMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.ConfigItem, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + // Op returns the operation name. func (m *ConfigItemMutation) Op() Op { return m.op } +// SetOp allows setting the mutation operation. +func (m *ConfigItemMutation) SetOp(op Op) { + m.op = op +} + // Type returns the node type of this mutation (ConfigItem). func (m *ConfigItemMutation) Type() string { return m.typ @@ -4830,6 +4875,7 @@ func (m *DecisionMutation) SetOwnerID(id int) { // ClearOwner clears the "owner" edge to the Alert entity. func (m *DecisionMutation) ClearOwner() { m.clearedowner = true + m.clearedFields[decision.FieldAlertDecisions] = struct{}{} } // OwnerCleared reports if the "owner" edge to the Alert entity was cleared. @@ -4866,11 +4912,26 @@ func (m *DecisionMutation) Where(ps ...predicate.Decision) { m.predicates = append(m.predicates, ps...) } +// WhereP appends storage-level predicates to the DecisionMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *DecisionMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.Decision, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + // Op returns the operation name. func (m *DecisionMutation) Op() Op { return m.op } +// SetOp allows setting the mutation operation. +func (m *DecisionMutation) SetOp(op Op) { + m.op = op +} + // Type returns the node type of this mutation (Decision). func (m *DecisionMutation) Type() string { return m.typ @@ -5775,6 +5836,7 @@ func (m *EventMutation) SetOwnerID(id int) { // ClearOwner clears the "owner" edge to the Alert entity. func (m *EventMutation) ClearOwner() { m.clearedowner = true + m.clearedFields[event.FieldAlertEvents] = struct{}{} } // OwnerCleared reports if the "owner" edge to the Alert entity was cleared. @@ -5811,11 +5873,26 @@ func (m *EventMutation) Where(ps ...predicate.Event) { m.predicates = append(m.predicates, ps...) } +// WhereP appends storage-level predicates to the EventMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *EventMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.Event, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + // Op returns the operation name. func (m *EventMutation) Op() Op { return m.op } +// SetOp allows setting the mutation operation. +func (m *EventMutation) SetOp(op Op) { + m.op = op +} + // Type returns the node type of this mutation (Event). func (m *EventMutation) Type() string { return m.typ @@ -6795,11 +6872,26 @@ func (m *MachineMutation) Where(ps ...predicate.Machine) { m.predicates = append(m.predicates, ps...) } +// WhereP appends storage-level predicates to the MachineMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *MachineMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.Machine, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + // Op returns the operation name. func (m *MachineMutation) Op() Op { return m.op } +// SetOp allows setting the mutation operation. +func (m *MachineMutation) SetOp(op Op) { + m.op = op +} + // Type returns the node type of this mutation (Machine). func (m *MachineMutation) Type() string { return m.typ @@ -7565,6 +7657,7 @@ func (m *MetaMutation) SetOwnerID(id int) { // ClearOwner clears the "owner" edge to the Alert entity. func (m *MetaMutation) ClearOwner() { m.clearedowner = true + m.clearedFields[meta.FieldAlertMetas] = struct{}{} } // OwnerCleared reports if the "owner" edge to the Alert entity was cleared. @@ -7601,11 +7694,26 @@ func (m *MetaMutation) Where(ps ...predicate.Meta) { m.predicates = append(m.predicates, ps...) } +// WhereP appends storage-level predicates to the MetaMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *MetaMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.Meta, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + // Op returns the operation name. func (m *MetaMutation) Op() Op { return m.op } +// SetOp allows setting the mutation operation. +func (m *MetaMutation) SetOp(op Op) { + m.op = op +} + // Type returns the node type of this mutation (Meta). func (m *MetaMutation) Type() string { return m.typ diff --git a/pkg/database/ent/runtime/runtime.go b/pkg/database/ent/runtime/runtime.go index e64f7bd7554..2a645f624d7 100644 --- a/pkg/database/ent/runtime/runtime.go +++ b/pkg/database/ent/runtime/runtime.go @@ -5,6 +5,6 @@ package runtime // The schema-stitching logic is generated in github.com/crowdsecurity/crowdsec/pkg/database/ent/runtime.go const ( - Version = "v0.11.3" // Version of ent codegen. - Sum = "h1:F5FBGAWiDCGder7YT+lqMnyzXl6d0xU3xMBM/SO3CMc=" // Sum of ent codegen. + Version = "v0.12.4" // Version of ent codegen. + Sum = "h1:LddPnAyxls/O7DTXZvUGDj0NZIdGSu317+aoNLJWbD8=" // Sum of ent codegen. ) diff --git a/pkg/database/ent/tx.go b/pkg/database/ent/tx.go index 2a1efd152a0..65c2ed00a44 100644 --- a/pkg/database/ent/tx.go +++ b/pkg/database/ent/tx.go @@ -30,12 +30,6 @@ type Tx struct { // lazily loaded. client *Client clientOnce sync.Once - - // completion callbacks. - mu sync.Mutex - onCommit []CommitHook - onRollback []RollbackHook - // ctx lives for the life of the transaction. It is // the same context used by the underlying connection. ctx context.Context @@ -80,9 +74,9 @@ func (tx *Tx) Commit() error { var fn Committer = CommitFunc(func(context.Context, *Tx) error { return txDriver.tx.Commit() }) - tx.mu.Lock() - hooks := append([]CommitHook(nil), tx.onCommit...) - tx.mu.Unlock() + txDriver.mu.Lock() + hooks := append([]CommitHook(nil), txDriver.onCommit...) + txDriver.mu.Unlock() for i := len(hooks) - 1; i >= 0; i-- { fn = hooks[i](fn) } @@ -91,9 +85,10 @@ func (tx *Tx) Commit() error { // OnCommit adds a hook to call on commit. func (tx *Tx) OnCommit(f CommitHook) { - tx.mu.Lock() - defer tx.mu.Unlock() - tx.onCommit = append(tx.onCommit, f) + txDriver := tx.config.driver.(*txDriver) + txDriver.mu.Lock() + txDriver.onCommit = append(txDriver.onCommit, f) + txDriver.mu.Unlock() } type ( @@ -135,9 +130,9 @@ func (tx *Tx) Rollback() error { var fn Rollbacker = RollbackFunc(func(context.Context, *Tx) error { return txDriver.tx.Rollback() }) - tx.mu.Lock() - hooks := append([]RollbackHook(nil), tx.onRollback...) - tx.mu.Unlock() + txDriver.mu.Lock() + hooks := append([]RollbackHook(nil), txDriver.onRollback...) + txDriver.mu.Unlock() for i := len(hooks) - 1; i >= 0; i-- { fn = hooks[i](fn) } @@ -146,9 +141,10 @@ func (tx *Tx) Rollback() error { // OnRollback adds a hook to call on rollback. func (tx *Tx) OnRollback(f RollbackHook) { - tx.mu.Lock() - defer tx.mu.Unlock() - tx.onRollback = append(tx.onRollback, f) + txDriver := tx.config.driver.(*txDriver) + txDriver.mu.Lock() + txDriver.onRollback = append(txDriver.onRollback, f) + txDriver.mu.Unlock() } // Client returns a Client that binds to current transaction. @@ -186,6 +182,10 @@ type txDriver struct { drv dialect.Driver // tx is the underlying transaction. tx dialect.Tx + // completion hooks. + mu sync.Mutex + onCommit []CommitHook + onRollback []RollbackHook } // newTx creates a new transactional driver. From 8de8bf0e0653beb80d847741b416ad51edd00272 Mon Sep 17 00:00:00 2001 From: mmetc <92726601+mmetc@users.noreply.github.com> Date: Wed, 14 Feb 2024 11:53:12 +0100 Subject: [PATCH 11/20] pkg/hubtest: extract methods + consistent error handling (#2756) * pkg/hubtest: extract methods + consistent error handling * lint * rename variables for further refactor --- pkg/hubtest/appsecrule.go | 80 ++++++++++++++++-------------- pkg/hubtest/parser.go | 99 +++++++++++++++++++------------------ pkg/hubtest/postoverflow.go | 98 ++++++++++++++++++------------------ pkg/hubtest/scenario.go | 78 +++++++++++++++-------------- 4 files changed, 185 insertions(+), 170 deletions(-) diff --git a/pkg/hubtest/appsecrule.go b/pkg/hubtest/appsecrule.go index 9b70e1441ac..fb4ad78cc18 100644 --- a/pkg/hubtest/appsecrule.go +++ b/pkg/hubtest/appsecrule.go @@ -11,75 +11,81 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/cwhub" ) -func (t *HubTestItem) installAppsecRuleItem(hubAppsecRule *cwhub.Item) error { - appsecRuleSource, err := filepath.Abs(filepath.Join(t.HubPath, hubAppsecRule.RemotePath)) +func (t *HubTestItem) installAppsecRuleItem(item *cwhub.Item) error { + sourcePath, err := filepath.Abs(filepath.Join(t.HubPath, item.RemotePath)) if err != nil { - return fmt.Errorf("can't get absolute path of '%s': %s", appsecRuleSource, err) + return fmt.Errorf("can't get absolute path of '%s': %w", sourcePath, err) } - appsecRuleFilename := filepath.Base(appsecRuleSource) + sourceFilename := filepath.Base(sourcePath) // runtime/hub/appsec-rules/author/appsec-rule - hubDirAppsecRuleDest := filepath.Join(t.RuntimeHubPath, filepath.Dir(hubAppsecRule.RemotePath)) + hubDirAppsecRuleDest := filepath.Join(t.RuntimeHubPath, filepath.Dir(item.RemotePath)) // runtime/appsec-rules/ - appsecRuleDirDest := fmt.Sprintf("%s/appsec-rules/", t.RuntimePath) + itemTypeDirDest := fmt.Sprintf("%s/appsec-rules/", t.RuntimePath) if err := os.MkdirAll(hubDirAppsecRuleDest, os.ModePerm); err != nil { - return fmt.Errorf("unable to create folder '%s': %s", hubDirAppsecRuleDest, err) + return fmt.Errorf("unable to create folder '%s': %w", hubDirAppsecRuleDest, err) } - if err := os.MkdirAll(appsecRuleDirDest, os.ModePerm); err != nil { - return fmt.Errorf("unable to create folder '%s': %s", appsecRuleDirDest, err) + if err := os.MkdirAll(itemTypeDirDest, os.ModePerm); err != nil { + return fmt.Errorf("unable to create folder '%s': %w", itemTypeDirDest, err) } // runtime/hub/appsec-rules/crowdsecurity/rule.yaml - hubDirAppsecRulePath := filepath.Join(appsecRuleDirDest, appsecRuleFilename) - if err := Copy(appsecRuleSource, hubDirAppsecRulePath); err != nil { - return fmt.Errorf("unable to copy '%s' to '%s': %s", appsecRuleSource, hubDirAppsecRulePath, err) + hubDirAppsecRulePath := filepath.Join(itemTypeDirDest, sourceFilename) + if err := Copy(sourcePath, hubDirAppsecRulePath); err != nil { + return fmt.Errorf("unable to copy '%s' to '%s': %w", sourcePath, hubDirAppsecRulePath, err) } // runtime/appsec-rules/rule.yaml - appsecRulePath := filepath.Join(appsecRuleDirDest, appsecRuleFilename) + appsecRulePath := filepath.Join(itemTypeDirDest, sourceFilename) if err := os.Symlink(hubDirAppsecRulePath, appsecRulePath); err != nil { if !os.IsExist(err) { - return fmt.Errorf("unable to symlink appsec-rule '%s' to '%s': %s", hubDirAppsecRulePath, appsecRulePath, err) + return fmt.Errorf("unable to symlink appsec-rule '%s' to '%s': %w", hubDirAppsecRulePath, appsecRulePath, err) } } return nil } +func (t *HubTestItem) installAppsecRuleCustomFrom(appsecrule string, customPath string) (bool, error) { + // we check if its a custom appsec-rule + customAppsecRulePath := filepath.Join(customPath, appsecrule) + if _, err := os.Stat(customAppsecRulePath); os.IsNotExist(err) { + return false, nil + } + + customAppsecRulePathSplit := strings.Split(customAppsecRulePath, "/") + customAppsecRuleName := customAppsecRulePathSplit[len(customAppsecRulePathSplit)-1] + + itemTypeDirDest := fmt.Sprintf("%s/appsec-rules/", t.RuntimePath) + if err := os.MkdirAll(itemTypeDirDest, os.ModePerm); err != nil { + return false, fmt.Errorf("unable to create folder '%s': %w", itemTypeDirDest, err) + } + + customAppsecRuleDest := fmt.Sprintf("%s/appsec-rules/%s", t.RuntimePath, customAppsecRuleName) + if err := Copy(customAppsecRulePath, customAppsecRuleDest); err != nil { + return false, fmt.Errorf("unable to copy appsec-rule from '%s' to '%s': %w", customAppsecRulePath, customAppsecRuleDest, err) + } + + return true, nil +} + func (t *HubTestItem) installAppsecRuleCustom(appsecrule string) error { - customAppsecRuleExist := false for _, customPath := range t.CustomItemsLocation { - // we check if its a custom appsec-rule - customAppsecRulePath := filepath.Join(customPath, appsecrule) - if _, err := os.Stat(customAppsecRulePath); os.IsNotExist(err) { - continue - } - customAppsecRulePathSplit := strings.Split(customAppsecRulePath, "/") - customAppsecRuleName := customAppsecRulePathSplit[len(customAppsecRulePathSplit)-1] - - appsecRuleDirDest := fmt.Sprintf("%s/appsec-rules/", t.RuntimePath) - if err := os.MkdirAll(appsecRuleDirDest, os.ModePerm); err != nil { - return fmt.Errorf("unable to create folder '%s': %s", appsecRuleDirDest, err) + found, err := t.installAppsecRuleCustomFrom(appsecrule, customPath) + if err != nil { + return err } - // runtime/appsec-rules/ - customAppsecRuleDest := fmt.Sprintf("%s/appsec-rules/%s", t.RuntimePath, customAppsecRuleName) - // if path to postoverflow exist, copy it - if err := Copy(customAppsecRulePath, customAppsecRuleDest); err != nil { - continue + if found { + return nil } - customAppsecRuleExist = true - break - } - if !customAppsecRuleExist { - return fmt.Errorf("couldn't find custom appsec-rule '%s' in the following location: %+v", appsecrule, t.CustomItemsLocation) } - return nil + return fmt.Errorf("couldn't find custom appsec-rule '%s' in the following location: %+v", appsecrule, t.CustomItemsLocation) } func (t *HubTestItem) installAppsecRule(name string) error { diff --git a/pkg/hubtest/parser.go b/pkg/hubtest/parser.go index b8dcdb8b1d0..d40301e3015 100644 --- a/pkg/hubtest/parser.go +++ b/pkg/hubtest/parser.go @@ -9,89 +9,90 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/cwhub" ) -func (t *HubTestItem) installParserItem(hubParser *cwhub.Item) error { - parserSource, err := filepath.Abs(filepath.Join(t.HubPath, hubParser.RemotePath)) +func (t *HubTestItem) installParserItem(item *cwhub.Item) error { + sourcePath, err := filepath.Abs(filepath.Join(t.HubPath, item.RemotePath)) if err != nil { - return fmt.Errorf("can't get absolute path of '%s': %s", parserSource, err) + return fmt.Errorf("can't get absolute path of '%s': %w", sourcePath, err) } - parserFileName := filepath.Base(parserSource) + sourceFilename := filepath.Base(sourcePath) // runtime/hub/parsers/s00-raw/crowdsecurity/ - hubDirParserDest := filepath.Join(t.RuntimeHubPath, filepath.Dir(hubParser.RemotePath)) + hubDirParserDest := filepath.Join(t.RuntimeHubPath, filepath.Dir(item.RemotePath)) // runtime/parsers/s00-raw/ - parserDirDest := fmt.Sprintf("%s/parsers/%s/", t.RuntimePath, hubParser.Stage) + itemTypeDirDest := fmt.Sprintf("%s/parsers/%s/", t.RuntimePath, item.Stage) if err := os.MkdirAll(hubDirParserDest, os.ModePerm); err != nil { - return fmt.Errorf("unable to create folder '%s': %s", hubDirParserDest, err) + return fmt.Errorf("unable to create folder '%s': %w", hubDirParserDest, err) } - if err := os.MkdirAll(parserDirDest, os.ModePerm); err != nil { - return fmt.Errorf("unable to create folder '%s': %s", parserDirDest, err) + if err := os.MkdirAll(itemTypeDirDest, os.ModePerm); err != nil { + return fmt.Errorf("unable to create folder '%s': %w", itemTypeDirDest, err) } // runtime/hub/parsers/s00-raw/crowdsecurity/syslog-logs.yaml - hubDirParserPath := filepath.Join(hubDirParserDest, parserFileName) - if err := Copy(parserSource, hubDirParserPath); err != nil { - return fmt.Errorf("unable to copy '%s' to '%s': %s", parserSource, hubDirParserPath, err) + hubDirParserPath := filepath.Join(hubDirParserDest, sourceFilename) + if err := Copy(sourcePath, hubDirParserPath); err != nil { + return fmt.Errorf("unable to copy '%s' to '%s': %w", sourcePath, hubDirParserPath, err) } // runtime/parsers/s00-raw/syslog-logs.yaml - parserDirParserPath := filepath.Join(parserDirDest, parserFileName) + parserDirParserPath := filepath.Join(itemTypeDirDest, sourceFilename) if err := os.Symlink(hubDirParserPath, parserDirParserPath); err != nil { if !os.IsExist(err) { - return fmt.Errorf("unable to symlink parser '%s' to '%s': %s", hubDirParserPath, parserDirParserPath, err) + return fmt.Errorf("unable to symlink parser '%s' to '%s': %w", hubDirParserPath, parserDirParserPath, err) } } return nil } -func (t *HubTestItem) installParserCustom(parser string) error { - customParserExist := false - for _, customPath := range t.CustomItemsLocation { - // we check if its a custom parser - customParserPath := filepath.Join(customPath, parser) - if _, err := os.Stat(customParserPath); os.IsNotExist(err) { - continue - //return fmt.Errorf("parser '%s' doesn't exist in the hub and doesn't appear to be a custom one.", parser) - } +func (t *HubTestItem) installParserCustomFrom(parser string, customPath string) (bool, error) { + // we check if its a custom parser + customParserPath := filepath.Join(customPath, parser) + if _, err := os.Stat(customParserPath); os.IsNotExist(err) { + return false, nil + } - customParserPathSplit, customParserName := filepath.Split(customParserPath) - // because path is parsers///parser.yaml and we wan't the stage - splittedPath := strings.Split(customParserPathSplit, string(os.PathSeparator)) - customParserStage := splittedPath[len(splittedPath)-3] + customParserPathSplit, customParserName := filepath.Split(customParserPath) + // because path is parsers///parser.yaml and we wan't the stage + splitPath := strings.Split(customParserPathSplit, string(os.PathSeparator)) + customParserStage := splitPath[len(splitPath)-3] - // check if stage exist - hubStagePath := filepath.Join(t.HubPath, fmt.Sprintf("parsers/%s", customParserStage)) + // check if stage exist + hubStagePath := filepath.Join(t.HubPath, fmt.Sprintf("parsers/%s", customParserStage)) + if _, err := os.Stat(hubStagePath); os.IsNotExist(err) { + return false, fmt.Errorf("stage '%s' extracted from '%s' doesn't exist in the hub", customParserStage, hubStagePath) + } - if _, err := os.Stat(hubStagePath); os.IsNotExist(err) { - continue - //return fmt.Errorf("stage '%s' extracted from '%s' doesn't exist in the hub", customParserStage, hubStagePath) - } + stageDirDest := fmt.Sprintf("%s/parsers/%s/", t.RuntimePath, customParserStage) + if err := os.MkdirAll(stageDirDest, os.ModePerm); err != nil { + return false, fmt.Errorf("unable to create folder '%s': %w", stageDirDest, err) + } - parserDirDest := fmt.Sprintf("%s/parsers/%s/", t.RuntimePath, customParserStage) - if err := os.MkdirAll(parserDirDest, os.ModePerm); err != nil { - continue - //return fmt.Errorf("unable to create folder '%s': %s", parserDirDest, err) - } + customParserDest := filepath.Join(stageDirDest, customParserName) + // if path to parser exist, copy it + if err := Copy(customParserPath, customParserDest); err != nil { + return false, fmt.Errorf("unable to copy custom parser '%s' to '%s': %w", customParserPath, customParserDest, err) + } + + return true, nil +} - customParserDest := filepath.Join(parserDirDest, customParserName) - // if path to parser exist, copy it - if err := Copy(customParserPath, customParserDest); err != nil { - continue - //return fmt.Errorf("unable to copy custom parser '%s' to '%s': %s", customParserPath, customParserDest, err) +func (t *HubTestItem) installParserCustom(parser string) error { + for _, customPath := range t.CustomItemsLocation { + found, err := t.installParserCustomFrom(parser, customPath) + if err != nil { + return err } - customParserExist = true - break - } - if !customParserExist { - return fmt.Errorf("couldn't find custom parser '%s' in the following location: %+v", parser, t.CustomItemsLocation) + if found { + return nil + } } - return nil + return fmt.Errorf("couldn't find custom parser '%s' in the following locations: %+v", parser, t.CustomItemsLocation) } func (t *HubTestItem) installParser(name string) error { diff --git a/pkg/hubtest/postoverflow.go b/pkg/hubtest/postoverflow.go index d5d43ddc742..76a67b58b76 100644 --- a/pkg/hubtest/postoverflow.go +++ b/pkg/hubtest/postoverflow.go @@ -9,88 +9,90 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/cwhub" ) -func (t *HubTestItem) installPostoverflowItem(hubPostOverflow *cwhub.Item) error { - postoverflowSource, err := filepath.Abs(filepath.Join(t.HubPath, hubPostOverflow.RemotePath)) +func (t *HubTestItem) installPostoverflowItem(item *cwhub.Item) error { + sourcePath, err := filepath.Abs(filepath.Join(t.HubPath, item.RemotePath)) if err != nil { - return fmt.Errorf("can't get absolute path of '%s': %s", postoverflowSource, err) + return fmt.Errorf("can't get absolute path of '%s': %w", sourcePath, err) } - postoverflowFileName := filepath.Base(postoverflowSource) + sourceFilename := filepath.Base(sourcePath) // runtime/hub/postoverflows/s00-enrich/crowdsecurity/ - hubDirPostoverflowDest := filepath.Join(t.RuntimeHubPath, filepath.Dir(hubPostOverflow.RemotePath)) + hubDirPostoverflowDest := filepath.Join(t.RuntimeHubPath, filepath.Dir(item.RemotePath)) // runtime/postoverflows/s00-enrich - postoverflowDirDest := fmt.Sprintf("%s/postoverflows/%s/", t.RuntimePath, hubPostOverflow.Stage) + itemTypeDirDest := fmt.Sprintf("%s/postoverflows/%s/", t.RuntimePath, item.Stage) if err := os.MkdirAll(hubDirPostoverflowDest, os.ModePerm); err != nil { - return fmt.Errorf("unable to create folder '%s': %s", hubDirPostoverflowDest, err) + return fmt.Errorf("unable to create folder '%s': %w", hubDirPostoverflowDest, err) } - if err := os.MkdirAll(postoverflowDirDest, os.ModePerm); err != nil { - return fmt.Errorf("unable to create folder '%s': %s", postoverflowDirDest, err) + if err := os.MkdirAll(itemTypeDirDest, os.ModePerm); err != nil { + return fmt.Errorf("unable to create folder '%s': %w", itemTypeDirDest, err) } // runtime/hub/postoverflows/s00-enrich/crowdsecurity/rdns.yaml - hubDirPostoverflowPath := filepath.Join(hubDirPostoverflowDest, postoverflowFileName) - if err := Copy(postoverflowSource, hubDirPostoverflowPath); err != nil { - return fmt.Errorf("unable to copy '%s' to '%s': %s", postoverflowSource, hubDirPostoverflowPath, err) + hubDirPostoverflowPath := filepath.Join(hubDirPostoverflowDest, sourceFilename) + if err := Copy(sourcePath, hubDirPostoverflowPath); err != nil { + return fmt.Errorf("unable to copy '%s' to '%s': %w", sourcePath, hubDirPostoverflowPath, err) } // runtime/postoverflows/s00-enrich/rdns.yaml - postoverflowDirParserPath := filepath.Join(postoverflowDirDest, postoverflowFileName) + postoverflowDirParserPath := filepath.Join(itemTypeDirDest, sourceFilename) if err := os.Symlink(hubDirPostoverflowPath, postoverflowDirParserPath); err != nil { if !os.IsExist(err) { - return fmt.Errorf("unable to symlink postoverflow '%s' to '%s': %s", hubDirPostoverflowPath, postoverflowDirParserPath, err) + return fmt.Errorf("unable to symlink postoverflow '%s' to '%s': %w", hubDirPostoverflowPath, postoverflowDirParserPath, err) } } return nil } -func (t *HubTestItem) installPostoverflowCustom(postoverflow string) error { - customPostoverflowExist := false - for _, customPath := range t.CustomItemsLocation { - // we check if its a custom postoverflow - customPostOverflowPath := filepath.Join(customPath, postoverflow) - if _, err := os.Stat(customPostOverflowPath); os.IsNotExist(err) { - continue - //return fmt.Errorf("postoverflow '%s' doesn't exist in the hub and doesn't appear to be a custom one.", postoverflow) - } +func (t *HubTestItem) installPostoverflowCustomFrom(postoverflow string, customPath string) (bool, error) { + // we check if its a custom postoverflow + customPostOverflowPath := filepath.Join(customPath, postoverflow) + if _, err := os.Stat(customPostOverflowPath); os.IsNotExist(err) { + return false, nil + } - customPostOverflowPathSplit := strings.Split(customPostOverflowPath, "/") - customPostoverflowName := customPostOverflowPathSplit[len(customPostOverflowPathSplit)-1] - // because path is postoverflows///parser.yaml and we wan't the stage - customPostoverflowStage := customPostOverflowPathSplit[len(customPostOverflowPathSplit)-3] + customPostOverflowPathSplit := strings.Split(customPostOverflowPath, "/") + customPostoverflowName := customPostOverflowPathSplit[len(customPostOverflowPathSplit)-1] + // because path is postoverflows///parser.yaml and we wan't the stage + customPostoverflowStage := customPostOverflowPathSplit[len(customPostOverflowPathSplit)-3] - // check if stage exist - hubStagePath := filepath.Join(t.HubPath, fmt.Sprintf("postoverflows/%s", customPostoverflowStage)) + // check if stage exist + hubStagePath := filepath.Join(t.HubPath, fmt.Sprintf("postoverflows/%s", customPostoverflowStage)) + if _, err := os.Stat(hubStagePath); os.IsNotExist(err) { + return false, fmt.Errorf("stage '%s' from extracted '%s' doesn't exist in the hub", customPostoverflowStage, hubStagePath) + } - if _, err := os.Stat(hubStagePath); os.IsNotExist(err) { - continue - //return fmt.Errorf("stage '%s' from extracted '%s' doesn't exist in the hub", customPostoverflowStage, hubStagePath) - } + stageDirDest := fmt.Sprintf("%s/postoverflows/%s/", t.RuntimePath, customPostoverflowStage) + if err := os.MkdirAll(stageDirDest, os.ModePerm); err != nil { + return false, fmt.Errorf("unable to create folder '%s': %w", stageDirDest, err) + } - postoverflowDirDest := fmt.Sprintf("%s/postoverflows/%s/", t.RuntimePath, customPostoverflowStage) - if err := os.MkdirAll(postoverflowDirDest, os.ModePerm); err != nil { - continue - //return fmt.Errorf("unable to create folder '%s': %s", postoverflowDirDest, err) + customPostoverflowDest := filepath.Join(stageDirDest, customPostoverflowName) + // if path to postoverflow exist, copy it + if err := Copy(customPostOverflowPath, customPostoverflowDest); err != nil { + return false, fmt.Errorf("unable to copy custom parser '%s' to '%s': %w", customPostOverflowPath, customPostoverflowDest, err) + } + + return true, nil +} + +func (t *HubTestItem) installPostoverflowCustom(postoverflow string) error { + for _, customPath := range t.CustomItemsLocation { + found, err := t.installPostoverflowCustomFrom(postoverflow, customPath) + if err != nil { + return err } - customPostoverflowDest := filepath.Join(postoverflowDirDest, customPostoverflowName) - // if path to postoverflow exist, copy it - if err := Copy(customPostOverflowPath, customPostoverflowDest); err != nil { - continue - //return fmt.Errorf("unable to copy custom parser '%s' to '%s': %s", customPostOverflowPath, customPostoverflowDest, err) + if found { + return nil } - customPostoverflowExist = true - break - } - if !customPostoverflowExist { - return fmt.Errorf("couldn't find custom postoverflow '%s' in the following location: %+v", postoverflow, t.CustomItemsLocation) } - return nil + return fmt.Errorf("couldn't find custom postoverflow '%s' in the following location: %+v", postoverflow, t.CustomItemsLocation) } func (t *HubTestItem) installPostoverflow(name string) error { diff --git a/pkg/hubtest/scenario.go b/pkg/hubtest/scenario.go index eaa831d8013..35ea465b7c0 100644 --- a/pkg/hubtest/scenario.go +++ b/pkg/hubtest/scenario.go @@ -8,74 +8,80 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/cwhub" ) -func (t *HubTestItem) installScenarioItem(hubScenario *cwhub.Item) error { - scenarioSource, err := filepath.Abs(filepath.Join(t.HubPath, hubScenario.RemotePath)) +func (t *HubTestItem) installScenarioItem(item *cwhub.Item) error { + sourcePath, err := filepath.Abs(filepath.Join(t.HubPath, item.RemotePath)) if err != nil { - return fmt.Errorf("can't get absolute path to: %s", scenarioSource) + return fmt.Errorf("can't get absolute path of '%s': %w", sourcePath, err) } - scenarioFileName := filepath.Base(scenarioSource) + sourceFilename := filepath.Base(sourcePath) // runtime/hub/scenarios/crowdsecurity/ - hubDirScenarioDest := filepath.Join(t.RuntimeHubPath, filepath.Dir(hubScenario.RemotePath)) + hubDirScenarioDest := filepath.Join(t.RuntimeHubPath, filepath.Dir(item.RemotePath)) // runtime/parsers/scenarios/ - scenarioDirDest := fmt.Sprintf("%s/scenarios/", t.RuntimePath) + itemTypeDirDest := fmt.Sprintf("%s/scenarios/", t.RuntimePath) if err := os.MkdirAll(hubDirScenarioDest, os.ModePerm); err != nil { - return fmt.Errorf("unable to create folder '%s': %s", hubDirScenarioDest, err) + return fmt.Errorf("unable to create folder '%s': %w", hubDirScenarioDest, err) } - if err := os.MkdirAll(scenarioDirDest, os.ModePerm); err != nil { - return fmt.Errorf("unable to create folder '%s': %s", scenarioDirDest, err) + if err := os.MkdirAll(itemTypeDirDest, os.ModePerm); err != nil { + return fmt.Errorf("unable to create folder '%s': %w", itemTypeDirDest, err) } // runtime/hub/scenarios/crowdsecurity/ssh-bf.yaml - hubDirScenarioPath := filepath.Join(hubDirScenarioDest, scenarioFileName) - if err := Copy(scenarioSource, hubDirScenarioPath); err != nil { - return fmt.Errorf("unable to copy '%s' to '%s': %s", scenarioSource, hubDirScenarioPath, err) + hubDirScenarioPath := filepath.Join(hubDirScenarioDest, sourceFilename) + if err := Copy(sourcePath, hubDirScenarioPath); err != nil { + return fmt.Errorf("unable to copy '%s' to '%s': %w", sourcePath, hubDirScenarioPath, err) } // runtime/scenarios/ssh-bf.yaml - scenarioDirParserPath := filepath.Join(scenarioDirDest, scenarioFileName) + scenarioDirParserPath := filepath.Join(itemTypeDirDest, sourceFilename) if err := os.Symlink(hubDirScenarioPath, scenarioDirParserPath); err != nil { if !os.IsExist(err) { - return fmt.Errorf("unable to symlink scenario '%s' to '%s': %s", hubDirScenarioPath, scenarioDirParserPath, err) + return fmt.Errorf("unable to symlink scenario '%s' to '%s': %w", hubDirScenarioPath, scenarioDirParserPath, err) } } return nil } +func (t *HubTestItem) installScenarioCustomFrom(scenario string, customPath string) (bool, error) { + // we check if its a custom scenario + customScenarioPath := filepath.Join(customPath, scenario) + if _, err := os.Stat(customScenarioPath); os.IsNotExist(err) { + return false, nil + } + + itemTypeDirDest := fmt.Sprintf("%s/scenarios/", t.RuntimePath) + if err := os.MkdirAll(itemTypeDirDest, os.ModePerm); err != nil { + return false, fmt.Errorf("unable to create folder '%s': %w", itemTypeDirDest, err) + } + + scenarioFileName := filepath.Base(customScenarioPath) + + scenarioFileDest := filepath.Join(itemTypeDirDest, scenarioFileName) + if err := Copy(customScenarioPath, scenarioFileDest); err != nil { + return false, fmt.Errorf("unable to copy scenario from '%s' to '%s': %w", customScenarioPath, scenarioFileDest, err) + } + + return true, nil +} + func (t *HubTestItem) installScenarioCustom(scenario string) error { - customScenarioExist := false for _, customPath := range t.CustomItemsLocation { - // we check if its a custom scenario - customScenarioPath := filepath.Join(customPath, scenario) - if _, err := os.Stat(customScenarioPath); os.IsNotExist(err) { - continue - //return fmt.Errorf("scenarios '%s' doesn't exist in the hub and doesn't appear to be a custom one.", scenario) - } - - scenarioDirDest := fmt.Sprintf("%s/scenarios/", t.RuntimePath) - if err := os.MkdirAll(scenarioDirDest, os.ModePerm); err != nil { - return fmt.Errorf("unable to create folder '%s': %s", scenarioDirDest, err) + found, err := t.installScenarioCustomFrom(scenario, customPath) + if err != nil { + return err } - scenarioFileName := filepath.Base(customScenarioPath) - scenarioFileDest := filepath.Join(scenarioDirDest, scenarioFileName) - if err := Copy(customScenarioPath, scenarioFileDest); err != nil { - continue - //return fmt.Errorf("unable to copy scenario from '%s' to '%s': %s", customScenarioPath, scenarioFileDest, err) + if found { + return nil } - customScenarioExist = true - break - } - if !customScenarioExist { - return fmt.Errorf("couldn't find custom scenario '%s' in the following location: %+v", scenario, t.CustomItemsLocation) } - return nil + return fmt.Errorf("couldn't find custom scenario '%s' in the following location: %+v", scenario, t.CustomItemsLocation) } func (t *HubTestItem) installScenario(name string) error { From 97c441dab6c387d8aff77c2b07a256b0b0321f16 Mon Sep 17 00:00:00 2001 From: he2ss Date: Wed, 14 Feb 2024 12:26:42 +0100 Subject: [PATCH 12/20] implement highAvailability feature (#2506) * implement highAvailability feature --------- Co-authored-by: Marco Mariani --- pkg/apiserver/apic.go | 12 + pkg/apiserver/apic_metrics_test.go | 8 +- pkg/database/ent/client.go | 152 ++++++- pkg/database/ent/ent.go | 2 + pkg/database/ent/hook/hook.go | 12 + pkg/database/ent/lock.go | 117 ++++++ pkg/database/ent/lock/lock.go | 62 +++ pkg/database/ent/lock/where.go | 185 +++++++++ pkg/database/ent/lock_create.go | 215 ++++++++++ pkg/database/ent/lock_delete.go | 88 ++++ pkg/database/ent/lock_query.go | 526 ++++++++++++++++++++++++ pkg/database/ent/lock_update.go | 228 ++++++++++ pkg/database/ent/migrate/schema.go | 13 + pkg/database/ent/mutation.go | 382 +++++++++++++++++ pkg/database/ent/predicate/predicate.go | 3 + pkg/database/ent/runtime.go | 7 + pkg/database/ent/schema/lock.go | 22 + pkg/database/ent/tx.go | 3 + pkg/database/lock.go | 67 +++ 19 files changed, 2096 insertions(+), 8 deletions(-) create mode 100644 pkg/database/ent/lock.go create mode 100644 pkg/database/ent/lock/lock.go create mode 100644 pkg/database/ent/lock/where.go create mode 100644 pkg/database/ent/lock_create.go create mode 100644 pkg/database/ent/lock_delete.go create mode 100644 pkg/database/ent/lock_query.go create mode 100644 pkg/database/ent/lock_update.go create mode 100644 pkg/database/ent/schema/lock.go create mode 100644 pkg/database/lock.go diff --git a/pkg/apiserver/apic.go b/pkg/apiserver/apic.go index d0b205c254d..2fdb01144a0 100644 --- a/pkg/apiserver/apic.go +++ b/pkg/apiserver/apic.go @@ -633,6 +633,13 @@ func (a *apic) PullTop(forcePull bool) error { } } + log.Debug("Acquiring lock for pullCAPI") + err = a.dbClient.AcquirePullCAPILock() + if a.dbClient.IsLocked(err) { + log.Info("PullCAPI is already running, skipping") + return nil + } + log.Infof("Starting community-blocklist update") data, _, err := a.apiClient.Decisions.GetStreamV3(context.Background(), apiclient.DecisionsStreamOpts{Startup: a.startup}) @@ -684,6 +691,11 @@ func (a *apic) PullTop(forcePull bool) error { return fmt.Errorf("while updating blocklists: %w", err) } + log.Debug("Releasing lock for pullCAPI") + if err := a.dbClient.ReleasePullCAPILock(); err != nil { + return fmt.Errorf("while releasing lock: %w", err) + } + return nil } diff --git a/pkg/apiserver/apic_metrics_test.go b/pkg/apiserver/apic_metrics_test.go index 2bc0dd26966..529dd6c6839 100644 --- a/pkg/apiserver/apic_metrics_test.go +++ b/pkg/apiserver/apic_metrics_test.go @@ -26,15 +26,15 @@ func TestAPICSendMetrics(t *testing.T) { }{ { name: "basic", - duration: time.Millisecond * 60, - metricsInterval: time.Millisecond * 10, + duration: time.Millisecond * 120, + metricsInterval: time.Millisecond * 20, expectedCalls: 5, setUp: func(api *apic) {}, }, { name: "with some metrics", - duration: time.Millisecond * 60, - metricsInterval: time.Millisecond * 10, + duration: time.Millisecond * 120, + metricsInterval: time.Millisecond * 20, expectedCalls: 5, setUp: func(api *apic) { api.dbClient.Ent.Machine.Delete().ExecX(context.Background()) diff --git a/pkg/database/ent/client.go b/pkg/database/ent/client.go index 2761ff088b5..006d52ef9ba 100644 --- a/pkg/database/ent/client.go +++ b/pkg/database/ent/client.go @@ -20,6 +20,7 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/database/ent/configitem" "github.com/crowdsecurity/crowdsec/pkg/database/ent/decision" "github.com/crowdsecurity/crowdsec/pkg/database/ent/event" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/lock" "github.com/crowdsecurity/crowdsec/pkg/database/ent/machine" "github.com/crowdsecurity/crowdsec/pkg/database/ent/meta" ) @@ -39,6 +40,8 @@ type Client struct { Decision *DecisionClient // Event is the client for interacting with the Event builders. Event *EventClient + // Lock is the client for interacting with the Lock builders. + Lock *LockClient // Machine is the client for interacting with the Machine builders. Machine *MachineClient // Meta is the client for interacting with the Meta builders. @@ -61,6 +64,7 @@ func (c *Client) init() { c.ConfigItem = NewConfigItemClient(c.config) c.Decision = NewDecisionClient(c.config) c.Event = NewEventClient(c.config) + c.Lock = NewLockClient(c.config) c.Machine = NewMachineClient(c.config) c.Meta = NewMetaClient(c.config) } @@ -153,6 +157,7 @@ func (c *Client) Tx(ctx context.Context) (*Tx, error) { ConfigItem: NewConfigItemClient(cfg), Decision: NewDecisionClient(cfg), Event: NewEventClient(cfg), + Lock: NewLockClient(cfg), Machine: NewMachineClient(cfg), Meta: NewMetaClient(cfg), }, nil @@ -179,6 +184,7 @@ func (c *Client) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) ConfigItem: NewConfigItemClient(cfg), Decision: NewDecisionClient(cfg), Event: NewEventClient(cfg), + Lock: NewLockClient(cfg), Machine: NewMachineClient(cfg), Meta: NewMetaClient(cfg), }, nil @@ -210,7 +216,8 @@ func (c *Client) Close() error { // In order to add hooks to a specific client, call: `client.Node.Use(...)`. func (c *Client) Use(hooks ...Hook) { for _, n := range []interface{ Use(...Hook) }{ - c.Alert, c.Bouncer, c.ConfigItem, c.Decision, c.Event, c.Machine, c.Meta, + c.Alert, c.Bouncer, c.ConfigItem, c.Decision, c.Event, c.Lock, c.Machine, + c.Meta, } { n.Use(hooks...) } @@ -220,7 +227,8 @@ func (c *Client) Use(hooks ...Hook) { // In order to add interceptors to a specific client, call: `client.Node.Intercept(...)`. func (c *Client) Intercept(interceptors ...Interceptor) { for _, n := range []interface{ Intercept(...Interceptor) }{ - c.Alert, c.Bouncer, c.ConfigItem, c.Decision, c.Event, c.Machine, c.Meta, + c.Alert, c.Bouncer, c.ConfigItem, c.Decision, c.Event, c.Lock, c.Machine, + c.Meta, } { n.Intercept(interceptors...) } @@ -239,6 +247,8 @@ func (c *Client) Mutate(ctx context.Context, m Mutation) (Value, error) { return c.Decision.mutate(ctx, m) case *EventMutation: return c.Event.mutate(ctx, m) + case *LockMutation: + return c.Lock.mutate(ctx, m) case *MachineMutation: return c.Machine.mutate(ctx, m) case *MetaMutation: @@ -1009,6 +1019,139 @@ func (c *EventClient) mutate(ctx context.Context, m *EventMutation) (Value, erro } } +// LockClient is a client for the Lock schema. +type LockClient struct { + config +} + +// NewLockClient returns a client for the Lock from the given config. +func NewLockClient(c config) *LockClient { + return &LockClient{config: c} +} + +// Use adds a list of mutation hooks to the hooks stack. +// A call to `Use(f, g, h)` equals to `lock.Hooks(f(g(h())))`. +func (c *LockClient) Use(hooks ...Hook) { + c.hooks.Lock = append(c.hooks.Lock, hooks...) +} + +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `lock.Intercept(f(g(h())))`. +func (c *LockClient) Intercept(interceptors ...Interceptor) { + c.inters.Lock = append(c.inters.Lock, interceptors...) +} + +// Create returns a builder for creating a Lock entity. +func (c *LockClient) Create() *LockCreate { + mutation := newLockMutation(c.config, OpCreate) + return &LockCreate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// CreateBulk returns a builder for creating a bulk of Lock entities. +func (c *LockClient) CreateBulk(builders ...*LockCreate) *LockCreateBulk { + return &LockCreateBulk{config: c.config, builders: builders} +} + +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *LockClient) MapCreateBulk(slice any, setFunc func(*LockCreate, int)) *LockCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &LockCreateBulk{err: fmt.Errorf("calling to LockClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*LockCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &LockCreateBulk{config: c.config, builders: builders} +} + +// Update returns an update builder for Lock. +func (c *LockClient) Update() *LockUpdate { + mutation := newLockMutation(c.config, OpUpdate) + return &LockUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOne returns an update builder for the given entity. +func (c *LockClient) UpdateOne(l *Lock) *LockUpdateOne { + mutation := newLockMutation(c.config, OpUpdateOne, withLock(l)) + return &LockUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOneID returns an update builder for the given id. +func (c *LockClient) UpdateOneID(id int) *LockUpdateOne { + mutation := newLockMutation(c.config, OpUpdateOne, withLockID(id)) + return &LockUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// Delete returns a delete builder for Lock. +func (c *LockClient) Delete() *LockDelete { + mutation := newLockMutation(c.config, OpDelete) + return &LockDelete{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// DeleteOne returns a builder for deleting the given entity. +func (c *LockClient) DeleteOne(l *Lock) *LockDeleteOne { + return c.DeleteOneID(l.ID) +} + +// DeleteOneID returns a builder for deleting the given entity by its id. +func (c *LockClient) DeleteOneID(id int) *LockDeleteOne { + builder := c.Delete().Where(lock.ID(id)) + builder.mutation.id = &id + builder.mutation.op = OpDeleteOne + return &LockDeleteOne{builder} +} + +// Query returns a query builder for Lock. +func (c *LockClient) Query() *LockQuery { + return &LockQuery{ + config: c.config, + ctx: &QueryContext{Type: TypeLock}, + inters: c.Interceptors(), + } +} + +// Get returns a Lock entity by its id. +func (c *LockClient) Get(ctx context.Context, id int) (*Lock, error) { + return c.Query().Where(lock.ID(id)).Only(ctx) +} + +// GetX is like Get, but panics if an error occurs. +func (c *LockClient) GetX(ctx context.Context, id int) *Lock { + obj, err := c.Get(ctx, id) + if err != nil { + panic(err) + } + return obj +} + +// Hooks returns the client hooks. +func (c *LockClient) Hooks() []Hook { + return c.hooks.Lock +} + +// Interceptors returns the client interceptors. +func (c *LockClient) Interceptors() []Interceptor { + return c.inters.Lock +} + +func (c *LockClient) mutate(ctx context.Context, m *LockMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&LockCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&LockUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&LockUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&LockDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown Lock mutation op: %q", m.Op()) + } +} + // MachineClient is a client for the Machine schema. type MachineClient struct { config @@ -1310,9 +1453,10 @@ func (c *MetaClient) mutate(ctx context.Context, m *MetaMutation) (Value, error) // hooks and interceptors per client, for fast access. type ( hooks struct { - Alert, Bouncer, ConfigItem, Decision, Event, Machine, Meta []ent.Hook + Alert, Bouncer, ConfigItem, Decision, Event, Lock, Machine, Meta []ent.Hook } inters struct { - Alert, Bouncer, ConfigItem, Decision, Event, Machine, Meta []ent.Interceptor + Alert, Bouncer, ConfigItem, Decision, Event, Lock, Machine, + Meta []ent.Interceptor } ) diff --git a/pkg/database/ent/ent.go b/pkg/database/ent/ent.go index 393ce9f1869..cb98ee9301c 100644 --- a/pkg/database/ent/ent.go +++ b/pkg/database/ent/ent.go @@ -17,6 +17,7 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/database/ent/configitem" "github.com/crowdsecurity/crowdsec/pkg/database/ent/decision" "github.com/crowdsecurity/crowdsec/pkg/database/ent/event" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/lock" "github.com/crowdsecurity/crowdsec/pkg/database/ent/machine" "github.com/crowdsecurity/crowdsec/pkg/database/ent/meta" ) @@ -84,6 +85,7 @@ func checkColumn(table, column string) error { configitem.Table: configitem.ValidColumn, decision.Table: decision.ValidColumn, event.Table: event.ValidColumn, + lock.Table: lock.ValidColumn, machine.Table: machine.ValidColumn, meta.Table: meta.ValidColumn, }) diff --git a/pkg/database/ent/hook/hook.go b/pkg/database/ent/hook/hook.go index 7ec9c3ab1d8..fdc31539679 100644 --- a/pkg/database/ent/hook/hook.go +++ b/pkg/database/ent/hook/hook.go @@ -69,6 +69,18 @@ func (f EventFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.EventMutation", m) } +// The LockFunc type is an adapter to allow the use of ordinary +// function as Lock mutator. +type LockFunc func(context.Context, *ent.LockMutation) (ent.Value, error) + +// Mutate calls f(ctx, m). +func (f LockFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { + if mv, ok := m.(*ent.LockMutation); ok { + return f(ctx, mv) + } + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.LockMutation", m) +} + // The MachineFunc type is an adapter to allow the use of ordinary // function as Machine mutator. type MachineFunc func(context.Context, *ent.MachineMutation) (ent.Value, error) diff --git a/pkg/database/ent/lock.go b/pkg/database/ent/lock.go new file mode 100644 index 00000000000..85556a30644 --- /dev/null +++ b/pkg/database/ent/lock.go @@ -0,0 +1,117 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "fmt" + "strings" + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/lock" +) + +// Lock is the model entity for the Lock schema. +type Lock struct { + config `json:"-"` + // ID of the ent. + ID int `json:"id,omitempty"` + // Name holds the value of the "name" field. + Name string `json:"name"` + // CreatedAt holds the value of the "created_at" field. + CreatedAt time.Time `json:"created_at"` + selectValues sql.SelectValues +} + +// scanValues returns the types for scanning values from sql.Rows. +func (*Lock) scanValues(columns []string) ([]any, error) { + values := make([]any, len(columns)) + for i := range columns { + switch columns[i] { + case lock.FieldID: + values[i] = new(sql.NullInt64) + case lock.FieldName: + values[i] = new(sql.NullString) + case lock.FieldCreatedAt: + values[i] = new(sql.NullTime) + default: + values[i] = new(sql.UnknownType) + } + } + return values, nil +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the Lock fields. +func (l *Lock) assignValues(columns []string, values []any) error { + if m, n := len(values), len(columns); m < n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + for i := range columns { + switch columns[i] { + case lock.FieldID: + value, ok := values[i].(*sql.NullInt64) + if !ok { + return fmt.Errorf("unexpected type %T for field id", value) + } + l.ID = int(value.Int64) + case lock.FieldName: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field name", values[i]) + } else if value.Valid { + l.Name = value.String + } + case lock.FieldCreatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field created_at", values[i]) + } else if value.Valid { + l.CreatedAt = value.Time + } + default: + l.selectValues.Set(columns[i], values[i]) + } + } + return nil +} + +// Value returns the ent.Value that was dynamically selected and assigned to the Lock. +// This includes values selected through modifiers, order, etc. +func (l *Lock) Value(name string) (ent.Value, error) { + return l.selectValues.Get(name) +} + +// Update returns a builder for updating this Lock. +// Note that you need to call Lock.Unwrap() before calling this method if this Lock +// was returned from a transaction, and the transaction was committed or rolled back. +func (l *Lock) Update() *LockUpdateOne { + return NewLockClient(l.config).UpdateOne(l) +} + +// Unwrap unwraps the Lock entity that was returned from a transaction after it was closed, +// so that all future queries will be executed through the driver which created the transaction. +func (l *Lock) Unwrap() *Lock { + _tx, ok := l.config.driver.(*txDriver) + if !ok { + panic("ent: Lock is not a transactional entity") + } + l.config.driver = _tx.drv + return l +} + +// String implements the fmt.Stringer. +func (l *Lock) String() string { + var builder strings.Builder + builder.WriteString("Lock(") + builder.WriteString(fmt.Sprintf("id=%v, ", l.ID)) + builder.WriteString("name=") + builder.WriteString(l.Name) + builder.WriteString(", ") + builder.WriteString("created_at=") + builder.WriteString(l.CreatedAt.Format(time.ANSIC)) + builder.WriteByte(')') + return builder.String() +} + +// Locks is a parsable slice of Lock. +type Locks []*Lock diff --git a/pkg/database/ent/lock/lock.go b/pkg/database/ent/lock/lock.go new file mode 100644 index 00000000000..d0143470a75 --- /dev/null +++ b/pkg/database/ent/lock/lock.go @@ -0,0 +1,62 @@ +// Code generated by ent, DO NOT EDIT. + +package lock + +import ( + "time" + + "entgo.io/ent/dialect/sql" +) + +const ( + // Label holds the string label denoting the lock type in the database. + Label = "lock" + // FieldID holds the string denoting the id field in the database. + FieldID = "id" + // FieldName holds the string denoting the name field in the database. + FieldName = "name" + // FieldCreatedAt holds the string denoting the created_at field in the database. + FieldCreatedAt = "created_at" + // Table holds the table name of the lock in the database. + Table = "locks" +) + +// Columns holds all SQL columns for lock fields. +var Columns = []string{ + FieldID, + FieldName, + FieldCreatedAt, +} + +// ValidColumn reports if the column name is valid (part of the table columns). +func ValidColumn(column string) bool { + for i := range Columns { + if column == Columns[i] { + return true + } + } + return false +} + +var ( + // DefaultCreatedAt holds the default value on creation for the "created_at" field. + DefaultCreatedAt func() time.Time +) + +// OrderOption defines the ordering options for the Lock queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByName orders the results by the name field. +func ByName(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldName, opts...).ToFunc() +} + +// ByCreatedAt orders the results by the created_at field. +func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() +} diff --git a/pkg/database/ent/lock/where.go b/pkg/database/ent/lock/where.go new file mode 100644 index 00000000000..cf59362d203 --- /dev/null +++ b/pkg/database/ent/lock/where.go @@ -0,0 +1,185 @@ +// Code generated by ent, DO NOT EDIT. + +package lock + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/predicate" +) + +// ID filters vertices based on their ID field. +func ID(id int) predicate.Lock { + return predicate.Lock(sql.FieldEQ(FieldID, id)) +} + +// IDEQ applies the EQ predicate on the ID field. +func IDEQ(id int) predicate.Lock { + return predicate.Lock(sql.FieldEQ(FieldID, id)) +} + +// IDNEQ applies the NEQ predicate on the ID field. +func IDNEQ(id int) predicate.Lock { + return predicate.Lock(sql.FieldNEQ(FieldID, id)) +} + +// IDIn applies the In predicate on the ID field. +func IDIn(ids ...int) predicate.Lock { + return predicate.Lock(sql.FieldIn(FieldID, ids...)) +} + +// IDNotIn applies the NotIn predicate on the ID field. +func IDNotIn(ids ...int) predicate.Lock { + return predicate.Lock(sql.FieldNotIn(FieldID, ids...)) +} + +// IDGT applies the GT predicate on the ID field. +func IDGT(id int) predicate.Lock { + return predicate.Lock(sql.FieldGT(FieldID, id)) +} + +// IDGTE applies the GTE predicate on the ID field. +func IDGTE(id int) predicate.Lock { + return predicate.Lock(sql.FieldGTE(FieldID, id)) +} + +// IDLT applies the LT predicate on the ID field. +func IDLT(id int) predicate.Lock { + return predicate.Lock(sql.FieldLT(FieldID, id)) +} + +// IDLTE applies the LTE predicate on the ID field. +func IDLTE(id int) predicate.Lock { + return predicate.Lock(sql.FieldLTE(FieldID, id)) +} + +// Name applies equality check predicate on the "name" field. It's identical to NameEQ. +func Name(v string) predicate.Lock { + return predicate.Lock(sql.FieldEQ(FieldName, v)) +} + +// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. +func CreatedAt(v time.Time) predicate.Lock { + return predicate.Lock(sql.FieldEQ(FieldCreatedAt, v)) +} + +// NameEQ applies the EQ predicate on the "name" field. +func NameEQ(v string) predicate.Lock { + return predicate.Lock(sql.FieldEQ(FieldName, v)) +} + +// NameNEQ applies the NEQ predicate on the "name" field. +func NameNEQ(v string) predicate.Lock { + return predicate.Lock(sql.FieldNEQ(FieldName, v)) +} + +// NameIn applies the In predicate on the "name" field. +func NameIn(vs ...string) predicate.Lock { + return predicate.Lock(sql.FieldIn(FieldName, vs...)) +} + +// NameNotIn applies the NotIn predicate on the "name" field. +func NameNotIn(vs ...string) predicate.Lock { + return predicate.Lock(sql.FieldNotIn(FieldName, vs...)) +} + +// NameGT applies the GT predicate on the "name" field. +func NameGT(v string) predicate.Lock { + return predicate.Lock(sql.FieldGT(FieldName, v)) +} + +// NameGTE applies the GTE predicate on the "name" field. +func NameGTE(v string) predicate.Lock { + return predicate.Lock(sql.FieldGTE(FieldName, v)) +} + +// NameLT applies the LT predicate on the "name" field. +func NameLT(v string) predicate.Lock { + return predicate.Lock(sql.FieldLT(FieldName, v)) +} + +// NameLTE applies the LTE predicate on the "name" field. +func NameLTE(v string) predicate.Lock { + return predicate.Lock(sql.FieldLTE(FieldName, v)) +} + +// NameContains applies the Contains predicate on the "name" field. +func NameContains(v string) predicate.Lock { + return predicate.Lock(sql.FieldContains(FieldName, v)) +} + +// NameHasPrefix applies the HasPrefix predicate on the "name" field. +func NameHasPrefix(v string) predicate.Lock { + return predicate.Lock(sql.FieldHasPrefix(FieldName, v)) +} + +// NameHasSuffix applies the HasSuffix predicate on the "name" field. +func NameHasSuffix(v string) predicate.Lock { + return predicate.Lock(sql.FieldHasSuffix(FieldName, v)) +} + +// NameEqualFold applies the EqualFold predicate on the "name" field. +func NameEqualFold(v string) predicate.Lock { + return predicate.Lock(sql.FieldEqualFold(FieldName, v)) +} + +// NameContainsFold applies the ContainsFold predicate on the "name" field. +func NameContainsFold(v string) predicate.Lock { + return predicate.Lock(sql.FieldContainsFold(FieldName, v)) +} + +// CreatedAtEQ applies the EQ predicate on the "created_at" field. +func CreatedAtEQ(v time.Time) predicate.Lock { + return predicate.Lock(sql.FieldEQ(FieldCreatedAt, v)) +} + +// CreatedAtNEQ applies the NEQ predicate on the "created_at" field. +func CreatedAtNEQ(v time.Time) predicate.Lock { + return predicate.Lock(sql.FieldNEQ(FieldCreatedAt, v)) +} + +// CreatedAtIn applies the In predicate on the "created_at" field. +func CreatedAtIn(vs ...time.Time) predicate.Lock { + return predicate.Lock(sql.FieldIn(FieldCreatedAt, vs...)) +} + +// CreatedAtNotIn applies the NotIn predicate on the "created_at" field. +func CreatedAtNotIn(vs ...time.Time) predicate.Lock { + return predicate.Lock(sql.FieldNotIn(FieldCreatedAt, vs...)) +} + +// CreatedAtGT applies the GT predicate on the "created_at" field. +func CreatedAtGT(v time.Time) predicate.Lock { + return predicate.Lock(sql.FieldGT(FieldCreatedAt, v)) +} + +// CreatedAtGTE applies the GTE predicate on the "created_at" field. +func CreatedAtGTE(v time.Time) predicate.Lock { + return predicate.Lock(sql.FieldGTE(FieldCreatedAt, v)) +} + +// CreatedAtLT applies the LT predicate on the "created_at" field. +func CreatedAtLT(v time.Time) predicate.Lock { + return predicate.Lock(sql.FieldLT(FieldCreatedAt, v)) +} + +// CreatedAtLTE applies the LTE predicate on the "created_at" field. +func CreatedAtLTE(v time.Time) predicate.Lock { + return predicate.Lock(sql.FieldLTE(FieldCreatedAt, v)) +} + +// And groups predicates with the AND operator between them. +func And(predicates ...predicate.Lock) predicate.Lock { + return predicate.Lock(sql.AndPredicates(predicates...)) +} + +// Or groups predicates with the OR operator between them. +func Or(predicates ...predicate.Lock) predicate.Lock { + return predicate.Lock(sql.OrPredicates(predicates...)) +} + +// Not applies the not operator on the given predicate. +func Not(p predicate.Lock) predicate.Lock { + return predicate.Lock(sql.NotPredicates(p)) +} diff --git a/pkg/database/ent/lock_create.go b/pkg/database/ent/lock_create.go new file mode 100644 index 00000000000..e2c29c88324 --- /dev/null +++ b/pkg/database/ent/lock_create.go @@ -0,0 +1,215 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/lock" +) + +// LockCreate is the builder for creating a Lock entity. +type LockCreate struct { + config + mutation *LockMutation + hooks []Hook +} + +// SetName sets the "name" field. +func (lc *LockCreate) SetName(s string) *LockCreate { + lc.mutation.SetName(s) + return lc +} + +// SetCreatedAt sets the "created_at" field. +func (lc *LockCreate) SetCreatedAt(t time.Time) *LockCreate { + lc.mutation.SetCreatedAt(t) + return lc +} + +// SetNillableCreatedAt sets the "created_at" field if the given value is not nil. +func (lc *LockCreate) SetNillableCreatedAt(t *time.Time) *LockCreate { + if t != nil { + lc.SetCreatedAt(*t) + } + return lc +} + +// Mutation returns the LockMutation object of the builder. +func (lc *LockCreate) Mutation() *LockMutation { + return lc.mutation +} + +// Save creates the Lock in the database. +func (lc *LockCreate) Save(ctx context.Context) (*Lock, error) { + lc.defaults() + return withHooks(ctx, lc.sqlSave, lc.mutation, lc.hooks) +} + +// SaveX calls Save and panics if Save returns an error. +func (lc *LockCreate) SaveX(ctx context.Context) *Lock { + v, err := lc.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (lc *LockCreate) Exec(ctx context.Context) error { + _, err := lc.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (lc *LockCreate) ExecX(ctx context.Context) { + if err := lc.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (lc *LockCreate) defaults() { + if _, ok := lc.mutation.CreatedAt(); !ok { + v := lock.DefaultCreatedAt() + lc.mutation.SetCreatedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (lc *LockCreate) check() error { + if _, ok := lc.mutation.Name(); !ok { + return &ValidationError{Name: "name", err: errors.New(`ent: missing required field "Lock.name"`)} + } + if _, ok := lc.mutation.CreatedAt(); !ok { + return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "Lock.created_at"`)} + } + return nil +} + +func (lc *LockCreate) sqlSave(ctx context.Context) (*Lock, error) { + if err := lc.check(); err != nil { + return nil, err + } + _node, _spec := lc.createSpec() + if err := sqlgraph.CreateNode(ctx, lc.driver, _spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + id := _spec.ID.Value.(int64) + _node.ID = int(id) + lc.mutation.id = &_node.ID + lc.mutation.done = true + return _node, nil +} + +func (lc *LockCreate) createSpec() (*Lock, *sqlgraph.CreateSpec) { + var ( + _node = &Lock{config: lc.config} + _spec = sqlgraph.NewCreateSpec(lock.Table, sqlgraph.NewFieldSpec(lock.FieldID, field.TypeInt)) + ) + if value, ok := lc.mutation.Name(); ok { + _spec.SetField(lock.FieldName, field.TypeString, value) + _node.Name = value + } + if value, ok := lc.mutation.CreatedAt(); ok { + _spec.SetField(lock.FieldCreatedAt, field.TypeTime, value) + _node.CreatedAt = value + } + return _node, _spec +} + +// LockCreateBulk is the builder for creating many Lock entities in bulk. +type LockCreateBulk struct { + config + err error + builders []*LockCreate +} + +// Save creates the Lock entities in the database. +func (lcb *LockCreateBulk) Save(ctx context.Context) ([]*Lock, error) { + if lcb.err != nil { + return nil, lcb.err + } + specs := make([]*sqlgraph.CreateSpec, len(lcb.builders)) + nodes := make([]*Lock, len(lcb.builders)) + mutators := make([]Mutator, len(lcb.builders)) + for i := range lcb.builders { + func(i int, root context.Context) { + builder := lcb.builders[i] + builder.defaults() + var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { + mutation, ok := m.(*LockMutation) + if !ok { + return nil, fmt.Errorf("unexpected mutation type %T", m) + } + if err := builder.check(); err != nil { + return nil, err + } + builder.mutation = mutation + var err error + nodes[i], specs[i] = builder.createSpec() + if i < len(mutators)-1 { + _, err = mutators[i+1].Mutate(root, lcb.builders[i+1].mutation) + } else { + spec := &sqlgraph.BatchCreateSpec{Nodes: specs} + // Invoke the actual operation on the latest mutation in the chain. + if err = sqlgraph.BatchCreate(ctx, lcb.driver, spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + } + } + if err != nil { + return nil, err + } + mutation.id = &nodes[i].ID + if specs[i].ID.Value != nil { + id := specs[i].ID.Value.(int64) + nodes[i].ID = int(id) + } + mutation.done = true + return nodes[i], nil + }) + for i := len(builder.hooks) - 1; i >= 0; i-- { + mut = builder.hooks[i](mut) + } + mutators[i] = mut + }(i, ctx) + } + if len(mutators) > 0 { + if _, err := mutators[0].Mutate(ctx, lcb.builders[0].mutation); err != nil { + return nil, err + } + } + return nodes, nil +} + +// SaveX is like Save, but panics if an error occurs. +func (lcb *LockCreateBulk) SaveX(ctx context.Context) []*Lock { + v, err := lcb.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (lcb *LockCreateBulk) Exec(ctx context.Context) error { + _, err := lcb.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (lcb *LockCreateBulk) ExecX(ctx context.Context) { + if err := lcb.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/pkg/database/ent/lock_delete.go b/pkg/database/ent/lock_delete.go new file mode 100644 index 00000000000..2275c608f75 --- /dev/null +++ b/pkg/database/ent/lock_delete.go @@ -0,0 +1,88 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/lock" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/predicate" +) + +// LockDelete is the builder for deleting a Lock entity. +type LockDelete struct { + config + hooks []Hook + mutation *LockMutation +} + +// Where appends a list predicates to the LockDelete builder. +func (ld *LockDelete) Where(ps ...predicate.Lock) *LockDelete { + ld.mutation.Where(ps...) + return ld +} + +// Exec executes the deletion query and returns how many vertices were deleted. +func (ld *LockDelete) Exec(ctx context.Context) (int, error) { + return withHooks(ctx, ld.sqlExec, ld.mutation, ld.hooks) +} + +// ExecX is like Exec, but panics if an error occurs. +func (ld *LockDelete) ExecX(ctx context.Context) int { + n, err := ld.Exec(ctx) + if err != nil { + panic(err) + } + return n +} + +func (ld *LockDelete) sqlExec(ctx context.Context) (int, error) { + _spec := sqlgraph.NewDeleteSpec(lock.Table, sqlgraph.NewFieldSpec(lock.FieldID, field.TypeInt)) + if ps := ld.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + affected, err := sqlgraph.DeleteNodes(ctx, ld.driver, _spec) + if err != nil && sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + ld.mutation.done = true + return affected, err +} + +// LockDeleteOne is the builder for deleting a single Lock entity. +type LockDeleteOne struct { + ld *LockDelete +} + +// Where appends a list predicates to the LockDelete builder. +func (ldo *LockDeleteOne) Where(ps ...predicate.Lock) *LockDeleteOne { + ldo.ld.mutation.Where(ps...) + return ldo +} + +// Exec executes the deletion query. +func (ldo *LockDeleteOne) Exec(ctx context.Context) error { + n, err := ldo.ld.Exec(ctx) + switch { + case err != nil: + return err + case n == 0: + return &NotFoundError{lock.Label} + default: + return nil + } +} + +// ExecX is like Exec, but panics if an error occurs. +func (ldo *LockDeleteOne) ExecX(ctx context.Context) { + if err := ldo.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/pkg/database/ent/lock_query.go b/pkg/database/ent/lock_query.go new file mode 100644 index 00000000000..75e5da48a94 --- /dev/null +++ b/pkg/database/ent/lock_query.go @@ -0,0 +1,526 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "fmt" + "math" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/lock" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/predicate" +) + +// LockQuery is the builder for querying Lock entities. +type LockQuery struct { + config + ctx *QueryContext + order []lock.OrderOption + inters []Interceptor + predicates []predicate.Lock + // intermediate query (i.e. traversal path). + sql *sql.Selector + path func(context.Context) (*sql.Selector, error) +} + +// Where adds a new predicate for the LockQuery builder. +func (lq *LockQuery) Where(ps ...predicate.Lock) *LockQuery { + lq.predicates = append(lq.predicates, ps...) + return lq +} + +// Limit the number of records to be returned by this query. +func (lq *LockQuery) Limit(limit int) *LockQuery { + lq.ctx.Limit = &limit + return lq +} + +// Offset to start from. +func (lq *LockQuery) Offset(offset int) *LockQuery { + lq.ctx.Offset = &offset + return lq +} + +// Unique configures the query builder to filter duplicate records on query. +// By default, unique is set to true, and can be disabled using this method. +func (lq *LockQuery) Unique(unique bool) *LockQuery { + lq.ctx.Unique = &unique + return lq +} + +// Order specifies how the records should be ordered. +func (lq *LockQuery) Order(o ...lock.OrderOption) *LockQuery { + lq.order = append(lq.order, o...) + return lq +} + +// First returns the first Lock entity from the query. +// Returns a *NotFoundError when no Lock was found. +func (lq *LockQuery) First(ctx context.Context) (*Lock, error) { + nodes, err := lq.Limit(1).All(setContextOp(ctx, lq.ctx, "First")) + if err != nil { + return nil, err + } + if len(nodes) == 0 { + return nil, &NotFoundError{lock.Label} + } + return nodes[0], nil +} + +// FirstX is like First, but panics if an error occurs. +func (lq *LockQuery) FirstX(ctx context.Context) *Lock { + node, err := lq.First(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return node +} + +// FirstID returns the first Lock ID from the query. +// Returns a *NotFoundError when no Lock ID was found. +func (lq *LockQuery) FirstID(ctx context.Context) (id int, err error) { + var ids []int + if ids, err = lq.Limit(1).IDs(setContextOp(ctx, lq.ctx, "FirstID")); err != nil { + return + } + if len(ids) == 0 { + err = &NotFoundError{lock.Label} + return + } + return ids[0], nil +} + +// FirstIDX is like FirstID, but panics if an error occurs. +func (lq *LockQuery) FirstIDX(ctx context.Context) int { + id, err := lq.FirstID(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return id +} + +// Only returns a single Lock entity found by the query, ensuring it only returns one. +// Returns a *NotSingularError when more than one Lock entity is found. +// Returns a *NotFoundError when no Lock entities are found. +func (lq *LockQuery) Only(ctx context.Context) (*Lock, error) { + nodes, err := lq.Limit(2).All(setContextOp(ctx, lq.ctx, "Only")) + if err != nil { + return nil, err + } + switch len(nodes) { + case 1: + return nodes[0], nil + case 0: + return nil, &NotFoundError{lock.Label} + default: + return nil, &NotSingularError{lock.Label} + } +} + +// OnlyX is like Only, but panics if an error occurs. +func (lq *LockQuery) OnlyX(ctx context.Context) *Lock { + node, err := lq.Only(ctx) + if err != nil { + panic(err) + } + return node +} + +// OnlyID is like Only, but returns the only Lock ID in the query. +// Returns a *NotSingularError when more than one Lock ID is found. +// Returns a *NotFoundError when no entities are found. +func (lq *LockQuery) OnlyID(ctx context.Context) (id int, err error) { + var ids []int + if ids, err = lq.Limit(2).IDs(setContextOp(ctx, lq.ctx, "OnlyID")); err != nil { + return + } + switch len(ids) { + case 1: + id = ids[0] + case 0: + err = &NotFoundError{lock.Label} + default: + err = &NotSingularError{lock.Label} + } + return +} + +// OnlyIDX is like OnlyID, but panics if an error occurs. +func (lq *LockQuery) OnlyIDX(ctx context.Context) int { + id, err := lq.OnlyID(ctx) + if err != nil { + panic(err) + } + return id +} + +// All executes the query and returns a list of Locks. +func (lq *LockQuery) All(ctx context.Context) ([]*Lock, error) { + ctx = setContextOp(ctx, lq.ctx, "All") + if err := lq.prepareQuery(ctx); err != nil { + return nil, err + } + qr := querierAll[[]*Lock, *LockQuery]() + return withInterceptors[[]*Lock](ctx, lq, qr, lq.inters) +} + +// AllX is like All, but panics if an error occurs. +func (lq *LockQuery) AllX(ctx context.Context) []*Lock { + nodes, err := lq.All(ctx) + if err != nil { + panic(err) + } + return nodes +} + +// IDs executes the query and returns a list of Lock IDs. +func (lq *LockQuery) IDs(ctx context.Context) (ids []int, err error) { + if lq.ctx.Unique == nil && lq.path != nil { + lq.Unique(true) + } + ctx = setContextOp(ctx, lq.ctx, "IDs") + if err = lq.Select(lock.FieldID).Scan(ctx, &ids); err != nil { + return nil, err + } + return ids, nil +} + +// IDsX is like IDs, but panics if an error occurs. +func (lq *LockQuery) IDsX(ctx context.Context) []int { + ids, err := lq.IDs(ctx) + if err != nil { + panic(err) + } + return ids +} + +// Count returns the count of the given query. +func (lq *LockQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, lq.ctx, "Count") + if err := lq.prepareQuery(ctx); err != nil { + return 0, err + } + return withInterceptors[int](ctx, lq, querierCount[*LockQuery](), lq.inters) +} + +// CountX is like Count, but panics if an error occurs. +func (lq *LockQuery) CountX(ctx context.Context) int { + count, err := lq.Count(ctx) + if err != nil { + panic(err) + } + return count +} + +// Exist returns true if the query has elements in the graph. +func (lq *LockQuery) Exist(ctx context.Context) (bool, error) { + ctx = setContextOp(ctx, lq.ctx, "Exist") + switch _, err := lq.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil + } +} + +// ExistX is like Exist, but panics if an error occurs. +func (lq *LockQuery) ExistX(ctx context.Context) bool { + exist, err := lq.Exist(ctx) + if err != nil { + panic(err) + } + return exist +} + +// Clone returns a duplicate of the LockQuery builder, including all associated steps. It can be +// used to prepare common query builders and use them differently after the clone is made. +func (lq *LockQuery) Clone() *LockQuery { + if lq == nil { + return nil + } + return &LockQuery{ + config: lq.config, + ctx: lq.ctx.Clone(), + order: append([]lock.OrderOption{}, lq.order...), + inters: append([]Interceptor{}, lq.inters...), + predicates: append([]predicate.Lock{}, lq.predicates...), + // clone intermediate query. + sql: lq.sql.Clone(), + path: lq.path, + } +} + +// GroupBy is used to group vertices by one or more fields/columns. +// It is often used with aggregate functions, like: count, max, mean, min, sum. +// +// Example: +// +// var v []struct { +// Name string `json:"name"` +// Count int `json:"count,omitempty"` +// } +// +// client.Lock.Query(). +// GroupBy(lock.FieldName). +// Aggregate(ent.Count()). +// Scan(ctx, &v) +func (lq *LockQuery) GroupBy(field string, fields ...string) *LockGroupBy { + lq.ctx.Fields = append([]string{field}, fields...) + grbuild := &LockGroupBy{build: lq} + grbuild.flds = &lq.ctx.Fields + grbuild.label = lock.Label + grbuild.scan = grbuild.Scan + return grbuild +} + +// Select allows the selection one or more fields/columns for the given query, +// instead of selecting all fields in the entity. +// +// Example: +// +// var v []struct { +// Name string `json:"name"` +// } +// +// client.Lock.Query(). +// Select(lock.FieldName). +// Scan(ctx, &v) +func (lq *LockQuery) Select(fields ...string) *LockSelect { + lq.ctx.Fields = append(lq.ctx.Fields, fields...) + sbuild := &LockSelect{LockQuery: lq} + sbuild.label = lock.Label + sbuild.flds, sbuild.scan = &lq.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a LockSelect configured with the given aggregations. +func (lq *LockQuery) Aggregate(fns ...AggregateFunc) *LockSelect { + return lq.Select().Aggregate(fns...) +} + +func (lq *LockQuery) prepareQuery(ctx context.Context) error { + for _, inter := range lq.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, lq); err != nil { + return err + } + } + } + for _, f := range lq.ctx.Fields { + if !lock.ValidColumn(f) { + return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + } + if lq.path != nil { + prev, err := lq.path(ctx) + if err != nil { + return err + } + lq.sql = prev + } + return nil +} + +func (lq *LockQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Lock, error) { + var ( + nodes = []*Lock{} + _spec = lq.querySpec() + ) + _spec.ScanValues = func(columns []string) ([]any, error) { + return (*Lock).scanValues(nil, columns) + } + _spec.Assign = func(columns []string, values []any) error { + node := &Lock{config: lq.config} + nodes = append(nodes, node) + return node.assignValues(columns, values) + } + for i := range hooks { + hooks[i](ctx, _spec) + } + if err := sqlgraph.QueryNodes(ctx, lq.driver, _spec); err != nil { + return nil, err + } + if len(nodes) == 0 { + return nodes, nil + } + return nodes, nil +} + +func (lq *LockQuery) sqlCount(ctx context.Context) (int, error) { + _spec := lq.querySpec() + _spec.Node.Columns = lq.ctx.Fields + if len(lq.ctx.Fields) > 0 { + _spec.Unique = lq.ctx.Unique != nil && *lq.ctx.Unique + } + return sqlgraph.CountNodes(ctx, lq.driver, _spec) +} + +func (lq *LockQuery) querySpec() *sqlgraph.QuerySpec { + _spec := sqlgraph.NewQuerySpec(lock.Table, lock.Columns, sqlgraph.NewFieldSpec(lock.FieldID, field.TypeInt)) + _spec.From = lq.sql + if unique := lq.ctx.Unique; unique != nil { + _spec.Unique = *unique + } else if lq.path != nil { + _spec.Unique = true + } + if fields := lq.ctx.Fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, lock.FieldID) + for i := range fields { + if fields[i] != lock.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, fields[i]) + } + } + } + if ps := lq.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := lq.ctx.Limit; limit != nil { + _spec.Limit = *limit + } + if offset := lq.ctx.Offset; offset != nil { + _spec.Offset = *offset + } + if ps := lq.order; len(ps) > 0 { + _spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return _spec +} + +func (lq *LockQuery) sqlQuery(ctx context.Context) *sql.Selector { + builder := sql.Dialect(lq.driver.Dialect()) + t1 := builder.Table(lock.Table) + columns := lq.ctx.Fields + if len(columns) == 0 { + columns = lock.Columns + } + selector := builder.Select(t1.Columns(columns...)...).From(t1) + if lq.sql != nil { + selector = lq.sql + selector.Select(selector.Columns(columns...)...) + } + if lq.ctx.Unique != nil && *lq.ctx.Unique { + selector.Distinct() + } + for _, p := range lq.predicates { + p(selector) + } + for _, p := range lq.order { + p(selector) + } + if offset := lq.ctx.Offset; offset != nil { + // limit is mandatory for offset clause. We start + // with default value, and override it below if needed. + selector.Offset(*offset).Limit(math.MaxInt32) + } + if limit := lq.ctx.Limit; limit != nil { + selector.Limit(*limit) + } + return selector +} + +// LockGroupBy is the group-by builder for Lock entities. +type LockGroupBy struct { + selector + build *LockQuery +} + +// Aggregate adds the given aggregation functions to the group-by query. +func (lgb *LockGroupBy) Aggregate(fns ...AggregateFunc) *LockGroupBy { + lgb.fns = append(lgb.fns, fns...) + return lgb +} + +// Scan applies the selector query and scans the result into the given value. +func (lgb *LockGroupBy) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, lgb.build.ctx, "GroupBy") + if err := lgb.build.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*LockQuery, *LockGroupBy](ctx, lgb.build, lgb, lgb.build.inters, v) +} + +func (lgb *LockGroupBy) sqlScan(ctx context.Context, root *LockQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(lgb.fns)) + for _, fn := range lgb.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*lgb.flds)+len(lgb.fns)) + for _, f := range *lgb.flds { + columns = append(columns, selector.C(f)) + } + columns = append(columns, aggregation...) + selector.Select(columns...) + } + selector.GroupBy(selector.Columns(*lgb.flds...)...) + if err := selector.Err(); err != nil { + return err + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := lgb.build.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} + +// LockSelect is the builder for selecting fields of Lock entities. +type LockSelect struct { + *LockQuery + selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (ls *LockSelect) Aggregate(fns ...AggregateFunc) *LockSelect { + ls.fns = append(ls.fns, fns...) + return ls +} + +// Scan applies the selector query and scans the result into the given value. +func (ls *LockSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, ls.ctx, "Select") + if err := ls.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*LockQuery, *LockSelect](ctx, ls.LockQuery, ls, ls.inters, v) +} + +func (ls *LockSelect) sqlScan(ctx context.Context, root *LockQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(ls.fns)) + for _, fn := range ls.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*ls.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := ls.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} diff --git a/pkg/database/ent/lock_update.go b/pkg/database/ent/lock_update.go new file mode 100644 index 00000000000..f4deda6e3a8 --- /dev/null +++ b/pkg/database/ent/lock_update.go @@ -0,0 +1,228 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/lock" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/predicate" +) + +// LockUpdate is the builder for updating Lock entities. +type LockUpdate struct { + config + hooks []Hook + mutation *LockMutation +} + +// Where appends a list predicates to the LockUpdate builder. +func (lu *LockUpdate) Where(ps ...predicate.Lock) *LockUpdate { + lu.mutation.Where(ps...) + return lu +} + +// SetName sets the "name" field. +func (lu *LockUpdate) SetName(s string) *LockUpdate { + lu.mutation.SetName(s) + return lu +} + +// SetCreatedAt sets the "created_at" field. +func (lu *LockUpdate) SetCreatedAt(t time.Time) *LockUpdate { + lu.mutation.SetCreatedAt(t) + return lu +} + +// SetNillableCreatedAt sets the "created_at" field if the given value is not nil. +func (lu *LockUpdate) SetNillableCreatedAt(t *time.Time) *LockUpdate { + if t != nil { + lu.SetCreatedAt(*t) + } + return lu +} + +// Mutation returns the LockMutation object of the builder. +func (lu *LockUpdate) Mutation() *LockMutation { + return lu.mutation +} + +// Save executes the query and returns the number of nodes affected by the update operation. +func (lu *LockUpdate) Save(ctx context.Context) (int, error) { + return withHooks(ctx, lu.sqlSave, lu.mutation, lu.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (lu *LockUpdate) SaveX(ctx context.Context) int { + affected, err := lu.Save(ctx) + if err != nil { + panic(err) + } + return affected +} + +// Exec executes the query. +func (lu *LockUpdate) Exec(ctx context.Context) error { + _, err := lu.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (lu *LockUpdate) ExecX(ctx context.Context) { + if err := lu.Exec(ctx); err != nil { + panic(err) + } +} + +func (lu *LockUpdate) sqlSave(ctx context.Context) (n int, err error) { + _spec := sqlgraph.NewUpdateSpec(lock.Table, lock.Columns, sqlgraph.NewFieldSpec(lock.FieldID, field.TypeInt)) + if ps := lu.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := lu.mutation.Name(); ok { + _spec.SetField(lock.FieldName, field.TypeString, value) + } + if value, ok := lu.mutation.CreatedAt(); ok { + _spec.SetField(lock.FieldCreatedAt, field.TypeTime, value) + } + if n, err = sqlgraph.UpdateNodes(ctx, lu.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{lock.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return 0, err + } + lu.mutation.done = true + return n, nil +} + +// LockUpdateOne is the builder for updating a single Lock entity. +type LockUpdateOne struct { + config + fields []string + hooks []Hook + mutation *LockMutation +} + +// SetName sets the "name" field. +func (luo *LockUpdateOne) SetName(s string) *LockUpdateOne { + luo.mutation.SetName(s) + return luo +} + +// SetCreatedAt sets the "created_at" field. +func (luo *LockUpdateOne) SetCreatedAt(t time.Time) *LockUpdateOne { + luo.mutation.SetCreatedAt(t) + return luo +} + +// SetNillableCreatedAt sets the "created_at" field if the given value is not nil. +func (luo *LockUpdateOne) SetNillableCreatedAt(t *time.Time) *LockUpdateOne { + if t != nil { + luo.SetCreatedAt(*t) + } + return luo +} + +// Mutation returns the LockMutation object of the builder. +func (luo *LockUpdateOne) Mutation() *LockMutation { + return luo.mutation +} + +// Where appends a list predicates to the LockUpdate builder. +func (luo *LockUpdateOne) Where(ps ...predicate.Lock) *LockUpdateOne { + luo.mutation.Where(ps...) + return luo +} + +// Select allows selecting one or more fields (columns) of the returned entity. +// The default is selecting all fields defined in the entity schema. +func (luo *LockUpdateOne) Select(field string, fields ...string) *LockUpdateOne { + luo.fields = append([]string{field}, fields...) + return luo +} + +// Save executes the query and returns the updated Lock entity. +func (luo *LockUpdateOne) Save(ctx context.Context) (*Lock, error) { + return withHooks(ctx, luo.sqlSave, luo.mutation, luo.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (luo *LockUpdateOne) SaveX(ctx context.Context) *Lock { + node, err := luo.Save(ctx) + if err != nil { + panic(err) + } + return node +} + +// Exec executes the query on the entity. +func (luo *LockUpdateOne) Exec(ctx context.Context) error { + _, err := luo.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (luo *LockUpdateOne) ExecX(ctx context.Context) { + if err := luo.Exec(ctx); err != nil { + panic(err) + } +} + +func (luo *LockUpdateOne) sqlSave(ctx context.Context) (_node *Lock, err error) { + _spec := sqlgraph.NewUpdateSpec(lock.Table, lock.Columns, sqlgraph.NewFieldSpec(lock.FieldID, field.TypeInt)) + id, ok := luo.mutation.ID() + if !ok { + return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "Lock.id" for update`)} + } + _spec.Node.ID.Value = id + if fields := luo.fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, lock.FieldID) + for _, f := range fields { + if !lock.ValidColumn(f) { + return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + if f != lock.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, f) + } + } + } + if ps := luo.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := luo.mutation.Name(); ok { + _spec.SetField(lock.FieldName, field.TypeString, value) + } + if value, ok := luo.mutation.CreatedAt(); ok { + _spec.SetField(lock.FieldCreatedAt, field.TypeTime, value) + } + _node = &Lock{config: luo.config} + _spec.Assign = _node.assignValues + _spec.ScanValues = _node.scanValues + if err = sqlgraph.UpdateNode(ctx, luo.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{lock.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + luo.mutation.done = true + return _node, nil +} diff --git a/pkg/database/ent/migrate/schema.go b/pkg/database/ent/migrate/schema.go index 375fd4e784a..c3ffed42239 100644 --- a/pkg/database/ent/migrate/schema.go +++ b/pkg/database/ent/migrate/schema.go @@ -178,6 +178,18 @@ var ( }, }, } + // LocksColumns holds the columns for the "locks" table. + LocksColumns = []*schema.Column{ + {Name: "id", Type: field.TypeInt, Increment: true}, + {Name: "name", Type: field.TypeString, Unique: true}, + {Name: "created_at", Type: field.TypeTime}, + } + // LocksTable holds the schema information for the "locks" table. + LocksTable = &schema.Table{ + Name: "locks", + Columns: LocksColumns, + PrimaryKey: []*schema.Column{LocksColumns[0]}, + } // MachinesColumns holds the columns for the "machines" table. MachinesColumns = []*schema.Column{ {Name: "id", Type: field.TypeInt, Increment: true}, @@ -237,6 +249,7 @@ var ( ConfigItemsTable, DecisionsTable, EventsTable, + LocksTable, MachinesTable, MetaTable, } diff --git a/pkg/database/ent/mutation.go b/pkg/database/ent/mutation.go index c5808d0d9b8..365824de739 100644 --- a/pkg/database/ent/mutation.go +++ b/pkg/database/ent/mutation.go @@ -16,6 +16,7 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/database/ent/configitem" "github.com/crowdsecurity/crowdsec/pkg/database/ent/decision" "github.com/crowdsecurity/crowdsec/pkg/database/ent/event" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/lock" "github.com/crowdsecurity/crowdsec/pkg/database/ent/machine" "github.com/crowdsecurity/crowdsec/pkg/database/ent/meta" "github.com/crowdsecurity/crowdsec/pkg/database/ent/predicate" @@ -35,6 +36,7 @@ const ( TypeConfigItem = "ConfigItem" TypeDecision = "Decision" TypeEvent = "Event" + TypeLock = "Lock" TypeMachine = "Machine" TypeMeta = "Meta" ) @@ -6165,6 +6167,386 @@ func (m *EventMutation) ResetEdge(name string) error { return fmt.Errorf("unknown Event edge %s", name) } +// LockMutation represents an operation that mutates the Lock nodes in the graph. +type LockMutation struct { + config + op Op + typ string + id *int + name *string + created_at *time.Time + clearedFields map[string]struct{} + done bool + oldValue func(context.Context) (*Lock, error) + predicates []predicate.Lock +} + +var _ ent.Mutation = (*LockMutation)(nil) + +// lockOption allows management of the mutation configuration using functional options. +type lockOption func(*LockMutation) + +// newLockMutation creates new mutation for the Lock entity. +func newLockMutation(c config, op Op, opts ...lockOption) *LockMutation { + m := &LockMutation{ + config: c, + op: op, + typ: TypeLock, + clearedFields: make(map[string]struct{}), + } + for _, opt := range opts { + opt(m) + } + return m +} + +// withLockID sets the ID field of the mutation. +func withLockID(id int) lockOption { + return func(m *LockMutation) { + var ( + err error + once sync.Once + value *Lock + ) + m.oldValue = func(ctx context.Context) (*Lock, error) { + once.Do(func() { + if m.done { + err = errors.New("querying old values post mutation is not allowed") + } else { + value, err = m.Client().Lock.Get(ctx, id) + } + }) + return value, err + } + m.id = &id + } +} + +// withLock sets the old Lock of the mutation. +func withLock(node *Lock) lockOption { + return func(m *LockMutation) { + m.oldValue = func(context.Context) (*Lock, error) { + return node, nil + } + m.id = &node.ID + } +} + +// Client returns a new `ent.Client` from the mutation. If the mutation was +// executed in a transaction (ent.Tx), a transactional client is returned. +func (m LockMutation) Client() *Client { + client := &Client{config: m.config} + client.init() + return client +} + +// Tx returns an `ent.Tx` for mutations that were executed in transactions; +// it returns an error otherwise. +func (m LockMutation) Tx() (*Tx, error) { + if _, ok := m.driver.(*txDriver); !ok { + return nil, errors.New("ent: mutation is not running in a transaction") + } + tx := &Tx{config: m.config} + tx.init() + return tx, nil +} + +// ID returns the ID value in the mutation. Note that the ID is only available +// if it was provided to the builder or after it was returned from the database. +func (m *LockMutation) ID() (id int, exists bool) { + if m.id == nil { + return + } + return *m.id, true +} + +// IDs queries the database and returns the entity ids that match the mutation's predicate. +// That means, if the mutation is applied within a transaction with an isolation level such +// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated +// or updated by the mutation. +func (m *LockMutation) IDs(ctx context.Context) ([]int, error) { + switch { + case m.op.Is(OpUpdateOne | OpDeleteOne): + id, exists := m.ID() + if exists { + return []int{id}, nil + } + fallthrough + case m.op.Is(OpUpdate | OpDelete): + return m.Client().Lock.Query().Where(m.predicates...).IDs(ctx) + default: + return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) + } +} + +// SetName sets the "name" field. +func (m *LockMutation) SetName(s string) { + m.name = &s +} + +// Name returns the value of the "name" field in the mutation. +func (m *LockMutation) Name() (r string, exists bool) { + v := m.name + if v == nil { + return + } + return *v, true +} + +// OldName returns the old "name" field's value of the Lock entity. +// If the Lock object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *LockMutation) OldName(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldName is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldName requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldName: %w", err) + } + return oldValue.Name, nil +} + +// ResetName resets all changes to the "name" field. +func (m *LockMutation) ResetName() { + m.name = nil +} + +// SetCreatedAt sets the "created_at" field. +func (m *LockMutation) SetCreatedAt(t time.Time) { + m.created_at = &t +} + +// CreatedAt returns the value of the "created_at" field in the mutation. +func (m *LockMutation) CreatedAt() (r time.Time, exists bool) { + v := m.created_at + if v == nil { + return + } + return *v, true +} + +// OldCreatedAt returns the old "created_at" field's value of the Lock entity. +// If the Lock object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *LockMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCreatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) + } + return oldValue.CreatedAt, nil +} + +// ResetCreatedAt resets all changes to the "created_at" field. +func (m *LockMutation) ResetCreatedAt() { + m.created_at = nil +} + +// Where appends a list predicates to the LockMutation builder. +func (m *LockMutation) Where(ps ...predicate.Lock) { + m.predicates = append(m.predicates, ps...) +} + +// WhereP appends storage-level predicates to the LockMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *LockMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.Lock, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + +// Op returns the operation name. +func (m *LockMutation) Op() Op { + return m.op +} + +// SetOp allows setting the mutation operation. +func (m *LockMutation) SetOp(op Op) { + m.op = op +} + +// Type returns the node type of this mutation (Lock). +func (m *LockMutation) Type() string { + return m.typ +} + +// Fields returns all fields that were changed during this mutation. Note that in +// order to get all numeric fields that were incremented/decremented, call +// AddedFields(). +func (m *LockMutation) Fields() []string { + fields := make([]string, 0, 2) + if m.name != nil { + fields = append(fields, lock.FieldName) + } + if m.created_at != nil { + fields = append(fields, lock.FieldCreatedAt) + } + return fields +} + +// Field returns the value of a field with the given name. The second boolean +// return value indicates that this field was not set, or was not defined in the +// schema. +func (m *LockMutation) Field(name string) (ent.Value, bool) { + switch name { + case lock.FieldName: + return m.Name() + case lock.FieldCreatedAt: + return m.CreatedAt() + } + return nil, false +} + +// OldField returns the old value of the field from the database. An error is +// returned if the mutation operation is not UpdateOne, or the query to the +// database failed. +func (m *LockMutation) OldField(ctx context.Context, name string) (ent.Value, error) { + switch name { + case lock.FieldName: + return m.OldName(ctx) + case lock.FieldCreatedAt: + return m.OldCreatedAt(ctx) + } + return nil, fmt.Errorf("unknown Lock field %s", name) +} + +// SetField sets the value of a field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *LockMutation) SetField(name string, value ent.Value) error { + switch name { + case lock.FieldName: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetName(v) + return nil + case lock.FieldCreatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCreatedAt(v) + return nil + } + return fmt.Errorf("unknown Lock field %s", name) +} + +// AddedFields returns all numeric fields that were incremented/decremented during +// this mutation. +func (m *LockMutation) AddedFields() []string { + return nil +} + +// AddedField returns the numeric value that was incremented/decremented on a field +// with the given name. The second boolean return value indicates that this field +// was not set, or was not defined in the schema. +func (m *LockMutation) AddedField(name string) (ent.Value, bool) { + return nil, false +} + +// AddField adds the value to the field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *LockMutation) AddField(name string, value ent.Value) error { + switch name { + } + return fmt.Errorf("unknown Lock numeric field %s", name) +} + +// ClearedFields returns all nullable fields that were cleared during this +// mutation. +func (m *LockMutation) ClearedFields() []string { + return nil +} + +// FieldCleared returns a boolean indicating if a field with the given name was +// cleared in this mutation. +func (m *LockMutation) FieldCleared(name string) bool { + _, ok := m.clearedFields[name] + return ok +} + +// ClearField clears the value of the field with the given name. It returns an +// error if the field is not defined in the schema. +func (m *LockMutation) ClearField(name string) error { + return fmt.Errorf("unknown Lock nullable field %s", name) +} + +// ResetField resets all changes in the mutation for the field with the given name. +// It returns an error if the field is not defined in the schema. +func (m *LockMutation) ResetField(name string) error { + switch name { + case lock.FieldName: + m.ResetName() + return nil + case lock.FieldCreatedAt: + m.ResetCreatedAt() + return nil + } + return fmt.Errorf("unknown Lock field %s", name) +} + +// AddedEdges returns all edge names that were set/added in this mutation. +func (m *LockMutation) AddedEdges() []string { + edges := make([]string, 0, 0) + return edges +} + +// AddedIDs returns all IDs (to other nodes) that were added for the given edge +// name in this mutation. +func (m *LockMutation) AddedIDs(name string) []ent.Value { + return nil +} + +// RemovedEdges returns all edge names that were removed in this mutation. +func (m *LockMutation) RemovedEdges() []string { + edges := make([]string, 0, 0) + return edges +} + +// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with +// the given name in this mutation. +func (m *LockMutation) RemovedIDs(name string) []ent.Value { + return nil +} + +// ClearedEdges returns all edge names that were cleared in this mutation. +func (m *LockMutation) ClearedEdges() []string { + edges := make([]string, 0, 0) + return edges +} + +// EdgeCleared returns a boolean which indicates if the edge with the given name +// was cleared in this mutation. +func (m *LockMutation) EdgeCleared(name string) bool { + return false +} + +// ClearEdge clears the value of the edge with the given name. It returns an error +// if that edge is not defined in the schema. +func (m *LockMutation) ClearEdge(name string) error { + return fmt.Errorf("unknown Lock unique edge %s", name) +} + +// ResetEdge resets all changes to the edge with the given name in this mutation. +// It returns an error if the edge is not defined in the schema. +func (m *LockMutation) ResetEdge(name string) error { + return fmt.Errorf("unknown Lock edge %s", name) +} + // MachineMutation represents an operation that mutates the Machine nodes in the graph. type MachineMutation struct { config diff --git a/pkg/database/ent/predicate/predicate.go b/pkg/database/ent/predicate/predicate.go index e95abcec343..ad2e6d3f327 100644 --- a/pkg/database/ent/predicate/predicate.go +++ b/pkg/database/ent/predicate/predicate.go @@ -21,6 +21,9 @@ type Decision func(*sql.Selector) // Event is the predicate function for event builders. type Event func(*sql.Selector) +// Lock is the predicate function for lock builders. +type Lock func(*sql.Selector) + // Machine is the predicate function for machine builders. type Machine func(*sql.Selector) diff --git a/pkg/database/ent/runtime.go b/pkg/database/ent/runtime.go index bceea37b3a7..87073074563 100644 --- a/pkg/database/ent/runtime.go +++ b/pkg/database/ent/runtime.go @@ -10,6 +10,7 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/database/ent/configitem" "github.com/crowdsecurity/crowdsec/pkg/database/ent/decision" "github.com/crowdsecurity/crowdsec/pkg/database/ent/event" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/lock" "github.com/crowdsecurity/crowdsec/pkg/database/ent/machine" "github.com/crowdsecurity/crowdsec/pkg/database/ent/meta" "github.com/crowdsecurity/crowdsec/pkg/database/ent/schema" @@ -137,6 +138,12 @@ func init() { eventDescSerialized := eventFields[3].Descriptor() // event.SerializedValidator is a validator for the "serialized" field. It is called by the builders before save. event.SerializedValidator = eventDescSerialized.Validators[0].(func(string) error) + lockFields := schema.Lock{}.Fields() + _ = lockFields + // lockDescCreatedAt is the schema descriptor for created_at field. + lockDescCreatedAt := lockFields[1].Descriptor() + // lock.DefaultCreatedAt holds the default value on creation for the created_at field. + lock.DefaultCreatedAt = lockDescCreatedAt.Default.(func() time.Time) machineFields := schema.Machine{}.Fields() _ = machineFields // machineDescCreatedAt is the schema descriptor for created_at field. diff --git a/pkg/database/ent/schema/lock.go b/pkg/database/ent/schema/lock.go new file mode 100644 index 00000000000..de87efff3f7 --- /dev/null +++ b/pkg/database/ent/schema/lock.go @@ -0,0 +1,22 @@ +package schema + +import ( + "entgo.io/ent" + "entgo.io/ent/schema/field" + "github.com/crowdsecurity/crowdsec/pkg/types" +) + +type Lock struct { + ent.Schema +} + +func (Lock) Fields() []ent.Field { + return []ent.Field{ + field.String("name").Unique().StructTag(`json:"name"`), + field.Time("created_at").Default(types.UtcNow).StructTag(`json:"created_at"`), + } +} + +func (Lock) Edges() []ent.Edge { + return nil +} diff --git a/pkg/database/ent/tx.go b/pkg/database/ent/tx.go index 65c2ed00a44..27b39c12502 100644 --- a/pkg/database/ent/tx.go +++ b/pkg/database/ent/tx.go @@ -22,6 +22,8 @@ type Tx struct { Decision *DecisionClient // Event is the client for interacting with the Event builders. Event *EventClient + // Lock is the client for interacting with the Lock builders. + Lock *LockClient // Machine is the client for interacting with the Machine builders. Machine *MachineClient // Meta is the client for interacting with the Meta builders. @@ -162,6 +164,7 @@ func (tx *Tx) init() { tx.ConfigItem = NewConfigItemClient(tx.config) tx.Decision = NewDecisionClient(tx.config) tx.Event = NewEventClient(tx.config) + tx.Lock = NewLockClient(tx.config) tx.Machine = NewMachineClient(tx.config) tx.Meta = NewMetaClient(tx.config) } diff --git a/pkg/database/lock.go b/pkg/database/lock.go new file mode 100644 index 00000000000..339226e8592 --- /dev/null +++ b/pkg/database/lock.go @@ -0,0 +1,67 @@ +package database + +import ( + "time" + + "github.com/pkg/errors" + log "github.com/sirupsen/logrus" + + "github.com/crowdsecurity/crowdsec/pkg/database/ent" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/lock" + "github.com/crowdsecurity/crowdsec/pkg/types" +) + +const ( + CAPIPullLockTimeout = 120 +) + +func (c *Client) AcquireLock(name string) error { + _, err := c.Ent.Lock.Create(). + SetName(name). + SetCreatedAt(types.UtcNow()). + Save(c.CTX) + if ent.IsConstraintError(err) { + return err + } + if err != nil { + return errors.Wrapf(InsertFail, "insert lock: %s", err) + } + return nil +} + +func (c *Client) ReleaseLock(name string) error { + _, err := c.Ent.Lock.Delete().Where(lock.NameEQ(name)).Exec(c.CTX) + if err != nil { + return errors.Wrapf(DeleteFail, "delete lock: %s", err) + } + return nil +} + +func (c *Client) ReleaseLockWithTimeout(name string, timeout int) error { + log.Debugf("(%s) releasing orphin locks", name) + _, err := c.Ent.Lock.Delete().Where( + lock.NameEQ(name), + lock.CreatedAtLT(time.Now().Add(-time.Duration(timeout)*time.Minute)), + ).Exec(c.CTX) + if err != nil { + return errors.Wrapf(DeleteFail, "delete lock: %s", err) + } + return nil +} + +func (c *Client) IsLocked(err error) bool { + return ent.IsConstraintError(err) +} + +func (c *Client) AcquirePullCAPILock() error { + lockName := "pullCAPI" + err := c.ReleaseLockWithTimeout(lockName, CAPIPullLockTimeout) + if err != nil { + log.Errorf("unable to release pullCAPI lock: %s", err) + } + return c.AcquireLock(lockName) +} + +func (c *Client) ReleasePullCAPILock() error { + return c.ReleaseLockWithTimeout("pullCAPI", CAPIPullLockTimeout) +} From 717fc97ca039a2fdf2afbdd73b2a8b417b48c69e Mon Sep 17 00:00:00 2001 From: "Thibault \"bui\" Koechlin" Date: Wed, 14 Feb 2024 13:38:40 +0100 Subject: [PATCH 13/20] add SetMeta and SetParsed helpers (#2845) * add SetMeta and SetParsed helpers --- pkg/types/event.go | 16 ++++++++ pkg/types/event_test.go | 82 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 98 insertions(+) diff --git a/pkg/types/event.go b/pkg/types/event.go index 074241918d8..c7b19fe3ca4 100644 --- a/pkg/types/event.go +++ b/pkg/types/event.go @@ -46,6 +46,22 @@ type Event struct { Meta map[string]string `yaml:"Meta,omitempty" json:"Meta,omitempty"` } +func (e *Event) SetMeta(key string, value string) bool { + if e.Meta == nil { + e.Meta = make(map[string]string) + } + e.Meta[key] = value + return true +} + +func (e *Event) SetParsed(key string, value string) bool { + if e.Parsed == nil { + e.Parsed = make(map[string]string) + } + e.Parsed[key] = value + return true +} + func (e *Event) GetType() string { if e.Type == OVFLW { return "overflow" diff --git a/pkg/types/event_test.go b/pkg/types/event_test.go index 14ca48cd2a8..a2fad9ebcc7 100644 --- a/pkg/types/event_test.go +++ b/pkg/types/event_test.go @@ -9,6 +9,88 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/models" ) +func TestSetParsed(t *testing.T) { + tests := []struct { + name string + evt *Event + key string + value string + expected bool + }{ + { + name: "SetParsed: Valid", + evt: &Event{}, + key: "test", + value: "test", + expected: true, + }, + { + name: "SetParsed: Existing map", + evt: &Event{Parsed: map[string]string{}}, + key: "test", + value: "test", + expected: true, + }, + { + name: "SetParsed: Existing map+key", + evt: &Event{Parsed: map[string]string{"test": "foobar"}}, + key: "test", + value: "test", + expected: true, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + tt.evt.SetParsed(tt.key, tt.value) + assert.Equal(t, tt.value, tt.evt.Parsed[tt.key]) + }) + } + +} + +func TestSetMeta(t *testing.T) { + tests := []struct { + name string + evt *Event + key string + value string + expected bool + }{ + { + name: "SetMeta: Valid", + evt: &Event{}, + key: "test", + value: "test", + expected: true, + }, + { + name: "SetMeta: Existing map", + evt: &Event{Meta: map[string]string{}}, + key: "test", + value: "test", + expected: true, + }, + { + name: "SetMeta: Existing map+key", + evt: &Event{Meta: map[string]string{"test": "foobar"}}, + key: "test", + value: "test", + expected: true, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + tt.evt.SetMeta(tt.key, tt.value) + assert.Equal(t, tt.value, tt.evt.GetMeta(tt.key)) + }) + } + +} + func TestParseIPSources(t *testing.T) { tests := []struct { name string From e976614645aba906a096f4bdf46e09709f71d096 Mon Sep 17 00:00:00 2001 From: mmetc <92726601+mmetc@users.noreply.github.com> Date: Thu, 15 Feb 2024 14:34:12 +0100 Subject: [PATCH 14/20] cscli metrics: rename buckets -> scenarios (#2848) * cscli metrics: rename buckets -> scenarios * update lint configuration * lint --- .golangci.yml | 6 +++++- cmd/crowdsec-cli/metrics.go | 20 ++++++++++++-------- cmd/crowdsec-cli/metrics_table.go | 18 ++++++++++++++---- pkg/exprhelpers/exprlib_test.go | 4 ++-- pkg/parser/README.md | 2 +- pkg/setup/README.md | 2 +- 6 files changed, 35 insertions(+), 17 deletions(-) diff --git a/.golangci.yml b/.golangci.yml index e605ac079d4..29332447b61 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -73,6 +73,10 @@ linters-settings: - pkg: "github.com/pkg/errors" desc: "errors.Wrap() is deprecated in favor of fmt.Errorf()" + wsl: + # Allow blocks to end with comments + allow-trailing-comment: true + linters: enable-all: true disable: @@ -105,6 +109,7 @@ linters: # - durationcheck # check for two durations multiplied together # - errcheck # Errcheck is a program for checking for unchecked errors in go programs. These unchecked errors can be critical bugs in some cases # - errorlint # errorlint is a linter for that can be used to find code that will cause problems with the error wrapping scheme introduced in Go 1.13. + # - execinquery # execinquery is a linter about query string checker in Query function which reads your Go src files and warning it finds # - exportloopref # checks for pointers to enclosing loop variables # - funlen # Tool for detection of long functions # - ginkgolinter # enforces standards of using ginkgo and gomega @@ -203,7 +208,6 @@ linters: # # Too strict / too many false positives (for now?) # - - execinquery # execinquery is a linter about query string checker in Query function which reads your Go src files and warning it finds - exhaustruct # Checks if all structure fields are initialized - forbidigo # Forbids identifiers - gochecknoglobals # check that no global variables exist diff --git a/cmd/crowdsec-cli/metrics.go b/cmd/crowdsec-cli/metrics.go index 6e23bcf12e4..0f92343868d 100644 --- a/cmd/crowdsec-cli/metrics.go +++ b/cmd/crowdsec-cli/metrics.go @@ -44,9 +44,8 @@ type ( ) var ( - ErrMissingConfig = errors.New("prometheus section missing, can't show metrics") + ErrMissingConfig = errors.New("prometheus section missing, can't show metrics") ErrMetricsDisabled = errors.New("prometheus is not enabled, can't show metrics") - ) type metricSection interface { @@ -59,7 +58,7 @@ type metricStore map[string]metricSection func NewMetricStore() metricStore { return metricStore{ "acquisition": statAcquis{}, - "buckets": statBucket{}, + "scenarios": statBucket{}, "parsers": statParser{}, "lapi": statLapi{}, "lapi-machine": statLapiMachine{}, @@ -110,7 +109,7 @@ func (ms metricStore) Fetch(url string) error { mAcquis := ms["acquisition"].(statAcquis) mParser := ms["parsers"].(statParser) - mBucket := ms["buckets"].(statBucket) + mBucket := ms["scenarios"].(statBucket) mLapi := ms["lapi"].(statLapi) mLapiMachine := ms["lapi-machine"].(statLapiMachine) mLapiBouncer := ms["lapi-bouncer"].(statLapiBouncer) @@ -361,7 +360,7 @@ cscli metrics --url http://lapi.local:6060/metrics show acquisition parsers cscli metrics list`, Args: cobra.ExactArgs(0), DisableAutoGenTag: true, - RunE: func(cmd *cobra.Command, args []string) error { + RunE: func(_ *cobra.Command, _ []string) error { return cli.show(nil, url, noUnit) }, } @@ -383,7 +382,7 @@ func (cli *cliMetrics) expandSectionGroups(args []string) []string { for _, section := range args { switch section { case "engine": - ret = append(ret, "acquisition", "parsers", "buckets", "stash", "whitelists") + ret = append(ret, "acquisition", "parsers", "scenarios", "stash", "whitelists") case "lapi": ret = append(ret, "alerts", "decisions", "lapi", "lapi-bouncer", "lapi-decisions", "lapi-machine") case "appsec": @@ -413,10 +412,13 @@ cscli metrics show cscli metrics show engine # Show some specific metrics, show empty tables, connect to a different url -cscli metrics show acquisition parsers buckets stash --url http://lapi.local:6060/metrics +cscli metrics show acquisition parsers scenarios stash --url http://lapi.local:6060/metrics + +# To list available metric types, use "cscli metrics list" +cscli metrics list; cscli metrics list -o json # Show metrics in json format -cscli metrics show acquisition parsers buckets stash -o json`, +cscli metrics show acquisition parsers scenarios stash -o json`, // Positional args are optional DisableAutoGenTag: true, RunE: func(_ *cobra.Command, args []string) error { @@ -467,12 +469,14 @@ func (cli *cliMetrics) list() error { if err != nil { return fmt.Errorf("failed to marshal metric types: %w", err) } + fmt.Println(string(x)) case "raw": x, err := yaml.Marshal(allMetrics) if err != nil { return fmt.Errorf("failed to marshal metric types: %w", err) } + fmt.Println(string(x)) } diff --git a/cmd/crowdsec-cli/metrics_table.go b/cmd/crowdsec-cli/metrics_table.go index da6ea3d9f1d..689929500ad 100644 --- a/cmd/crowdsec-cli/metrics_table.go +++ b/cmd/crowdsec-cli/metrics_table.go @@ -1,6 +1,7 @@ package main import ( + "errors" "fmt" "io" "sort" @@ -13,7 +14,7 @@ import ( ) // ErrNilTable means a nil pointer was passed instead of a table instance. This is a programming error. -var ErrNilTable = fmt.Errorf("nil table") +var ErrNilTable = errors.New("nil table") func lapiMetricsToTable(t *table.Table, stats map[string]map[string]map[string]int) int { // stats: machine -> route -> method -> count @@ -44,6 +45,7 @@ func lapiMetricsToTable(t *table.Table, stats map[string]map[string]map[string]i } t.AddRow(row...) + numRows++ } } @@ -82,6 +84,7 @@ func wlMetricsToTable(t *table.Table, stats map[string]map[string]map[string]int } t.AddRow(row...) + numRows++ } } @@ -120,6 +123,7 @@ func metricsToTable(t *table.Table, stats map[string]map[string]int, keys []stri } t.AddRow(row...) + numRows++ } @@ -127,7 +131,7 @@ func metricsToTable(t *table.Table, stats map[string]map[string]int, keys []stri } func (s statBucket) Description() (string, string) { - return "Bucket Metrics", + return "Scenario Metrics", `Measure events in different scenarios. Current count is the number of buckets during metrics collection. ` + `Overflows are past event-producing buckets, while Expired are the ones that didn’t receive enough events to Overflow.` } @@ -143,13 +147,13 @@ func (s statBucket) Process(bucket, metric string, val int) { func (s statBucket) Table(out io.Writer, noUnit bool, showEmpty bool) { t := newTable(out) t.SetRowLines(false) - t.SetHeaders("Bucket", "Current Count", "Overflows", "Instantiated", "Poured", "Expired") + t.SetHeaders("Scenario", "Current Count", "Overflows", "Instantiated", "Poured", "Expired") t.SetAlignment(table.AlignLeft, table.AlignLeft, table.AlignLeft, table.AlignLeft, table.AlignLeft, table.AlignLeft) keys := []string{"curr_count", "overflow", "instantiation", "pour", "underflow"} if numRows, err := metricsToTable(t, s, keys, noUnit); err != nil { - log.Warningf("while collecting bucket stats: %s", err) + log.Warningf("while collecting scenario stats: %s", err) } else if numRows > 0 || showEmpty { title, _ := s.Description() renderTableTitle(out, "\n"+title+":") @@ -352,6 +356,7 @@ func (s statStash) Table(out io.Writer, noUnit bool, showEmpty bool) { strconv.Itoa(astats.Count), } t.AddRow(row...) + numRows++ } @@ -400,7 +405,9 @@ func (s statLapi) Table(out io.Writer, noUnit bool, showEmpty bool) { sl, strconv.Itoa(astats[sl]), } + t.AddRow(row...) + numRows++ } } @@ -515,6 +522,7 @@ func (s statLapiDecision) Table(out io.Writer, noUnit bool, showEmpty bool) { strconv.Itoa(hits.Empty), strconv.Itoa(hits.NonEmpty), ) + numRows++ } @@ -560,6 +568,7 @@ func (s statDecision) Table(out io.Writer, noUnit bool, showEmpty bool) { action, strconv.Itoa(hits), ) + numRows++ } } @@ -594,6 +603,7 @@ func (s statAlert) Table(out io.Writer, noUnit bool, showEmpty bool) { scenario, strconv.Itoa(hits), ) + numRows++ } diff --git a/pkg/exprhelpers/exprlib_test.go b/pkg/exprhelpers/exprlib_test.go index 6b9cd15c73b..9d5a6556b25 100644 --- a/pkg/exprhelpers/exprlib_test.go +++ b/pkg/exprhelpers/exprlib_test.go @@ -200,7 +200,7 @@ func TestDistanceHelper(t *testing.T) { ret, err := expr.Run(vm, env) if test.valid { require.NoError(t, err) - assert.Equal(t, test.dist, ret) + assert.InDelta(t, test.dist, ret, 0.000001) } else { require.Error(t, err) } @@ -592,7 +592,7 @@ func TestAtof(t *testing.T) { require.NoError(t, err) output, err := expr.Run(program, test.env) require.NoError(t, err) - require.Equal(t, test.result, output) + require.InDelta(t, test.result, output, 0.000001) } } diff --git a/pkg/parser/README.md b/pkg/parser/README.md index 62a56e61820..0fcccc811e4 100644 --- a/pkg/parser/README.md +++ b/pkg/parser/README.md @@ -45,7 +45,7 @@ statics: > `filter: "Line.Src endsWith '/foobar'"` - - *optional* `filter` : an [expression](https://github.com/antonmedv/expr/blob/master/docs/Language-Definition.md) that will be evaluated against the runtime of a line (`Event`) + - *optional* `filter` : an [expression](https://github.com/antonmedv/expr/blob/master/docs/language-definition.md) that will be evaluated against the runtime of a line (`Event`) - if the `filter` is present and returns false, node is not evaluated - if `filter` is absent or present and returns true, node is evaluated diff --git a/pkg/setup/README.md b/pkg/setup/README.md index 3585ee8b141..9cdc7243975 100644 --- a/pkg/setup/README.md +++ b/pkg/setup/README.md @@ -129,7 +129,7 @@ services: and must all return true for a service to be detected (implied *and* clause, no short-circuit). A missing or empty `when:` section is evaluated as true. The [expression -engine](https://github.com/antonmedv/expr/blob/master/docs/Language-Definition.md) +engine](https://github.com/antonmedv/expr/blob/master/docs/language-definition.md) is the same one used by CrowdSec parser filters. You can force the detection of a process by using the `cscli setup detect... --force-process ` flag. It will always behave as if `` was running. From f3ea88f64ce7a594830558c84bc6f196ddddc323 Mon Sep 17 00:00:00 2001 From: Laurence Jones Date: Wed, 21 Feb 2024 13:40:38 +0000 Subject: [PATCH 15/20] Appsec unix socket (#2737) * Appsec socket * Patch detection of nil listenaddr * Allow TLS unix socket * Merge diff issue --- pkg/acquisition/modules/appsec/appsec.go | 55 ++++++++++++++++++------ 1 file changed, 42 insertions(+), 13 deletions(-) diff --git a/pkg/acquisition/modules/appsec/appsec.go b/pkg/acquisition/modules/appsec/appsec.go index 4e2ff0bd22b..a3c8c7dd8ee 100644 --- a/pkg/acquisition/modules/appsec/appsec.go +++ b/pkg/acquisition/modules/appsec/appsec.go @@ -4,7 +4,9 @@ import ( "context" "encoding/json" "fmt" + "net" "net/http" + "os" "sync" "time" @@ -34,6 +36,7 @@ var ( // configuration structure of the acquis for the application security engine type AppsecSourceConfig struct { ListenAddr string `yaml:"listen_addr"` + ListenSocket string `yaml:"listen_socket"` CertFilePath string `yaml:"cert_file"` KeyFilePath string `yaml:"key_file"` Path string `yaml:"path"` @@ -97,7 +100,7 @@ func (w *AppsecSource) UnmarshalConfig(yamlConfig []byte) error { return errors.Wrap(err, "Cannot parse appsec configuration") } - if w.config.ListenAddr == "" { + if w.config.ListenAddr == "" && w.config.ListenSocket == "" { w.config.ListenAddr = "127.0.0.1:7422" } @@ -123,7 +126,12 @@ func (w *AppsecSource) UnmarshalConfig(yamlConfig []byte) error { } if w.config.Name == "" { - w.config.Name = fmt.Sprintf("%s%s", w.config.ListenAddr, w.config.Path) + if w.config.ListenSocket != "" && w.config.ListenAddr == "" { + w.config.Name = w.config.ListenSocket + } + if w.config.ListenSocket == "" { + w.config.Name = fmt.Sprintf("%s%s", w.config.ListenAddr, w.config.Path) + } } csConfig := csconfig.GetConfig() @@ -251,23 +259,44 @@ func (w *AppsecSource) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) return runner.Run(t) }) } - - w.logger.Infof("Starting Appsec server on %s%s", w.config.ListenAddr, w.config.Path) t.Go(func() error { - var err error - if w.config.CertFilePath != "" && w.config.KeyFilePath != "" { - err = w.server.ListenAndServeTLS(w.config.CertFilePath, w.config.KeyFilePath) - } else { - err = w.server.ListenAndServe() + if w.config.ListenSocket != "" { + w.logger.Infof("creating unix socket %s", w.config.ListenSocket) + _ = os.RemoveAll(w.config.ListenSocket) + listener, err := net.Listen("unix", w.config.ListenSocket) + if err != nil { + return errors.Wrap(err, "Appsec server failed") + } + defer listener.Close() + if w.config.CertFilePath != "" && w.config.KeyFilePath != "" { + err = w.server.ServeTLS(listener, w.config.CertFilePath, w.config.KeyFilePath) + } else { + err = w.server.Serve(listener) + } + if err != nil && err != http.ErrServerClosed { + return errors.Wrap(err, "Appsec server failed") + } } - - if err != nil && err != http.ErrServerClosed { - return errors.Wrap(err, "Appsec server failed") + return nil + }) + t.Go(func() error { + var err error + if w.config.ListenAddr != "" { + w.logger.Infof("creating TCP server on %s", w.config.ListenAddr) + if w.config.CertFilePath != "" && w.config.KeyFilePath != "" { + err = w.server.ListenAndServeTLS(w.config.CertFilePath, w.config.KeyFilePath) + } else { + err = w.server.ListenAndServe() + } + + if err != nil && err != http.ErrServerClosed { + return errors.Wrap(err, "Appsec server failed") + } } return nil }) <-t.Dying() - w.logger.Infof("Stopping Appsec server on %s%s", w.config.ListenAddr, w.config.Path) + w.logger.Info("Shutting down Appsec server") //xx let's clean up the appsec runners :) appsec.AppsecRulesDetails = make(map[int]appsec.RulesDetails) w.server.Shutdown(context.TODO()) From 3e3df5e4c6e6deb1ef36bb406e86a7ebc8c30f06 Mon Sep 17 00:00:00 2001 From: mmetc <92726601+mmetc@users.noreply.github.com> Date: Thu, 22 Feb 2024 11:04:36 +0100 Subject: [PATCH 16/20] refact "cscli config", remove flag "cscli restore --old-backup" (#2832) * refact "cscli config show" * refact "cscli config backup" * refact "cscli confgi show-yaml" * refact "cscli config restore" * refact "cscli config feature-flags" * cscli restore: remove 'old-backup' option * lint (whitespace, wrapped errors) --- cmd/crowdsec-cli/config.go | 26 ++-- cmd/crowdsec-cli/config_backup.go | 99 ++++++------- cmd/crowdsec-cli/config_feature_flags.go | 25 ++-- cmd/crowdsec-cli/config_restore.go | 175 ++++++++--------------- cmd/crowdsec-cli/config_show.go | 37 +++-- cmd/crowdsec-cli/config_showyaml.go | 12 +- cmd/crowdsec-cli/main.go | 2 +- 7 files changed, 167 insertions(+), 209 deletions(-) diff --git a/cmd/crowdsec-cli/config.go b/cmd/crowdsec-cli/config.go index e60246db790..e88845798e2 100644 --- a/cmd/crowdsec-cli/config.go +++ b/cmd/crowdsec-cli/config.go @@ -4,19 +4,29 @@ import ( "github.com/spf13/cobra" ) -func NewConfigCmd() *cobra.Command { - cmdConfig := &cobra.Command{ +type cliConfig struct { + cfg configGetter +} + +func NewCLIConfig(cfg configGetter) *cliConfig { + return &cliConfig{ + cfg: cfg, + } +} + +func (cli *cliConfig) NewCommand() *cobra.Command { + cmd := &cobra.Command{ Use: "config [command]", Short: "Allows to view current config", Args: cobra.ExactArgs(0), DisableAutoGenTag: true, } - cmdConfig.AddCommand(NewConfigShowCmd()) - cmdConfig.AddCommand(NewConfigShowYAMLCmd()) - cmdConfig.AddCommand(NewConfigBackupCmd()) - cmdConfig.AddCommand(NewConfigRestoreCmd()) - cmdConfig.AddCommand(NewConfigFeatureFlagsCmd()) + cmd.AddCommand(cli.newShowCmd()) + cmd.AddCommand(cli.newShowYAMLCmd()) + cmd.AddCommand(cli.newBackupCmd()) + cmd.AddCommand(cli.newRestoreCmd()) + cmd.AddCommand(cli.newFeatureFlagsCmd()) - return cmdConfig + return cmd } diff --git a/cmd/crowdsec-cli/config_backup.go b/cmd/crowdsec-cli/config_backup.go index 9414fa51033..d1e4a393555 100644 --- a/cmd/crowdsec-cli/config_backup.go +++ b/cmd/crowdsec-cli/config_backup.go @@ -2,6 +2,7 @@ package main import ( "encoding/json" + "errors" "fmt" "os" "path/filepath" @@ -13,8 +14,8 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/cwhub" ) -func backupHub(dirPath string) error { - hub, err := require.Hub(csConfig, nil, nil) +func (cli *cliConfig) backupHub(dirPath string) error { + hub, err := require.Hub(cli.cfg(), nil, nil) if err != nil { return err } @@ -32,7 +33,7 @@ func backupHub(dirPath string) error { itemDirectory := fmt.Sprintf("%s/%s/", dirPath, itemType) if err = os.MkdirAll(itemDirectory, os.ModePerm); err != nil { - return fmt.Errorf("error while creating %s : %s", itemDirectory, err) + return fmt.Errorf("error while creating %s: %w", itemDirectory, err) } upstreamParsers := []string{} @@ -41,18 +42,18 @@ func backupHub(dirPath string) error { clog = clog.WithFields(log.Fields{ "file": v.Name, }) - if !v.State.Installed { //only backup installed ones - clog.Debugf("[%s] : not installed", k) + if !v.State.Installed { // only backup installed ones + clog.Debugf("[%s]: not installed", k) continue } - //for the local/tainted ones, we back up the full file + // for the local/tainted ones, we back up the full file if v.State.Tainted || v.State.IsLocal() || !v.State.UpToDate { - //we need to backup stages for parsers + // we need to backup stages for parsers if itemType == cwhub.PARSERS || itemType == cwhub.POSTOVERFLOWS { fstagedir := fmt.Sprintf("%s%s", itemDirectory, v.Stage) if err = os.MkdirAll(fstagedir, os.ModePerm); err != nil { - return fmt.Errorf("error while creating stage dir %s : %s", fstagedir, err) + return fmt.Errorf("error while creating stage dir %s: %w", fstagedir, err) } } @@ -60,7 +61,7 @@ func backupHub(dirPath string) error { tfile := fmt.Sprintf("%s%s/%s", itemDirectory, v.Stage, v.FileName) if err = CopyFile(v.State.LocalPath, tfile); err != nil { - return fmt.Errorf("failed copy %s %s to %s : %s", itemType, v.State.LocalPath, tfile, err) + return fmt.Errorf("failed copy %s %s to %s: %w", itemType, v.State.LocalPath, tfile, err) } clog.Infof("local/tainted saved %s to %s", v.State.LocalPath, tfile) @@ -68,21 +69,21 @@ func backupHub(dirPath string) error { continue } - clog.Debugf("[%s] : from hub, just backup name (up-to-date:%t)", k, v.State.UpToDate) + clog.Debugf("[%s]: from hub, just backup name (up-to-date:%t)", k, v.State.UpToDate) clog.Infof("saving, version:%s, up-to-date:%t", v.Version, v.State.UpToDate) upstreamParsers = append(upstreamParsers, v.Name) } - //write the upstream items + // write the upstream items upstreamParsersFname := fmt.Sprintf("%s/upstream-%s.json", itemDirectory, itemType) upstreamParsersContent, err := json.MarshalIndent(upstreamParsers, "", " ") if err != nil { - return fmt.Errorf("failed marshaling upstream parsers : %s", err) + return fmt.Errorf("failed marshaling upstream parsers: %w", err) } err = os.WriteFile(upstreamParsersFname, upstreamParsersContent, 0o644) if err != nil { - return fmt.Errorf("unable to write to %s %s : %s", itemType, upstreamParsersFname, err) + return fmt.Errorf("unable to write to %s %s: %w", itemType, upstreamParsersFname, err) } clog.Infof("Wrote %d entries for %s to %s", len(upstreamParsers), itemType, upstreamParsersFname) @@ -102,11 +103,13 @@ func backupHub(dirPath string) error { - Tainted/local/out-of-date scenarios, parsers, postoverflows and collections - Acquisition files (acquis.yaml, acquis.d/*.yaml) */ -func backupConfigToDirectory(dirPath string) error { +func (cli *cliConfig) backup(dirPath string) error { var err error + cfg := cli.cfg() + if dirPath == "" { - return fmt.Errorf("directory path can't be empty") + return errors.New("directory path can't be empty") } log.Infof("Starting configuration backup") @@ -121,10 +124,10 @@ func backupConfigToDirectory(dirPath string) error { return fmt.Errorf("while creating %s: %w", dirPath, err) } - if csConfig.ConfigPaths.SimulationFilePath != "" { + if cfg.ConfigPaths.SimulationFilePath != "" { backupSimulation := filepath.Join(dirPath, "simulation.yaml") - if err = CopyFile(csConfig.ConfigPaths.SimulationFilePath, backupSimulation); err != nil { - return fmt.Errorf("failed copy %s to %s: %w", csConfig.ConfigPaths.SimulationFilePath, backupSimulation, err) + if err = CopyFile(cfg.ConfigPaths.SimulationFilePath, backupSimulation); err != nil { + return fmt.Errorf("failed copy %s to %s: %w", cfg.ConfigPaths.SimulationFilePath, backupSimulation, err) } log.Infof("Saved simulation to %s", backupSimulation) @@ -134,22 +137,22 @@ func backupConfigToDirectory(dirPath string) error { - backup AcquisitionFilePath - backup the other files of acquisition directory */ - if csConfig.Crowdsec != nil && csConfig.Crowdsec.AcquisitionFilePath != "" { + if cfg.Crowdsec != nil && cfg.Crowdsec.AcquisitionFilePath != "" { backupAcquisition := filepath.Join(dirPath, "acquis.yaml") - if err = CopyFile(csConfig.Crowdsec.AcquisitionFilePath, backupAcquisition); err != nil { - return fmt.Errorf("failed copy %s to %s: %s", csConfig.Crowdsec.AcquisitionFilePath, backupAcquisition, err) + if err = CopyFile(cfg.Crowdsec.AcquisitionFilePath, backupAcquisition); err != nil { + return fmt.Errorf("failed copy %s to %s: %w", cfg.Crowdsec.AcquisitionFilePath, backupAcquisition, err) } } acquisBackupDir := filepath.Join(dirPath, "acquis") if err = os.Mkdir(acquisBackupDir, 0o700); err != nil { - return fmt.Errorf("error while creating %s: %s", acquisBackupDir, err) + return fmt.Errorf("error while creating %s: %w", acquisBackupDir, err) } - if csConfig.Crowdsec != nil && len(csConfig.Crowdsec.AcquisitionFiles) > 0 { - for _, acquisFile := range csConfig.Crowdsec.AcquisitionFiles { + if cfg.Crowdsec != nil && len(cfg.Crowdsec.AcquisitionFiles) > 0 { + for _, acquisFile := range cfg.Crowdsec.AcquisitionFiles { /*if it was the default one, it was already backup'ed*/ - if csConfig.Crowdsec.AcquisitionFilePath == acquisFile { + if cfg.Crowdsec.AcquisitionFilePath == acquisFile { continue } @@ -169,56 +172,48 @@ func backupConfigToDirectory(dirPath string) error { if ConfigFilePath != "" { backupMain := fmt.Sprintf("%s/config.yaml", dirPath) if err = CopyFile(ConfigFilePath, backupMain); err != nil { - return fmt.Errorf("failed copy %s to %s: %s", ConfigFilePath, backupMain, err) + return fmt.Errorf("failed copy %s to %s: %w", ConfigFilePath, backupMain, err) } log.Infof("Saved default yaml to %s", backupMain) } - if csConfig.API != nil && csConfig.API.Server != nil && csConfig.API.Server.OnlineClient != nil && csConfig.API.Server.OnlineClient.CredentialsFilePath != "" { + if cfg.API != nil && cfg.API.Server != nil && cfg.API.Server.OnlineClient != nil && cfg.API.Server.OnlineClient.CredentialsFilePath != "" { backupCAPICreds := fmt.Sprintf("%s/online_api_credentials.yaml", dirPath) - if err = CopyFile(csConfig.API.Server.OnlineClient.CredentialsFilePath, backupCAPICreds); err != nil { - return fmt.Errorf("failed copy %s to %s: %s", csConfig.API.Server.OnlineClient.CredentialsFilePath, backupCAPICreds, err) + if err = CopyFile(cfg.API.Server.OnlineClient.CredentialsFilePath, backupCAPICreds); err != nil { + return fmt.Errorf("failed copy %s to %s: %w", cfg.API.Server.OnlineClient.CredentialsFilePath, backupCAPICreds, err) } log.Infof("Saved online API credentials to %s", backupCAPICreds) } - if csConfig.API != nil && csConfig.API.Client != nil && csConfig.API.Client.CredentialsFilePath != "" { + if cfg.API != nil && cfg.API.Client != nil && cfg.API.Client.CredentialsFilePath != "" { backupLAPICreds := fmt.Sprintf("%s/local_api_credentials.yaml", dirPath) - if err = CopyFile(csConfig.API.Client.CredentialsFilePath, backupLAPICreds); err != nil { - return fmt.Errorf("failed copy %s to %s: %s", csConfig.API.Client.CredentialsFilePath, backupLAPICreds, err) + if err = CopyFile(cfg.API.Client.CredentialsFilePath, backupLAPICreds); err != nil { + return fmt.Errorf("failed copy %s to %s: %w", cfg.API.Client.CredentialsFilePath, backupLAPICreds, err) } log.Infof("Saved local API credentials to %s", backupLAPICreds) } - if csConfig.API != nil && csConfig.API.Server != nil && csConfig.API.Server.ProfilesPath != "" { + if cfg.API != nil && cfg.API.Server != nil && cfg.API.Server.ProfilesPath != "" { backupProfiles := fmt.Sprintf("%s/profiles.yaml", dirPath) - if err = CopyFile(csConfig.API.Server.ProfilesPath, backupProfiles); err != nil { - return fmt.Errorf("failed copy %s to %s: %s", csConfig.API.Server.ProfilesPath, backupProfiles, err) + if err = CopyFile(cfg.API.Server.ProfilesPath, backupProfiles); err != nil { + return fmt.Errorf("failed copy %s to %s: %w", cfg.API.Server.ProfilesPath, backupProfiles, err) } log.Infof("Saved profiles to %s", backupProfiles) } - if err = backupHub(dirPath); err != nil { - return fmt.Errorf("failed to backup hub config: %s", err) - } - - return nil -} - -func runConfigBackup(cmd *cobra.Command, args []string) error { - if err := backupConfigToDirectory(args[0]); err != nil { - return fmt.Errorf("failed to backup config: %w", err) + if err = cli.backupHub(dirPath); err != nil { + return fmt.Errorf("failed to backup hub config: %w", err) } return nil } -func NewConfigBackupCmd() *cobra.Command { - cmdConfigBackup := &cobra.Command{ +func (cli *cliConfig) newBackupCmd() *cobra.Command { + cmd := &cobra.Command{ Use: `backup "directory"`, Short: "Backup current config", Long: `Backup the current crowdsec configuration including : @@ -232,8 +227,14 @@ func NewConfigBackupCmd() *cobra.Command { Example: `cscli config backup ./my-backup`, Args: cobra.ExactArgs(1), DisableAutoGenTag: true, - RunE: runConfigBackup, + RunE: func(_ *cobra.Command, args []string) error { + if err := cli.backup(args[0]); err != nil { + return fmt.Errorf("failed to backup config: %w", err) + } + + return nil + }, } - return cmdConfigBackup + return cmd } diff --git a/cmd/crowdsec-cli/config_feature_flags.go b/cmd/crowdsec-cli/config_feature_flags.go index fbba1f56736..d1dbe2b93b7 100644 --- a/cmd/crowdsec-cli/config_feature_flags.go +++ b/cmd/crowdsec-cli/config_feature_flags.go @@ -11,14 +11,7 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/fflag" ) -func runConfigFeatureFlags(cmd *cobra.Command, args []string) error { - flags := cmd.Flags() - - showRetired, err := flags.GetBool("retired") - if err != nil { - return err - } - +func (cli *cliConfig) featureFlags(showRetired bool) error { green := color.New(color.FgGreen).SprintFunc() red := color.New(color.FgRed).SprintFunc() yellow := color.New(color.FgYellow).SprintFunc() @@ -121,18 +114,22 @@ func runConfigFeatureFlags(cmd *cobra.Command, args []string) error { return nil } -func NewConfigFeatureFlagsCmd() *cobra.Command { - cmdConfigFeatureFlags := &cobra.Command{ +func (cli *cliConfig) newFeatureFlagsCmd() *cobra.Command { + var showRetired bool + + cmd := &cobra.Command{ Use: "feature-flags", Short: "Displays feature flag status", Long: `Displays the supported feature flags and their current status.`, Args: cobra.ExactArgs(0), DisableAutoGenTag: true, - RunE: runConfigFeatureFlags, + RunE: func(_ *cobra.Command, _ []string) error { + return cli.featureFlags(showRetired) + }, } - flags := cmdConfigFeatureFlags.Flags() - flags.Bool("retired", false, "Show retired features") + flags := cmd.Flags() + flags.BoolVar(&showRetired, "retired", false, "Show retired features") - return cmdConfigFeatureFlags + return cmd } diff --git a/cmd/crowdsec-cli/config_restore.go b/cmd/crowdsec-cli/config_restore.go index 17d7494c60f..513f993ba80 100644 --- a/cmd/crowdsec-cli/config_restore.go +++ b/cmd/crowdsec-cli/config_restore.go @@ -3,25 +3,17 @@ package main import ( "encoding/json" "fmt" - "io" "os" "path/filepath" log "github.com/sirupsen/logrus" "github.com/spf13/cobra" - "gopkg.in/yaml.v2" "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require" - "github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/cwhub" ) -type OldAPICfg struct { - MachineID string `json:"machine_id"` - Password string `json:"password"` -} - -func restoreHub(dirPath string) error { +func (cli *cliConfig) restoreHub(dirPath string) error { hub, err := require.Hub(csConfig, require.RemoteHub(csConfig), nil) if err != nil { return err @@ -38,14 +30,14 @@ func restoreHub(dirPath string) error { file, err := os.ReadFile(upstreamListFN) if err != nil { - return fmt.Errorf("error while opening %s : %s", upstreamListFN, err) + return fmt.Errorf("error while opening %s: %w", upstreamListFN, err) } var upstreamList []string err = json.Unmarshal(file, &upstreamList) if err != nil { - return fmt.Errorf("error unmarshaling %s : %s", upstreamListFN, err) + return fmt.Errorf("error unmarshaling %s: %w", upstreamListFN, err) } for _, toinstall := range upstreamList { @@ -55,8 +47,7 @@ func restoreHub(dirPath string) error { continue } - err := item.Install(false, false) - if err != nil { + if err = item.Install(false, false); err != nil { log.Errorf("Error while installing %s : %s", toinstall, err) } } @@ -64,17 +55,17 @@ func restoreHub(dirPath string) error { /*restore the local and tainted items*/ files, err := os.ReadDir(itemDirectory) if err != nil { - return fmt.Errorf("failed enumerating files of %s : %s", itemDirectory, err) + return fmt.Errorf("failed enumerating files of %s: %w", itemDirectory, err) } for _, file := range files { - //this was the upstream data + // this was the upstream data if file.Name() == fmt.Sprintf("upstream-%s.json", itype) { continue } if itype == cwhub.PARSERS || itype == cwhub.POSTOVERFLOWS { - //we expect a stage here + // we expect a stage here if !file.IsDir() { continue } @@ -84,22 +75,23 @@ func restoreHub(dirPath string) error { log.Debugf("Found stage %s in %s, target directory : %s", stage, itype, stagedir) if err = os.MkdirAll(stagedir, os.ModePerm); err != nil { - return fmt.Errorf("error while creating stage directory %s : %s", stagedir, err) + return fmt.Errorf("error while creating stage directory %s: %w", stagedir, err) } // find items ifiles, err := os.ReadDir(itemDirectory + "/" + stage + "/") if err != nil { - return fmt.Errorf("failed enumerating files of %s : %s", itemDirectory+"/"+stage, err) + return fmt.Errorf("failed enumerating files of %s: %w", itemDirectory+"/"+stage, err) } - //finally copy item + + // finally copy item for _, tfile := range ifiles { log.Infof("Going to restore local/tainted [%s]", tfile.Name()) sourceFile := fmt.Sprintf("%s/%s/%s", itemDirectory, stage, tfile.Name()) destinationFile := fmt.Sprintf("%s%s", stagedir, tfile.Name()) if err = CopyFile(sourceFile, destinationFile); err != nil { - return fmt.Errorf("failed copy %s %s to %s : %s", itype, sourceFile, destinationFile, err) + return fmt.Errorf("failed copy %s %s to %s: %w", itype, sourceFile, destinationFile, err) } log.Infof("restored %s to %s", sourceFile, destinationFile) @@ -108,9 +100,11 @@ func restoreHub(dirPath string) error { log.Infof("Going to restore local/tainted [%s]", file.Name()) sourceFile := fmt.Sprintf("%s/%s", itemDirectory, file.Name()) destinationFile := fmt.Sprintf("%s/%s/%s", csConfig.ConfigPaths.ConfigDir, itype, file.Name()) + if err = CopyFile(sourceFile, destinationFile); err != nil { - return fmt.Errorf("failed copy %s %s to %s : %s", itype, sourceFile, destinationFile, err) + return fmt.Errorf("failed copy %s %s to %s: %w", itype, sourceFile, destinationFile, err) } + log.Infof("restored %s to %s", sourceFile, destinationFile) } } @@ -130,95 +124,60 @@ func restoreHub(dirPath string) error { - Tainted/local/out-of-date scenarios, parsers, postoverflows and collections - Acquisition files (acquis.yaml, acquis.d/*.yaml) */ -func restoreConfigFromDirectory(dirPath string, oldBackup bool) error { +func (cli *cliConfig) restore(dirPath string) error { var err error - if !oldBackup { - backupMain := fmt.Sprintf("%s/config.yaml", dirPath) - if _, err = os.Stat(backupMain); err == nil { - if csConfig.ConfigPaths != nil && csConfig.ConfigPaths.ConfigDir != "" { - if err = CopyFile(backupMain, fmt.Sprintf("%s/config.yaml", csConfig.ConfigPaths.ConfigDir)); err != nil { - return fmt.Errorf("failed copy %s to %s : %s", backupMain, csConfig.ConfigPaths.ConfigDir, err) - } + backupMain := fmt.Sprintf("%s/config.yaml", dirPath) + if _, err = os.Stat(backupMain); err == nil { + if csConfig.ConfigPaths != nil && csConfig.ConfigPaths.ConfigDir != "" { + if err = CopyFile(backupMain, fmt.Sprintf("%s/config.yaml", csConfig.ConfigPaths.ConfigDir)); err != nil { + return fmt.Errorf("failed copy %s to %s: %w", backupMain, csConfig.ConfigPaths.ConfigDir, err) } } + } - // Now we have config.yaml, we should regenerate config struct to have rights paths etc - ConfigFilePath = fmt.Sprintf("%s/config.yaml", csConfig.ConfigPaths.ConfigDir) - - log.Debug("Reloading configuration") + // Now we have config.yaml, we should regenerate config struct to have rights paths etc + ConfigFilePath = fmt.Sprintf("%s/config.yaml", csConfig.ConfigPaths.ConfigDir) - csConfig, _, err = loadConfigFor("config") - if err != nil { - return fmt.Errorf("failed to reload configuration: %s", err) - } + log.Debug("Reloading configuration") - backupCAPICreds := fmt.Sprintf("%s/online_api_credentials.yaml", dirPath) - if _, err = os.Stat(backupCAPICreds); err == nil { - if err = CopyFile(backupCAPICreds, csConfig.API.Server.OnlineClient.CredentialsFilePath); err != nil { - return fmt.Errorf("failed copy %s to %s : %s", backupCAPICreds, csConfig.API.Server.OnlineClient.CredentialsFilePath, err) - } - } + csConfig, _, err = loadConfigFor("config") + if err != nil { + return fmt.Errorf("failed to reload configuration: %w", err) + } - backupLAPICreds := fmt.Sprintf("%s/local_api_credentials.yaml", dirPath) - if _, err = os.Stat(backupLAPICreds); err == nil { - if err = CopyFile(backupLAPICreds, csConfig.API.Client.CredentialsFilePath); err != nil { - return fmt.Errorf("failed copy %s to %s : %s", backupLAPICreds, csConfig.API.Client.CredentialsFilePath, err) - } + backupCAPICreds := fmt.Sprintf("%s/online_api_credentials.yaml", dirPath) + if _, err = os.Stat(backupCAPICreds); err == nil { + if err = CopyFile(backupCAPICreds, csConfig.API.Server.OnlineClient.CredentialsFilePath); err != nil { + return fmt.Errorf("failed copy %s to %s: %w", backupCAPICreds, csConfig.API.Server.OnlineClient.CredentialsFilePath, err) } + } - backupProfiles := fmt.Sprintf("%s/profiles.yaml", dirPath) - if _, err = os.Stat(backupProfiles); err == nil { - if err = CopyFile(backupProfiles, csConfig.API.Server.ProfilesPath); err != nil { - return fmt.Errorf("failed copy %s to %s : %s", backupProfiles, csConfig.API.Server.ProfilesPath, err) - } + backupLAPICreds := fmt.Sprintf("%s/local_api_credentials.yaml", dirPath) + if _, err = os.Stat(backupLAPICreds); err == nil { + if err = CopyFile(backupLAPICreds, csConfig.API.Client.CredentialsFilePath); err != nil { + return fmt.Errorf("failed copy %s to %s: %w", backupLAPICreds, csConfig.API.Client.CredentialsFilePath, err) } - } else { - var oldAPICfg OldAPICfg - backupOldAPICfg := fmt.Sprintf("%s/api_creds.json", dirPath) - - jsonFile, err := os.Open(backupOldAPICfg) - if err != nil { - log.Warningf("failed to open %s : %s", backupOldAPICfg, err) - } else { - byteValue, _ := io.ReadAll(jsonFile) - err = json.Unmarshal(byteValue, &oldAPICfg) - if err != nil { - return fmt.Errorf("failed to load json file %s : %s", backupOldAPICfg, err) - } + } - apiCfg := csconfig.ApiCredentialsCfg{ - Login: oldAPICfg.MachineID, - Password: oldAPICfg.Password, - URL: CAPIBaseURL, - } - apiConfigDump, err := yaml.Marshal(apiCfg) - if err != nil { - return fmt.Errorf("unable to dump api credentials: %s", err) - } - apiConfigDumpFile := fmt.Sprintf("%s/online_api_credentials.yaml", csConfig.ConfigPaths.ConfigDir) - if csConfig.API.Server.OnlineClient != nil && csConfig.API.Server.OnlineClient.CredentialsFilePath != "" { - apiConfigDumpFile = csConfig.API.Server.OnlineClient.CredentialsFilePath - } - err = os.WriteFile(apiConfigDumpFile, apiConfigDump, 0o600) - if err != nil { - return fmt.Errorf("write api credentials in '%s' failed: %s", apiConfigDumpFile, err) - } - log.Infof("Saved API credentials to %s", apiConfigDumpFile) + backupProfiles := fmt.Sprintf("%s/profiles.yaml", dirPath) + if _, err = os.Stat(backupProfiles); err == nil { + if err = CopyFile(backupProfiles, csConfig.API.Server.ProfilesPath); err != nil { + return fmt.Errorf("failed copy %s to %s: %w", backupProfiles, csConfig.API.Server.ProfilesPath, err) } } backupSimulation := fmt.Sprintf("%s/simulation.yaml", dirPath) if _, err = os.Stat(backupSimulation); err == nil { if err = CopyFile(backupSimulation, csConfig.ConfigPaths.SimulationFilePath); err != nil { - return fmt.Errorf("failed copy %s to %s : %s", backupSimulation, csConfig.ConfigPaths.SimulationFilePath, err) + return fmt.Errorf("failed copy %s to %s: %w", backupSimulation, csConfig.ConfigPaths.SimulationFilePath, err) } } /*if there is a acquisition dir, restore its content*/ if csConfig.Crowdsec.AcquisitionDirPath != "" { if err = os.MkdirAll(csConfig.Crowdsec.AcquisitionDirPath, 0o700); err != nil { - return fmt.Errorf("error while creating %s : %s", csConfig.Crowdsec.AcquisitionDirPath, err) + return fmt.Errorf("error while creating %s: %w", csConfig.Crowdsec.AcquisitionDirPath, err) } } @@ -228,7 +187,7 @@ func restoreConfigFromDirectory(dirPath string, oldBackup bool) error { log.Debugf("restoring backup'ed %s", backupAcquisition) if err = CopyFile(backupAcquisition, csConfig.Crowdsec.AcquisitionFilePath); err != nil { - return fmt.Errorf("failed copy %s to %s : %s", backupAcquisition, csConfig.Crowdsec.AcquisitionFilePath, err) + return fmt.Errorf("failed copy %s to %s: %w", backupAcquisition, csConfig.Crowdsec.AcquisitionFilePath, err) } } @@ -244,7 +203,7 @@ func restoreConfigFromDirectory(dirPath string, oldBackup bool) error { log.Debugf("restoring %s to %s", acquisFile, targetFname) if err = CopyFile(acquisFile, targetFname); err != nil { - return fmt.Errorf("failed copy %s to %s : %s", acquisFile, targetFname, err) + return fmt.Errorf("failed copy %s to %s: %w", acquisFile, targetFname, err) } } } @@ -265,37 +224,22 @@ func restoreConfigFromDirectory(dirPath string, oldBackup bool) error { } if err = CopyFile(acquisFile, targetFname); err != nil { - return fmt.Errorf("failed copy %s to %s : %s", acquisFile, targetFname, err) + return fmt.Errorf("failed copy %s to %s: %w", acquisFile, targetFname, err) } log.Infof("Saved acquis %s to %s", acquisFile, targetFname) } } - if err = restoreHub(dirPath); err != nil { - return fmt.Errorf("failed to restore hub config : %s", err) + if err = cli.restoreHub(dirPath); err != nil { + return fmt.Errorf("failed to restore hub config: %w", err) } return nil } -func runConfigRestore(cmd *cobra.Command, args []string) error { - flags := cmd.Flags() - - oldBackup, err := flags.GetBool("old-backup") - if err != nil { - return err - } - - if err := restoreConfigFromDirectory(args[0], oldBackup); err != nil { - return fmt.Errorf("failed to restore config from %s: %w", args[0], err) - } - - return nil -} - -func NewConfigRestoreCmd() *cobra.Command { - cmdConfigRestore := &cobra.Command{ +func (cli *cliConfig) newRestoreCmd() *cobra.Command { + cmd := &cobra.Command{ Use: `restore "directory"`, Short: `Restore config in backup "directory"`, Long: `Restore the crowdsec configuration from specified backup "directory" including: @@ -308,11 +252,16 @@ func NewConfigRestoreCmd() *cobra.Command { - Backup of API credentials (local API and online API)`, Args: cobra.ExactArgs(1), DisableAutoGenTag: true, - RunE: runConfigRestore, - } + RunE: func(_ *cobra.Command, args []string) error { + dirPath := args[0] - flags := cmdConfigRestore.Flags() - flags.BoolP("old-backup", "", false, "To use when you are upgrading crowdsec v0.X to v1.X and you need to restore backup from v0.X") + if err := cli.restore(dirPath); err != nil { + return fmt.Errorf("failed to restore config from %s: %w", dirPath, err) + } + + return nil + }, + } - return cmdConfigRestore + return cmd } diff --git a/cmd/crowdsec-cli/config_show.go b/cmd/crowdsec-cli/config_show.go index bab911cc340..634ca77410e 100644 --- a/cmd/crowdsec-cli/config_show.go +++ b/cmd/crowdsec-cli/config_show.go @@ -182,31 +182,26 @@ Central API: {{- end }} ` -func runConfigShow(cmd *cobra.Command, args []string) error { - flags := cmd.Flags() +func (cli *cliConfig) show(key string) error { + cfg := cli.cfg() - if err := csConfig.LoadAPIClient(); err != nil { + if err := cfg.LoadAPIClient(); err != nil { log.Errorf("failed to load API client configuration: %s", err) // don't return, we can still show the configuration } - key, err := flags.GetString("key") - if err != nil { - return err - } - if key != "" { return showConfigKey(key) } - switch csConfig.Cscli.Output { + switch cfg.Cscli.Output { case "human": // The tests on .Enable look funny because the option has a true default which has // not been set yet (we don't really load the LAPI) and go templates don't dereference // pointers in boolean tests. Prefix notation is the cherry on top. funcs := template.FuncMap{ // can't use generics here - "ValueBool": func(b *bool) bool { return b!=nil && *b }, + "ValueBool": func(b *bool) bool { return b != nil && *b }, } tmp, err := template.New("config").Funcs(funcs).Parse(configShowTemplate) @@ -214,19 +209,19 @@ func runConfigShow(cmd *cobra.Command, args []string) error { return err } - err = tmp.Execute(os.Stdout, csConfig) + err = tmp.Execute(os.Stdout, cfg) if err != nil { return err } case "json": - data, err := json.MarshalIndent(csConfig, "", " ") + data, err := json.MarshalIndent(cfg, "", " ") if err != nil { return fmt.Errorf("failed to marshal configuration: %w", err) } fmt.Printf("%s\n", string(data)) case "raw": - data, err := yaml.Marshal(csConfig) + data, err := yaml.Marshal(cfg) if err != nil { return fmt.Errorf("failed to marshal configuration: %w", err) } @@ -237,18 +232,22 @@ func runConfigShow(cmd *cobra.Command, args []string) error { return nil } -func NewConfigShowCmd() *cobra.Command { - cmdConfigShow := &cobra.Command{ +func (cli *cliConfig) newShowCmd() *cobra.Command { + var key string + + cmd := &cobra.Command{ Use: "show", Short: "Displays current config", Long: `Displays the current cli configuration.`, Args: cobra.ExactArgs(0), DisableAutoGenTag: true, - RunE: runConfigShow, + RunE: func(_ *cobra.Command, _ []string) error { + return cli.show(key) + }, } - flags := cmdConfigShow.Flags() - flags.StringP("key", "", "", "Display only this value (Config.API.Server.ListenURI)") + flags := cmd.Flags() + flags.StringVarP(&key, "key", "", "", "Display only this value (Config.API.Server.ListenURI)") - return cmdConfigShow + return cmd } diff --git a/cmd/crowdsec-cli/config_showyaml.go b/cmd/crowdsec-cli/config_showyaml.go index 82bc67ffcb8..52daee6a65e 100644 --- a/cmd/crowdsec-cli/config_showyaml.go +++ b/cmd/crowdsec-cli/config_showyaml.go @@ -6,19 +6,21 @@ import ( "github.com/spf13/cobra" ) -func runConfigShowYAML(cmd *cobra.Command, args []string) error { +func (cli *cliConfig) showYAML() error { fmt.Println(mergedConfig) return nil } -func NewConfigShowYAMLCmd() *cobra.Command { - cmdConfigShow := &cobra.Command{ +func (cli *cliConfig) newShowYAMLCmd() *cobra.Command { + cmd := &cobra.Command{ Use: "show-yaml", Short: "Displays merged config.yaml + config.yaml.local", Args: cobra.ExactArgs(0), DisableAutoGenTag: true, - RunE: runConfigShowYAML, + RunE: func(_ *cobra.Command, _ []string) error { + return cli.showYAML() + }, } - return cmdConfigShow + return cmd } diff --git a/cmd/crowdsec-cli/main.go b/cmd/crowdsec-cli/main.go index 27ac17d554f..1f87390b636 100644 --- a/cmd/crowdsec-cli/main.go +++ b/cmd/crowdsec-cli/main.go @@ -231,7 +231,7 @@ It is meant to allow you to manage bans, parsers/scenarios/etc, api and generall cmd.AddCommand(NewCLIDoc().NewCommand(cmd)) cmd.AddCommand(NewCLIVersion().NewCommand()) - cmd.AddCommand(NewConfigCmd()) + cmd.AddCommand(NewCLIConfig(cli.cfg).NewCommand()) cmd.AddCommand(NewCLIHub(cli.cfg).NewCommand()) cmd.AddCommand(NewCLIMetrics(cli.cfg).NewCommand()) cmd.AddCommand(NewCLIDashboard(cli.cfg).NewCommand()) From 8da490f5930406180bef6f4b0b99e0b0dc86dff8 Mon Sep 17 00:00:00 2001 From: mmetc <92726601+mmetc@users.noreply.github.com> Date: Thu, 22 Feb 2024 11:42:33 +0100 Subject: [PATCH 17/20] refact pkg/apiclient (#2846) * extract resperr.go * extract method prepareRequest() * reset token inside mutex --- pkg/apiclient/auth_jwt.go | 37 +++++++++++++++++++++---------- pkg/apiclient/client.go | 36 ------------------------------ pkg/apiclient/resperr.go | 46 +++++++++++++++++++++++++++++++++++++++ pkg/apiserver/apic.go | 1 - 4 files changed, 72 insertions(+), 48 deletions(-) create mode 100644 pkg/apiclient/resperr.go diff --git a/pkg/apiclient/auth_jwt.go b/pkg/apiclient/auth_jwt.go index 71b0e273105..2ead10cf6da 100644 --- a/pkg/apiclient/auth_jwt.go +++ b/pkg/apiclient/auth_jwt.go @@ -130,20 +130,24 @@ func (t *JWTTransport) refreshJwtToken() error { return nil } -// RoundTrip implements the RoundTripper interface. -func (t *JWTTransport) RoundTrip(req *http.Request) (*http.Response, error) { - // In a few occasions several goroutines will execute refreshJwtToken concurrently which is useless and will cause overload on CAPI - // we use a mutex to avoid this - // We also bypass the refresh if we are requesting the login endpoint, as it does not require a token, and it leads to do 2 requests instead of one (refresh + actual login request) +func (t *JWTTransport) needsTokenRefresh() bool { + return t.Token == "" || t.Expiration.Add(-time.Minute).Before(time.Now().UTC()) +} + +// prepareRequest returns a copy of the request with the necessary authentication headers. +func (t *JWTTransport) prepareRequest(req *http.Request) (*http.Request, error) { + // In a few occasions several goroutines will execute refreshJwtToken concurrently which is useless + // and will cause overload on CAPI. We use a mutex to avoid this. t.refreshTokenMutex.Lock() - if req.URL.Path != "/"+t.VersionPrefix+"/watchers/login" && (t.Token == "" || t.Expiration.Add(-time.Minute).Before(time.Now().UTC())) { - if err := t.refreshJwtToken(); err != nil { - t.refreshTokenMutex.Unlock() + defer t.refreshTokenMutex.Unlock() + // We bypass the refresh if we are requesting the login endpoint, as it does not require a token, + // and it leads to do 2 requests instead of one (refresh + actual login request). + if req.URL.Path != "/"+t.VersionPrefix+"/watchers/login" && t.needsTokenRefresh() { + if err := t.refreshJwtToken(); err != nil { return nil, err } } - t.refreshTokenMutex.Unlock() if t.UserAgent != "" { req.Header.Add("User-Agent", t.UserAgent) @@ -151,6 +155,16 @@ func (t *JWTTransport) RoundTrip(req *http.Request) (*http.Response, error) { req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", t.Token)) + return req, nil +} + +// RoundTrip implements the RoundTripper interface. +func (t *JWTTransport) RoundTrip(req *http.Request) (*http.Response, error) { + req, err := t.prepareRequest(req) + if err != nil { + return nil, err + } + if log.GetLevel() >= log.TraceLevel { //requestToDump := cloneRequest(req) dump, _ := httputil.DumpRequest(req, true) @@ -166,7 +180,7 @@ func (t *JWTTransport) RoundTrip(req *http.Request) (*http.Response, error) { if err != nil { // we had an error (network error for example, or 401 because token is refused), reset the token? - t.Token = "" + t.ResetToken() return resp, fmt.Errorf("performing jwt auth: %w", err) } @@ -189,7 +203,8 @@ func (t *JWTTransport) ResetToken() { t.refreshTokenMutex.Unlock() } -// transport() returns a round tripper that retries once when the status is unauthorized, and 5 times when the infrastructure is overloaded. +// transport() returns a round tripper that retries once when the status is unauthorized, +// and 5 times when the infrastructure is overloaded. func (t *JWTTransport) transport() http.RoundTripper { transport := t.Transport if transport == nil { diff --git a/pkg/apiclient/client.go b/pkg/apiclient/client.go index b183a8c7909..b487f68a698 100644 --- a/pkg/apiclient/client.go +++ b/pkg/apiclient/client.go @@ -4,9 +4,7 @@ import ( "context" "crypto/tls" "crypto/x509" - "encoding/json" "fmt" - "io" "net/http" "net/url" @@ -167,44 +165,10 @@ type Response struct { //... } -type ErrorResponse struct { - models.ErrorResponse -} - -func (e *ErrorResponse) Error() string { - err := fmt.Sprintf("API error: %s", *e.Message) - if len(e.Errors) > 0 { - err += fmt.Sprintf(" (%s)", e.Errors) - } - - return err -} - func newResponse(r *http.Response) *Response { return &Response{Response: r} } -func CheckResponse(r *http.Response) error { - if c := r.StatusCode; 200 <= c && c <= 299 || c == 304 { - return nil - } - - errorResponse := &ErrorResponse{} - - data, err := io.ReadAll(r.Body) - if err == nil && len(data)>0 { - err := json.Unmarshal(data, errorResponse) - if err != nil { - return fmt.Errorf("http code %d, invalid body: %w", r.StatusCode, err) - } - } else { - errorResponse.Message = new(string) - *errorResponse.Message = fmt.Sprintf("http code %d, no error message", r.StatusCode) - } - - return errorResponse -} - type ListOpts struct { //Page int //PerPage int diff --git a/pkg/apiclient/resperr.go b/pkg/apiclient/resperr.go new file mode 100644 index 00000000000..ff954a73609 --- /dev/null +++ b/pkg/apiclient/resperr.go @@ -0,0 +1,46 @@ +package apiclient + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + + "github.com/crowdsecurity/go-cs-lib/ptr" + + "github.com/crowdsecurity/crowdsec/pkg/models" +) + +type ErrorResponse struct { + models.ErrorResponse +} + +func (e *ErrorResponse) Error() string { + err := fmt.Sprintf("API error: %s", *e.Message) + if len(e.Errors) > 0 { + err += fmt.Sprintf(" (%s)", e.Errors) + } + + return err +} + +// CheckResponse verifies the API response and builds an appropriate Go error if necessary. +func CheckResponse(r *http.Response) error { + if c := r.StatusCode; 200 <= c && c <= 299 || c == 304 { + return nil + } + + ret := &ErrorResponse{} + + data, err := io.ReadAll(r.Body) + if err != nil || len(data) == 0 { + ret.Message = ptr.Of(fmt.Sprintf("http code %d, no error message", r.StatusCode)) + return ret + } + + if err := json.Unmarshal(data, ret); err != nil { + return fmt.Errorf("http code %d, invalid body: %w", r.StatusCode, err) + } + + return ret +} diff --git a/pkg/apiserver/apic.go b/pkg/apiserver/apic.go index 2fdb01144a0..f57ae685e45 100644 --- a/pkg/apiserver/apic.go +++ b/pkg/apiserver/apic.go @@ -539,7 +539,6 @@ func createAlertForDecision(decision *models.Decision) *models.Alert { scenario = *decision.Scenario scope = types.ListOrigin default: - // XXX: this or nil? scenario = "" scope = "" From 0df8f54fbbd08ab857e153229a43cf9e3c3f258e Mon Sep 17 00:00:00 2001 From: Laurence Jones Date: Thu, 22 Feb 2024 11:18:29 +0000 Subject: [PATCH 18/20] Add unix socket option to http plugin, we have to use this in conjunction with URL parameter as we dont know which path the user wants so if they would like to communicate over unix socket they need to use both, however, the hostname can be whatever they want. We could be a little smarter and actually parse the url, however, increasing code when a user can just define it correctly make no sense (#2764) --- cmd/notification-http/main.go | 42 +++++++++++++++++++++-------------- 1 file changed, 25 insertions(+), 17 deletions(-) diff --git a/cmd/notification-http/main.go b/cmd/notification-http/main.go index 340d462c175..382f30fea53 100644 --- a/cmd/notification-http/main.go +++ b/cmd/notification-http/main.go @@ -7,8 +7,10 @@ import ( "crypto/x509" "fmt" "io" + "net" "net/http" "os" + "strings" "github.com/crowdsecurity/crowdsec/pkg/protobufs" "github.com/hashicorp/go-hclog" @@ -19,6 +21,7 @@ import ( type PluginConfig struct { Name string `yaml:"name"` URL string `yaml:"url"` + UnixSocket string `yaml:"unix_socket"` Headers map[string]string `yaml:"headers"` SkipTLSVerification bool `yaml:"skip_tls_verification"` Method string `yaml:"method"` @@ -66,36 +69,40 @@ func getCertPool(caPath string) (*x509.CertPool, error) { return cp, nil } -func getTLSClient(tlsVerify bool, caPath, certPath, keyPath string) (*http.Client, error) { - var client *http.Client - - caCertPool, err := getCertPool(caPath) +func getTLSClient(c *PluginConfig) error { + caCertPool, err := getCertPool(c.CAPath) if err != nil { - return nil, err + return err } tlsConfig := &tls.Config{ RootCAs: caCertPool, - InsecureSkipVerify: tlsVerify, + InsecureSkipVerify: c.SkipTLSVerification, } - if certPath != "" && keyPath != "" { - logger.Info(fmt.Sprintf("Using client certificate '%s' and key '%s'", certPath, keyPath)) + if c.CertPath != "" && c.KeyPath != "" { + logger.Info(fmt.Sprintf("Using client certificate '%s' and key '%s'", c.CertPath, c.KeyPath)) - cert, err := tls.LoadX509KeyPair(certPath, keyPath) + cert, err := tls.LoadX509KeyPair(c.CertPath, c.KeyPath) if err != nil { - return nil, fmt.Errorf("unable to load client certificate '%s' and key '%s': %w", certPath, keyPath, err) + return fmt.Errorf("unable to load client certificate '%s' and key '%s': %w", c.CertPath, c.KeyPath, err) } tlsConfig.Certificates = []tls.Certificate{cert} } - - client = &http.Client{ - Transport: &http.Transport{ - TLSClientConfig: tlsConfig, - }, + transport := &http.Transport{ + TLSClientConfig: tlsConfig, + } + if c.UnixSocket != "" { + logger.Info(fmt.Sprintf("Using socket '%s'", c.UnixSocket)) + transport.DialContext = func(_ context.Context, _, _ string) (net.Conn, error) { + return net.Dial("unix", strings.TrimSuffix(c.UnixSocket, "/")) + } + } + c.Client = &http.Client{ + Transport: transport, } - return client, err + return nil } func (s *HTTPPlugin) Notify(ctx context.Context, notification *protobufs.Notification) (*protobufs.Empty, error) { @@ -135,6 +142,7 @@ func (s *HTTPPlugin) Notify(ctx context.Context, notification *protobufs.Notific if resp.StatusCode < 200 || resp.StatusCode >= 300 { logger.Warn(fmt.Sprintf("HTTP server returned non 200 status code: %d", resp.StatusCode)) + logger.Debug(fmt.Sprintf("HTTP server returned body: %s", string(respData))) return &protobufs.Empty{}, nil } @@ -147,7 +155,7 @@ func (s *HTTPPlugin) Configure(ctx context.Context, config *protobufs.Config) (* if err != nil { return nil, err } - d.Client, err = getTLSClient(d.SkipTLSVerification, d.CAPath, d.CertPath, d.KeyPath) + err = getTLSClient(&d) if err != nil { return nil, err } From e34af358d7b96df49634f28696e9c1b1f01e097c Mon Sep 17 00:00:00 2001 From: mmetc <92726601+mmetc@users.noreply.github.com> Date: Fri, 23 Feb 2024 10:37:04 +0100 Subject: [PATCH 19/20] refact cscli (globals) (#2854) * cscli capi: avoid globals, extract methods * cscli config restore: avoid global * cscli hubtest: avoid global * lint (whitespace, wrapped errors) --- cmd/crowdsec-cli/bouncers.go | 24 +-- cmd/crowdsec-cli/capi.go | 245 ++++++++++++++++------------- cmd/crowdsec-cli/config_restore.go | 54 ++++--- cmd/crowdsec-cli/hubtest.go | 170 +++++++++++--------- cmd/crowdsec-cli/main.go | 4 +- 5 files changed, 281 insertions(+), 216 deletions(-) diff --git a/cmd/crowdsec-cli/bouncers.go b/cmd/crowdsec-cli/bouncers.go index 717e9aef5fe..35f4320c520 100644 --- a/cmd/crowdsec-cli/bouncers.go +++ b/cmd/crowdsec-cli/bouncers.go @@ -3,6 +3,7 @@ package main import ( "encoding/csv" "encoding/json" + "errors" "fmt" "os" "slices" @@ -58,13 +59,16 @@ Note: This command requires database direct access, so is intended to be run on DisableAutoGenTag: true, PersistentPreRunE: func(_ *cobra.Command, _ []string) error { var err error - if err = require.LAPI(cli.cfg()); err != nil { + + cfg := cli.cfg() + + if err = require.LAPI(cfg); err != nil { return err } - cli.db, err = database.NewClient(cli.cfg().DbConfig) + cli.db, err = database.NewClient(cfg.DbConfig) if err != nil { - return fmt.Errorf("can't connect to the database: %s", err) + return fmt.Errorf("can't connect to the database: %w", err) } return nil @@ -84,7 +88,7 @@ func (cli *cliBouncers) list() error { bouncers, err := cli.db.ListBouncers() if err != nil { - return fmt.Errorf("unable to list bouncers: %s", err) + return fmt.Errorf("unable to list bouncers: %w", err) } switch cli.cfg().Cscli.Output { @@ -146,13 +150,13 @@ func (cli *cliBouncers) add(bouncerName string, key string) error { if key == "" { key, err = middlewares.GenerateAPIKey(keyLength) if err != nil { - return fmt.Errorf("unable to generate api key: %s", err) + return fmt.Errorf("unable to generate api key: %w", err) } } _, err = cli.db.CreateBouncer(bouncerName, "", middlewares.HashSHA512(key), types.ApiKeyAuthType) if err != nil { - return fmt.Errorf("unable to create bouncer: %s", err) + return fmt.Errorf("unable to create bouncer: %w", err) } switch cli.cfg().Cscli.Output { @@ -165,7 +169,7 @@ func (cli *cliBouncers) add(bouncerName string, key string) error { case "json": j, err := json.Marshal(key) if err != nil { - return fmt.Errorf("unable to marshal api key") + return errors.New("unable to marshal api key") } fmt.Print(string(j)) @@ -191,7 +195,7 @@ cscli bouncers add MyBouncerName --key `, flags := cmd.Flags() flags.StringP("length", "l", "", "length of the api key") - flags.MarkDeprecated("length", "use --key instead") + _ = flags.MarkDeprecated("length", "use --key instead") flags.StringVarP(&key, "key", "k", "", "api key for the bouncer") return cmd @@ -218,7 +222,7 @@ func (cli *cliBouncers) delete(bouncers []string) error { for _, bouncerID := range bouncers { err := cli.db.DeleteBouncer(bouncerID) if err != nil { - return fmt.Errorf("unable to delete bouncer '%s': %s", bouncerID, err) + return fmt.Errorf("unable to delete bouncer '%s': %w", bouncerID, err) } log.Infof("bouncer '%s' deleted successfully", bouncerID) @@ -280,7 +284,7 @@ func (cli *cliBouncers) prune(duration time.Duration, force bool) error { deleted, err := cli.db.BulkDeleteBouncers(bouncers) if err != nil { - return fmt.Errorf("unable to prune bouncers: %s", err) + return fmt.Errorf("unable to prune bouncers: %w", err) } fmt.Fprintf(os.Stderr, "Successfully deleted %d bouncers\n", deleted) diff --git a/cmd/crowdsec-cli/capi.go b/cmd/crowdsec-cli/capi.go index 358d91ee215..e56a8a74707 100644 --- a/cmd/crowdsec-cli/capi.go +++ b/cmd/crowdsec-cli/capi.go @@ -2,6 +2,7 @@ package main import ( "context" + "errors" "fmt" "net/url" "os" @@ -26,24 +27,29 @@ const ( CAPIURLPrefix = "v3" ) -type cliCapi struct{} +type cliCapi struct { + cfg configGetter +} -func NewCLICapi() *cliCapi { - return &cliCapi{} +func NewCLICapi(cfg configGetter) *cliCapi { + return &cliCapi{ + cfg: cfg, + } } -func (cli cliCapi) NewCommand() *cobra.Command { - var cmd = &cobra.Command{ +func (cli *cliCapi) NewCommand() *cobra.Command { + cmd := &cobra.Command{ Use: "capi [action]", Short: "Manage interaction with Central API (CAPI)", Args: cobra.MinimumNArgs(1), DisableAutoGenTag: true, PersistentPreRunE: func(_ *cobra.Command, _ []string) error { - if err := require.LAPI(csConfig); err != nil { + cfg := cli.cfg() + if err := require.LAPI(cfg); err != nil { return err } - if err := require.CAPI(csConfig); err != nil { + if err := require.CAPI(cfg); err != nil { return err } @@ -51,78 +57,92 @@ func (cli cliCapi) NewCommand() *cobra.Command { }, } - cmd.AddCommand(cli.NewRegisterCmd()) - cmd.AddCommand(cli.NewStatusCmd()) + cmd.AddCommand(cli.newRegisterCmd()) + cmd.AddCommand(cli.newStatusCmd()) return cmd } -func (cli cliCapi) NewRegisterCmd() *cobra.Command { +func (cli *cliCapi) register(capiUserPrefix string, outputFile string) error { + cfg := cli.cfg() + + capiUser, err := generateID(capiUserPrefix) + if err != nil { + return fmt.Errorf("unable to generate machine id: %w", err) + } + + password := strfmt.Password(generatePassword(passwordLength)) + + apiurl, err := url.Parse(types.CAPIBaseURL) + if err != nil { + return fmt.Errorf("unable to parse api url %s: %w", types.CAPIBaseURL, err) + } + + _, err = apiclient.RegisterClient(&apiclient.Config{ + MachineID: capiUser, + Password: password, + UserAgent: fmt.Sprintf("crowdsec/%s", version.String()), + URL: apiurl, + VersionPrefix: CAPIURLPrefix, + }, nil) + + if err != nil { + return fmt.Errorf("api client register ('%s'): %w", types.CAPIBaseURL, err) + } + + log.Infof("Successfully registered to Central API (CAPI)") + + var dumpFile string + + switch { + case outputFile != "": + dumpFile = outputFile + case cfg.API.Server.OnlineClient.CredentialsFilePath != "": + dumpFile = cfg.API.Server.OnlineClient.CredentialsFilePath + default: + dumpFile = "" + } + + apiCfg := csconfig.ApiCredentialsCfg{ + Login: capiUser, + Password: password.String(), + URL: types.CAPIBaseURL, + } + + apiConfigDump, err := yaml.Marshal(apiCfg) + if err != nil { + return fmt.Errorf("unable to marshal api credentials: %w", err) + } + + if dumpFile != "" { + err = os.WriteFile(dumpFile, apiConfigDump, 0o600) + if err != nil { + return fmt.Errorf("write api credentials in '%s' failed: %w", dumpFile, err) + } + + log.Infof("Central API credentials written to '%s'", dumpFile) + } else { + fmt.Println(string(apiConfigDump)) + } + + log.Warning(ReloadMessage()) + + return nil +} + +func (cli *cliCapi) newRegisterCmd() *cobra.Command { var ( capiUserPrefix string - outputFile string + outputFile string ) - var cmd = &cobra.Command{ + cmd := &cobra.Command{ Use: "register", Short: "Register to Central API (CAPI)", Args: cobra.MinimumNArgs(0), DisableAutoGenTag: true, RunE: func(_ *cobra.Command, _ []string) error { - var err error - capiUser, err := generateID(capiUserPrefix) - if err != nil { - return fmt.Errorf("unable to generate machine id: %s", err) - } - password := strfmt.Password(generatePassword(passwordLength)) - apiurl, err := url.Parse(types.CAPIBaseURL) - if err != nil { - return fmt.Errorf("unable to parse api url %s: %w", types.CAPIBaseURL, err) - } - _, err = apiclient.RegisterClient(&apiclient.Config{ - MachineID: capiUser, - Password: password, - UserAgent: fmt.Sprintf("crowdsec/%s", version.String()), - URL: apiurl, - VersionPrefix: CAPIURLPrefix, - }, nil) - - if err != nil { - return fmt.Errorf("api client register ('%s'): %w", types.CAPIBaseURL, err) - } - log.Printf("Successfully registered to Central API (CAPI)") - - var dumpFile string - - if outputFile != "" { - dumpFile = outputFile - } else if csConfig.API.Server.OnlineClient.CredentialsFilePath != "" { - dumpFile = csConfig.API.Server.OnlineClient.CredentialsFilePath - } else { - dumpFile = "" - } - apiCfg := csconfig.ApiCredentialsCfg{ - Login: capiUser, - Password: password.String(), - URL: types.CAPIBaseURL, - } - apiConfigDump, err := yaml.Marshal(apiCfg) - if err != nil { - return fmt.Errorf("unable to marshal api credentials: %w", err) - } - if dumpFile != "" { - err = os.WriteFile(dumpFile, apiConfigDump, 0o600) - if err != nil { - return fmt.Errorf("write api credentials in '%s' failed: %w", dumpFile, err) - } - log.Printf("Central API credentials written to '%s'", dumpFile) - } else { - fmt.Println(string(apiConfigDump)) - } - - log.Warning(ReloadMessage()) - - return nil + return cli.register(capiUserPrefix, outputFile) }, } @@ -136,59 +156,66 @@ func (cli cliCapi) NewRegisterCmd() *cobra.Command { return cmd } -func (cli cliCapi) NewStatusCmd() *cobra.Command { - cmd := &cobra.Command{ - Use: "status", - Short: "Check status with the Central API (CAPI)", - Args: cobra.MinimumNArgs(0), - DisableAutoGenTag: true, - RunE: func(_ *cobra.Command, _ []string) error { - if err := require.CAPIRegistered(csConfig); err != nil { - return err - } +func (cli *cliCapi) status() error { + cfg := cli.cfg() - password := strfmt.Password(csConfig.API.Server.OnlineClient.Credentials.Password) + if err := require.CAPIRegistered(cfg); err != nil { + return err + } - apiurl, err := url.Parse(csConfig.API.Server.OnlineClient.Credentials.URL) - if err != nil { - return fmt.Errorf("parsing api url ('%s'): %w", csConfig.API.Server.OnlineClient.Credentials.URL, err) - } + password := strfmt.Password(cfg.API.Server.OnlineClient.Credentials.Password) - hub, err := require.Hub(csConfig, nil, nil) - if err != nil { - return err - } + apiurl, err := url.Parse(cfg.API.Server.OnlineClient.Credentials.URL) + if err != nil { + return fmt.Errorf("parsing api url ('%s'): %w", cfg.API.Server.OnlineClient.Credentials.URL, err) + } - scenarios, err := hub.GetInstalledItemNames(cwhub.SCENARIOS) - if err != nil { - return fmt.Errorf("failed to get scenarios: %w", err) - } + hub, err := require.Hub(cfg, nil, nil) + if err != nil { + return err + } - if len(scenarios) == 0 { - return fmt.Errorf("no scenarios installed, abort") - } + scenarios, err := hub.GetInstalledItemNames(cwhub.SCENARIOS) + if err != nil { + return fmt.Errorf("failed to get scenarios: %w", err) + } - Client, err = apiclient.NewDefaultClient(apiurl, CAPIURLPrefix, fmt.Sprintf("crowdsec/%s", version.String()), nil) - if err != nil { - return fmt.Errorf("init default client: %w", err) - } + if len(scenarios) == 0 { + return errors.New("no scenarios installed, abort") + } - t := models.WatcherAuthRequest{ - MachineID: &csConfig.API.Server.OnlineClient.Credentials.Login, - Password: &password, - Scenarios: scenarios, - } + Client, err = apiclient.NewDefaultClient(apiurl, CAPIURLPrefix, fmt.Sprintf("crowdsec/%s", version.String()), nil) + if err != nil { + return fmt.Errorf("init default client: %w", err) + } - log.Infof("Loaded credentials from %s", csConfig.API.Server.OnlineClient.CredentialsFilePath) - log.Infof("Trying to authenticate with username %s on %s", csConfig.API.Server.OnlineClient.Credentials.Login, apiurl) + t := models.WatcherAuthRequest{ + MachineID: &cfg.API.Server.OnlineClient.Credentials.Login, + Password: &password, + Scenarios: scenarios, + } - _, _, err = Client.Auth.AuthenticateWatcher(context.Background(), t) - if err != nil { - return fmt.Errorf("failed to authenticate to Central API (CAPI): %w", err) - } - log.Infof("You can successfully interact with Central API (CAPI)") + log.Infof("Loaded credentials from %s", cfg.API.Server.OnlineClient.CredentialsFilePath) + log.Infof("Trying to authenticate with username %s on %s", cfg.API.Server.OnlineClient.Credentials.Login, apiurl) - return nil + _, _, err = Client.Auth.AuthenticateWatcher(context.Background(), t) + if err != nil { + return fmt.Errorf("failed to authenticate to Central API (CAPI): %w", err) + } + + log.Info("You can successfully interact with Central API (CAPI)") + + return nil +} + +func (cli *cliCapi) newStatusCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "status", + Short: "Check status with the Central API (CAPI)", + Args: cobra.MinimumNArgs(0), + DisableAutoGenTag: true, + RunE: func(_ *cobra.Command, _ []string) error { + return cli.status() }, } diff --git a/cmd/crowdsec-cli/config_restore.go b/cmd/crowdsec-cli/config_restore.go index 513f993ba80..ee7179b73c5 100644 --- a/cmd/crowdsec-cli/config_restore.go +++ b/cmd/crowdsec-cli/config_restore.go @@ -14,7 +14,9 @@ import ( ) func (cli *cliConfig) restoreHub(dirPath string) error { - hub, err := require.Hub(csConfig, require.RemoteHub(csConfig), nil) + cfg := cli.cfg() + + hub, err := require.Hub(cfg, require.RemoteHub(cfg), nil) if err != nil { return err } @@ -71,7 +73,7 @@ func (cli *cliConfig) restoreHub(dirPath string) error { } stage := file.Name() - stagedir := fmt.Sprintf("%s/%s/%s/", csConfig.ConfigPaths.ConfigDir, itype, stage) + stagedir := fmt.Sprintf("%s/%s/%s/", cfg.ConfigPaths.ConfigDir, itype, stage) log.Debugf("Found stage %s in %s, target directory : %s", stage, itype, stagedir) if err = os.MkdirAll(stagedir, os.ModePerm); err != nil { @@ -99,7 +101,7 @@ func (cli *cliConfig) restoreHub(dirPath string) error { } else { log.Infof("Going to restore local/tainted [%s]", file.Name()) sourceFile := fmt.Sprintf("%s/%s", itemDirectory, file.Name()) - destinationFile := fmt.Sprintf("%s/%s/%s", csConfig.ConfigPaths.ConfigDir, itype, file.Name()) + destinationFile := fmt.Sprintf("%s/%s/%s", cfg.ConfigPaths.ConfigDir, itype, file.Name()) if err = CopyFile(sourceFile, destinationFile); err != nil { return fmt.Errorf("failed copy %s %s to %s: %w", itype, sourceFile, destinationFile, err) @@ -127,17 +129,19 @@ func (cli *cliConfig) restoreHub(dirPath string) error { func (cli *cliConfig) restore(dirPath string) error { var err error + cfg := cli.cfg() + backupMain := fmt.Sprintf("%s/config.yaml", dirPath) if _, err = os.Stat(backupMain); err == nil { - if csConfig.ConfigPaths != nil && csConfig.ConfigPaths.ConfigDir != "" { - if err = CopyFile(backupMain, fmt.Sprintf("%s/config.yaml", csConfig.ConfigPaths.ConfigDir)); err != nil { - return fmt.Errorf("failed copy %s to %s: %w", backupMain, csConfig.ConfigPaths.ConfigDir, err) + if cfg.ConfigPaths != nil && cfg.ConfigPaths.ConfigDir != "" { + if err = CopyFile(backupMain, fmt.Sprintf("%s/config.yaml", cfg.ConfigPaths.ConfigDir)); err != nil { + return fmt.Errorf("failed copy %s to %s: %w", backupMain, cfg.ConfigPaths.ConfigDir, err) } } } // Now we have config.yaml, we should regenerate config struct to have rights paths etc - ConfigFilePath = fmt.Sprintf("%s/config.yaml", csConfig.ConfigPaths.ConfigDir) + ConfigFilePath = fmt.Sprintf("%s/config.yaml", cfg.ConfigPaths.ConfigDir) log.Debug("Reloading configuration") @@ -146,38 +150,40 @@ func (cli *cliConfig) restore(dirPath string) error { return fmt.Errorf("failed to reload configuration: %w", err) } + cfg = cli.cfg() + backupCAPICreds := fmt.Sprintf("%s/online_api_credentials.yaml", dirPath) if _, err = os.Stat(backupCAPICreds); err == nil { - if err = CopyFile(backupCAPICreds, csConfig.API.Server.OnlineClient.CredentialsFilePath); err != nil { - return fmt.Errorf("failed copy %s to %s: %w", backupCAPICreds, csConfig.API.Server.OnlineClient.CredentialsFilePath, err) + if err = CopyFile(backupCAPICreds, cfg.API.Server.OnlineClient.CredentialsFilePath); err != nil { + return fmt.Errorf("failed copy %s to %s: %w", backupCAPICreds, cfg.API.Server.OnlineClient.CredentialsFilePath, err) } } backupLAPICreds := fmt.Sprintf("%s/local_api_credentials.yaml", dirPath) if _, err = os.Stat(backupLAPICreds); err == nil { - if err = CopyFile(backupLAPICreds, csConfig.API.Client.CredentialsFilePath); err != nil { - return fmt.Errorf("failed copy %s to %s: %w", backupLAPICreds, csConfig.API.Client.CredentialsFilePath, err) + if err = CopyFile(backupLAPICreds, cfg.API.Client.CredentialsFilePath); err != nil { + return fmt.Errorf("failed copy %s to %s: %w", backupLAPICreds, cfg.API.Client.CredentialsFilePath, err) } } backupProfiles := fmt.Sprintf("%s/profiles.yaml", dirPath) if _, err = os.Stat(backupProfiles); err == nil { - if err = CopyFile(backupProfiles, csConfig.API.Server.ProfilesPath); err != nil { - return fmt.Errorf("failed copy %s to %s: %w", backupProfiles, csConfig.API.Server.ProfilesPath, err) + if err = CopyFile(backupProfiles, cfg.API.Server.ProfilesPath); err != nil { + return fmt.Errorf("failed copy %s to %s: %w", backupProfiles, cfg.API.Server.ProfilesPath, err) } } backupSimulation := fmt.Sprintf("%s/simulation.yaml", dirPath) if _, err = os.Stat(backupSimulation); err == nil { - if err = CopyFile(backupSimulation, csConfig.ConfigPaths.SimulationFilePath); err != nil { - return fmt.Errorf("failed copy %s to %s: %w", backupSimulation, csConfig.ConfigPaths.SimulationFilePath, err) + if err = CopyFile(backupSimulation, cfg.ConfigPaths.SimulationFilePath); err != nil { + return fmt.Errorf("failed copy %s to %s: %w", backupSimulation, cfg.ConfigPaths.SimulationFilePath, err) } } /*if there is a acquisition dir, restore its content*/ - if csConfig.Crowdsec.AcquisitionDirPath != "" { - if err = os.MkdirAll(csConfig.Crowdsec.AcquisitionDirPath, 0o700); err != nil { - return fmt.Errorf("error while creating %s: %w", csConfig.Crowdsec.AcquisitionDirPath, err) + if cfg.Crowdsec.AcquisitionDirPath != "" { + if err = os.MkdirAll(cfg.Crowdsec.AcquisitionDirPath, 0o700); err != nil { + return fmt.Errorf("error while creating %s: %w", cfg.Crowdsec.AcquisitionDirPath, err) } } @@ -186,8 +192,8 @@ func (cli *cliConfig) restore(dirPath string) error { if _, err = os.Stat(backupAcquisition); err == nil { log.Debugf("restoring backup'ed %s", backupAcquisition) - if err = CopyFile(backupAcquisition, csConfig.Crowdsec.AcquisitionFilePath); err != nil { - return fmt.Errorf("failed copy %s to %s: %w", backupAcquisition, csConfig.Crowdsec.AcquisitionFilePath, err) + if err = CopyFile(backupAcquisition, cfg.Crowdsec.AcquisitionFilePath); err != nil { + return fmt.Errorf("failed copy %s to %s: %w", backupAcquisition, cfg.Crowdsec.AcquisitionFilePath, err) } } @@ -195,7 +201,7 @@ func (cli *cliConfig) restore(dirPath string) error { acquisBackupDir := filepath.Join(dirPath, "acquis", "*.yaml") if acquisFiles, err := filepath.Glob(acquisBackupDir); err == nil { for _, acquisFile := range acquisFiles { - targetFname, err := filepath.Abs(csConfig.Crowdsec.AcquisitionDirPath + "/" + filepath.Base(acquisFile)) + targetFname, err := filepath.Abs(cfg.Crowdsec.AcquisitionDirPath + "/" + filepath.Base(acquisFile)) if err != nil { return fmt.Errorf("while saving %s to %s: %w", acquisFile, targetFname, err) } @@ -208,12 +214,12 @@ func (cli *cliConfig) restore(dirPath string) error { } } - if csConfig.Crowdsec != nil && len(csConfig.Crowdsec.AcquisitionFiles) > 0 { - for _, acquisFile := range csConfig.Crowdsec.AcquisitionFiles { + if cfg.Crowdsec != nil && len(cfg.Crowdsec.AcquisitionFiles) > 0 { + for _, acquisFile := range cfg.Crowdsec.AcquisitionFiles { log.Infof("backup filepath from dir -> %s", acquisFile) // if it was the default one, it has already been backed up - if csConfig.Crowdsec.AcquisitionFilePath == acquisFile { + if cfg.Crowdsec.AcquisitionFilePath == acquisFile { log.Infof("skip this one") continue } diff --git a/cmd/crowdsec-cli/hubtest.go b/cmd/crowdsec-cli/hubtest.go index 1860540e7dc..8f5ab087370 100644 --- a/cmd/crowdsec-cli/hubtest.go +++ b/cmd/crowdsec-cli/hubtest.go @@ -2,6 +2,7 @@ package main import ( "encoding/json" + "errors" "fmt" "math" "os" @@ -20,21 +21,29 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/hubtest" ) -var HubTest hubtest.HubTest -var HubAppsecTests hubtest.HubTest -var hubPtr *hubtest.HubTest -var isAppsecTest bool +var ( + HubTest hubtest.HubTest + HubAppsecTests hubtest.HubTest + hubPtr *hubtest.HubTest + isAppsecTest bool +) -type cliHubTest struct{} +type cliHubTest struct { + cfg configGetter +} -func NewCLIHubTest() *cliHubTest { - return &cliHubTest{} +func NewCLIHubTest(cfg configGetter) *cliHubTest { + return &cliHubTest{ + cfg: cfg, + } } -func (cli cliHubTest) NewCommand() *cobra.Command { - var hubPath string - var crowdsecPath string - var cscliPath string +func (cli *cliHubTest) NewCommand() *cobra.Command { + var ( + hubPath string + crowdsecPath string + cscliPath string + ) cmd := &cobra.Command{ Use: "hubtest", @@ -53,11 +62,13 @@ func (cli cliHubTest) NewCommand() *cobra.Command { if err != nil { return fmt.Errorf("unable to load appsec specific hubtest: %+v", err) } - /*commands will use the hubPtr, will point to the default hubTest object, or the one dedicated to appsec tests*/ + + // commands will use the hubPtr, will point to the default hubTest object, or the one dedicated to appsec tests hubPtr = &HubTest if isAppsecTest { hubPtr = &HubAppsecTests } + return nil }, } @@ -79,13 +90,16 @@ func (cli cliHubTest) NewCommand() *cobra.Command { return cmd } -func (cli cliHubTest) NewCreateCmd() *cobra.Command { +func (cli *cliHubTest) NewCreateCmd() *cobra.Command { + var ( + ignoreParsers bool + labels map[string]string + logType string + ) + parsers := []string{} postoverflows := []string{} scenarios := []string{} - var ignoreParsers bool - var labels map[string]string - var logType string cmd := &cobra.Command{ Use: "create", @@ -107,7 +121,7 @@ cscli hubtest create my-scenario-test --parsers crowdsecurity/nginx --scenarios } if logType == "" { - return fmt.Errorf("please provide a type (--type) for the test") + return errors.New("please provide a type (--type) for the test") } if err := os.MkdirAll(testPath, os.ModePerm); err != nil { @@ -118,7 +132,7 @@ cscli hubtest create my-scenario-test --parsers crowdsecurity/nginx --scenarios configFileData := &hubtest.HubTestItemConfig{} if logType == "appsec" { - //create empty nuclei template file + // create empty nuclei template file nucleiFileName := fmt.Sprintf("%s.yaml", testName) nucleiFilePath := filepath.Join(testPath, nucleiFileName) nucleiFile, err := os.OpenFile(nucleiFilePath, os.O_RDWR|os.O_CREATE, 0755) @@ -128,7 +142,7 @@ cscli hubtest create my-scenario-test --parsers crowdsecurity/nginx --scenarios ntpl := template.Must(template.New("nuclei").Parse(hubtest.TemplateNucleiFile)) if ntpl == nil { - return fmt.Errorf("unable to parse nuclei template") + return errors.New("unable to parse nuclei template") } ntpl.ExecuteTemplate(nucleiFile, "nuclei", struct{ TestName string }{TestName: testName}) nucleiFile.Close() @@ -188,24 +202,24 @@ cscli hubtest create my-scenario-test --parsers crowdsecurity/nginx --scenarios fmt.Printf(" Parser assertion file : %s (please fill it with assertion)\n", parserAssertFilePath) fmt.Printf(" Scenario assertion file : %s (please fill it with assertion)\n", scenarioAssertFilePath) fmt.Printf(" Configuration File : %s (please fill it with parsers, scenarios...)\n", configFilePath) - } fd, err := os.Create(configFilePath) if err != nil { - return fmt.Errorf("open: %s", err) + return fmt.Errorf("open: %w", err) } data, err := yaml.Marshal(configFileData) if err != nil { - return fmt.Errorf("marshal: %s", err) + return fmt.Errorf("marshal: %w", err) } _, err = fd.Write(data) if err != nil { - return fmt.Errorf("write: %s", err) + return fmt.Errorf("write: %w", err) } if err := fd.Close(); err != nil { - return fmt.Errorf("close: %s", err) + return fmt.Errorf("close: %w", err) } + return nil }, } @@ -219,20 +233,25 @@ cscli hubtest create my-scenario-test --parsers crowdsecurity/nginx --scenarios return cmd } -func (cli cliHubTest) NewRunCmd() *cobra.Command { - var noClean bool - var runAll bool - var forceClean bool - var NucleiTargetHost string - var AppSecHost string - var cmd = &cobra.Command{ +func (cli *cliHubTest) NewRunCmd() *cobra.Command { + var ( + noClean bool + runAll bool + forceClean bool + NucleiTargetHost string + AppSecHost string + ) + + cmd := &cobra.Command{ Use: "run", Short: "run [test_name]", DisableAutoGenTag: true, RunE: func(cmd *cobra.Command, args []string) error { + cfg := cli.cfg() + if !runAll && len(args) == 0 { printHelp(cmd) - return fmt.Errorf("please provide test to run or --all flag") + return errors.New("please provide test to run or --all flag") } hubPtr.NucleiTargetHost = NucleiTargetHost hubPtr.AppSecHost = AppSecHost @@ -244,7 +263,7 @@ func (cli cliHubTest) NewRunCmd() *cobra.Command { for _, testName := range args { _, err := hubPtr.LoadTestItem(testName) if err != nil { - return fmt.Errorf("unable to load test '%s': %s", testName, err) + return fmt.Errorf("unable to load test '%s': %w", testName, err) } } } @@ -252,7 +271,7 @@ func (cli cliHubTest) NewRunCmd() *cobra.Command { // set timezone to avoid DST issues os.Setenv("TZ", "UTC") for _, test := range hubPtr.Tests { - if csConfig.Cscli.Output == "human" { + if cfg.Cscli.Output == "human" { log.Infof("Running test '%s'", test.Name) } err := test.Run() @@ -264,6 +283,8 @@ func (cli cliHubTest) NewRunCmd() *cobra.Command { return nil }, PersistentPostRunE: func(_ *cobra.Command, _ []string) error { + cfg := cli.cfg() + success := true testResult := make(map[string]bool) for _, test := range hubPtr.Tests { @@ -280,7 +301,7 @@ func (cli cliHubTest) NewRunCmd() *cobra.Command { } if !noClean { if err := test.Clean(); err != nil { - return fmt.Errorf("unable to clean test '%s' env: %s", test.Name, err) + return fmt.Errorf("unable to clean test '%s' env: %w", test.Name, err) } } fmt.Printf("\nPlease fill your assert file(s) for test '%s', exiting\n", test.Name) @@ -288,18 +309,18 @@ func (cli cliHubTest) NewRunCmd() *cobra.Command { } testResult[test.Name] = test.Success if test.Success { - if csConfig.Cscli.Output == "human" { + if cfg.Cscli.Output == "human" { log.Infof("Test '%s' passed successfully (%d assertions)\n", test.Name, test.ParserAssert.NbAssert+test.ScenarioAssert.NbAssert) } if !noClean { if err := test.Clean(); err != nil { - return fmt.Errorf("unable to clean test '%s' env: %s", test.Name, err) + return fmt.Errorf("unable to clean test '%s' env: %w", test.Name, err) } } } else { success = false cleanTestEnv := false - if csConfig.Cscli.Output == "human" { + if cfg.Cscli.Output == "human" { if len(test.ParserAssert.Fails) > 0 { fmt.Println() log.Errorf("Parser test '%s' failed (%d errors)\n", test.Name, len(test.ParserAssert.Fails)) @@ -330,20 +351,20 @@ func (cli cliHubTest) NewRunCmd() *cobra.Command { Default: true, } if err := survey.AskOne(prompt, &cleanTestEnv); err != nil { - return fmt.Errorf("unable to ask to remove runtime folder: %s", err) + return fmt.Errorf("unable to ask to remove runtime folder: %w", err) } } } if cleanTestEnv || forceClean { if err := test.Clean(); err != nil { - return fmt.Errorf("unable to clean test '%s' env: %s", test.Name, err) + return fmt.Errorf("unable to clean test '%s' env: %w", test.Name, err) } } } } - switch csConfig.Cscli.Output { + switch cfg.Cscli.Output { case "human": hubTestResultTable(color.Output, testResult) case "json": @@ -359,11 +380,11 @@ func (cli cliHubTest) NewRunCmd() *cobra.Command { } jsonStr, err := json.Marshal(jsonResult) if err != nil { - return fmt.Errorf("unable to json test result: %s", err) + return fmt.Errorf("unable to json test result: %w", err) } fmt.Println(string(jsonStr)) default: - return fmt.Errorf("only human/json output modes are supported") + return errors.New("only human/json output modes are supported") } if !success { @@ -383,7 +404,7 @@ func (cli cliHubTest) NewRunCmd() *cobra.Command { return cmd } -func (cli cliHubTest) NewCleanCmd() *cobra.Command { +func (cli *cliHubTest) NewCleanCmd() *cobra.Command { var cmd = &cobra.Command{ Use: "clean", Short: "clean [test_name]", @@ -393,10 +414,10 @@ func (cli cliHubTest) NewCleanCmd() *cobra.Command { for _, testName := range args { test, err := hubPtr.LoadTestItem(testName) if err != nil { - return fmt.Errorf("unable to load test '%s': %s", testName, err) + return fmt.Errorf("unable to load test '%s': %w", testName, err) } if err := test.Clean(); err != nil { - return fmt.Errorf("unable to clean test '%s' env: %s", test.Name, err) + return fmt.Errorf("unable to clean test '%s' env: %w", test.Name, err) } } @@ -407,7 +428,7 @@ func (cli cliHubTest) NewCleanCmd() *cobra.Command { return cmd } -func (cli cliHubTest) NewInfoCmd() *cobra.Command { +func (cli *cliHubTest) NewInfoCmd() *cobra.Command { cmd := &cobra.Command{ Use: "info", Short: "info [test_name]", @@ -417,7 +438,7 @@ func (cli cliHubTest) NewInfoCmd() *cobra.Command { for _, testName := range args { test, err := hubPtr.LoadTestItem(testName) if err != nil { - return fmt.Errorf("unable to load test '%s': %s", testName, err) + return fmt.Errorf("unable to load test '%s': %w", testName, err) } fmt.Println() fmt.Printf(" Test name : %s\n", test.Name) @@ -440,17 +461,19 @@ func (cli cliHubTest) NewInfoCmd() *cobra.Command { return cmd } -func (cli cliHubTest) NewListCmd() *cobra.Command { +func (cli *cliHubTest) NewListCmd() *cobra.Command { cmd := &cobra.Command{ Use: "list", Short: "list", DisableAutoGenTag: true, RunE: func(_ *cobra.Command, _ []string) error { + cfg := cli.cfg() + if err := hubPtr.LoadAllTests(); err != nil { - return fmt.Errorf("unable to load all tests: %s", err) + return fmt.Errorf("unable to load all tests: %w", err) } - switch csConfig.Cscli.Output { + switch cfg.Cscli.Output { case "human": hubTestListTable(color.Output, hubPtr.Tests) case "json": @@ -460,7 +483,7 @@ func (cli cliHubTest) NewListCmd() *cobra.Command { } fmt.Println(string(j)) default: - return fmt.Errorf("only human/json output modes are supported") + return errors.New("only human/json output modes are supported") } return nil @@ -470,18 +493,22 @@ func (cli cliHubTest) NewListCmd() *cobra.Command { return cmd } -func (cli cliHubTest) NewCoverageCmd() *cobra.Command { - var showParserCov bool - var showScenarioCov bool - var showOnlyPercent bool - var showAppsecCov bool +func (cli *cliHubTest) NewCoverageCmd() *cobra.Command { + var ( + showParserCov bool + showScenarioCov bool + showOnlyPercent bool + showAppsecCov bool + ) cmd := &cobra.Command{ Use: "coverage", Short: "coverage", DisableAutoGenTag: true, RunE: func(_ *cobra.Command, _ []string) error { - //for this one we explicitly don't do for appsec + cfg := cli.cfg() + + // for this one we explicitly don't do for appsec if err := HubTest.LoadAllTests(); err != nil { return fmt.Errorf("unable to load all tests: %+v", err) } @@ -499,7 +526,7 @@ func (cli cliHubTest) NewCoverageCmd() *cobra.Command { if showParserCov || showAll { parserCoverage, err = HubTest.GetParsersCoverage() if err != nil { - return fmt.Errorf("while getting parser coverage: %s", err) + return fmt.Errorf("while getting parser coverage: %w", err) } parserTested := 0 for _, test := range parserCoverage { @@ -513,7 +540,7 @@ func (cli cliHubTest) NewCoverageCmd() *cobra.Command { if showScenarioCov || showAll { scenarioCoverage, err = HubTest.GetScenariosCoverage() if err != nil { - return fmt.Errorf("while getting scenario coverage: %s", err) + return fmt.Errorf("while getting scenario coverage: %w", err) } scenarioTested := 0 @@ -529,7 +556,7 @@ func (cli cliHubTest) NewCoverageCmd() *cobra.Command { if showAppsecCov || showAll { appsecRuleCoverage, err = HubTest.GetAppsecCoverage() if err != nil { - return fmt.Errorf("while getting scenario coverage: %s", err) + return fmt.Errorf("while getting scenario coverage: %w", err) } appsecRuleTested := 0 @@ -542,19 +569,20 @@ func (cli cliHubTest) NewCoverageCmd() *cobra.Command { } if showOnlyPercent { - if showAll { + switch { + case showAll: fmt.Printf("parsers=%d%%\nscenarios=%d%%\nappsec_rules=%d%%", parserCoveragePercent, scenarioCoveragePercent, appsecRuleCoveragePercent) - } else if showParserCov { + case showParserCov: fmt.Printf("parsers=%d%%", parserCoveragePercent) - } else if showScenarioCov { + case showScenarioCov: fmt.Printf("scenarios=%d%%", scenarioCoveragePercent) - } else if showAppsecCov { + case showAppsecCov: fmt.Printf("appsec_rules=%d%%", appsecRuleCoveragePercent) } os.Exit(0) } - switch csConfig.Cscli.Output { + switch cfg.Cscli.Output { case "human": if showParserCov || showAll { hubTestParserCoverageTable(color.Output, parserCoverage) @@ -595,7 +623,7 @@ func (cli cliHubTest) NewCoverageCmd() *cobra.Command { } fmt.Printf("%s", dump) default: - return fmt.Errorf("only human/json output modes are supported") + return errors.New("only human/json output modes are supported") } return nil @@ -610,7 +638,7 @@ func (cli cliHubTest) NewCoverageCmd() *cobra.Command { return cmd } -func (cli cliHubTest) NewEvalCmd() *cobra.Command { +func (cli *cliHubTest) NewEvalCmd() *cobra.Command { var evalExpression string cmd := &cobra.Command{ @@ -647,7 +675,7 @@ func (cli cliHubTest) NewEvalCmd() *cobra.Command { return cmd } -func (cli cliHubTest) NewExplainCmd() *cobra.Command { +func (cli *cliHubTest) NewExplainCmd() *cobra.Command { cmd := &cobra.Command{ Use: "explain", Short: "explain [test_name]", @@ -666,7 +694,7 @@ func (cli cliHubTest) NewExplainCmd() *cobra.Command { } if err = test.ParserAssert.LoadTest(test.ParserResultFile); err != nil { - return fmt.Errorf("unable to load parser result after run: %s", err) + return fmt.Errorf("unable to load parser result after run: %w", err) } } @@ -677,7 +705,7 @@ func (cli cliHubTest) NewExplainCmd() *cobra.Command { } if err = test.ScenarioAssert.LoadTest(test.ScenarioResultFile, test.BucketPourResultFile); err != nil { - return fmt.Errorf("unable to load scenario result after run: %s", err) + return fmt.Errorf("unable to load scenario result after run: %w", err) } } opts := dumps.DumpOpts{} diff --git a/cmd/crowdsec-cli/main.go b/cmd/crowdsec-cli/main.go index 1f87390b636..446901e4aa9 100644 --- a/cmd/crowdsec-cli/main.go +++ b/cmd/crowdsec-cli/main.go @@ -240,12 +240,12 @@ It is meant to allow you to manage bans, parsers/scenarios/etc, api and generall cmd.AddCommand(NewCLISimulation(cli.cfg).NewCommand()) cmd.AddCommand(NewCLIBouncers(cli.cfg).NewCommand()) cmd.AddCommand(NewCLIMachines(cli.cfg).NewCommand()) - cmd.AddCommand(NewCLICapi().NewCommand()) + cmd.AddCommand(NewCLICapi(cli.cfg).NewCommand()) cmd.AddCommand(NewCLILapi(cli.cfg).NewCommand()) cmd.AddCommand(NewCompletionCmd()) cmd.AddCommand(NewCLIConsole(cli.cfg).NewCommand()) cmd.AddCommand(NewCLIExplain(cli.cfg).NewCommand()) - cmd.AddCommand(NewCLIHubTest().NewCommand()) + cmd.AddCommand(NewCLIHubTest(cli.cfg).NewCommand()) cmd.AddCommand(NewCLINotifications(cli.cfg).NewCommand()) cmd.AddCommand(NewCLISupport().NewCommand()) cmd.AddCommand(NewCLIPapi(cli.cfg).NewCommand()) From 4bf640c6e86185b506fde7332a338ccf2eb711ca Mon Sep 17 00:00:00 2001 From: mmetc <92726601+mmetc@users.noreply.github.com> Date: Fri, 23 Feb 2024 14:03:50 +0100 Subject: [PATCH 20/20] refact pkg/apiserver (auth helpers) (#2856) --- pkg/apiserver/controllers/v1/alerts.go | 5 +--- pkg/apiserver/controllers/v1/heartbeat.go | 5 +--- pkg/apiserver/controllers/v1/metrics.go | 34 ++++++++++------------- pkg/apiserver/controllers/v1/utils.go | 32 +++++++++++++++++---- pkg/apiserver/middlewares/v1/api_key.go | 11 ++------ pkg/apiserver/middlewares/v1/jwt.go | 8 +++--- 6 files changed, 50 insertions(+), 45 deletions(-) diff --git a/pkg/apiserver/controllers/v1/alerts.go b/pkg/apiserver/controllers/v1/alerts.go index e7d106d72a3..ad183e4ba80 100644 --- a/pkg/apiserver/controllers/v1/alerts.go +++ b/pkg/apiserver/controllers/v1/alerts.go @@ -9,7 +9,6 @@ import ( "strings" "time" - jwt "github.com/appleboy/gin-jwt/v2" "github.com/gin-gonic/gin" "github.com/go-openapi/strfmt" "github.com/google/uuid" @@ -143,9 +142,7 @@ func normalizeScope(scope string) string { func (c *Controller) CreateAlert(gctx *gin.Context) { var input models.AddAlertsRequest - claims := jwt.ExtractClaims(gctx) - // TBD: use defined rather than hardcoded key to find back owner - machineID := claims["id"].(string) + machineID, _ := getMachineIDFromContext(gctx) if err := gctx.ShouldBindJSON(&input); err != nil { gctx.JSON(http.StatusBadRequest, gin.H{"message": err.Error()}) diff --git a/pkg/apiserver/controllers/v1/heartbeat.go b/pkg/apiserver/controllers/v1/heartbeat.go index b19b450f0d5..e1231eaa9ec 100644 --- a/pkg/apiserver/controllers/v1/heartbeat.go +++ b/pkg/apiserver/controllers/v1/heartbeat.go @@ -3,14 +3,11 @@ package v1 import ( "net/http" - jwt "github.com/appleboy/gin-jwt/v2" "github.com/gin-gonic/gin" ) func (c *Controller) HeartBeat(gctx *gin.Context) { - claims := jwt.ExtractClaims(gctx) - // TBD: use defined rather than hardcoded key to find back owner - machineID := claims["id"].(string) + machineID, _ := getMachineIDFromContext(gctx) if err := c.DBClient.UpdateMachineLastHeartBeat(machineID); err != nil { c.HandleDBErrors(gctx, err) diff --git a/pkg/apiserver/controllers/v1/metrics.go b/pkg/apiserver/controllers/v1/metrics.go index 13ccf9ac94f..ddb38512a11 100644 --- a/pkg/apiserver/controllers/v1/metrics.go +++ b/pkg/apiserver/controllers/v1/metrics.go @@ -3,7 +3,6 @@ package v1 import ( "time" - jwt "github.com/appleboy/gin-jwt/v2" "github.com/gin-gonic/gin" "github.com/prometheus/client_golang/prometheus" ) @@ -66,32 +65,29 @@ var LapiResponseTime = prometheus.NewHistogramVec( []string{"endpoint", "method"}) func PrometheusBouncersHasEmptyDecision(c *gin.Context) { - name, ok := c.Get("BOUNCER_NAME") - if ok { + bouncer, _ := getBouncerFromContext(c) + if bouncer != nil { LapiNilDecisions.With(prometheus.Labels{ - "bouncer": name.(string)}).Inc() + "bouncer": bouncer.Name}).Inc() } } func PrometheusBouncersHasNonEmptyDecision(c *gin.Context) { - name, ok := c.Get("BOUNCER_NAME") - if ok { + bouncer, _ := getBouncerFromContext(c) + if bouncer != nil { LapiNonNilDecisions.With(prometheus.Labels{ - "bouncer": name.(string)}).Inc() + "bouncer": bouncer.Name}).Inc() } } func PrometheusMachinesMiddleware() gin.HandlerFunc { return func(c *gin.Context) { - claims := jwt.ExtractClaims(c) - if claims != nil { - if rawID, ok := claims["id"]; ok { - machineID := rawID.(string) - LapiMachineHits.With(prometheus.Labels{ - "machine": machineID, - "route": c.Request.URL.Path, - "method": c.Request.Method}).Inc() - } + machineID, _ := getMachineIDFromContext(c) + if machineID != "" { + LapiMachineHits.With(prometheus.Labels{ + "machine": machineID, + "route": c.Request.URL.Path, + "method": c.Request.Method}).Inc() } c.Next() @@ -100,10 +96,10 @@ func PrometheusMachinesMiddleware() gin.HandlerFunc { func PrometheusBouncersMiddleware() gin.HandlerFunc { return func(c *gin.Context) { - name, ok := c.Get("BOUNCER_NAME") - if ok { + bouncer, _ := getBouncerFromContext(c) + if bouncer != nil { LapiBouncerHits.With(prometheus.Labels{ - "bouncer": name.(string), + "bouncer": bouncer.Name, "route": c.Request.URL.Path, "method": c.Request.Method}).Inc() } diff --git a/pkg/apiserver/controllers/v1/utils.go b/pkg/apiserver/controllers/v1/utils.go index 6afd005132a..6f14dd9204e 100644 --- a/pkg/apiserver/controllers/v1/utils.go +++ b/pkg/apiserver/controllers/v1/utils.go @@ -1,30 +1,50 @@ package v1 import ( - "fmt" + "errors" "net/http" + jwt "github.com/appleboy/gin-jwt/v2" "github.com/gin-gonic/gin" + middlewares "github.com/crowdsecurity/crowdsec/pkg/apiserver/middlewares/v1" "github.com/crowdsecurity/crowdsec/pkg/database/ent" ) -const bouncerContextKey = "bouncer_info" - func getBouncerFromContext(ctx *gin.Context) (*ent.Bouncer, error) { - bouncerInterface, exist := ctx.Get(bouncerContextKey) + bouncerInterface, exist := ctx.Get(middlewares.BouncerContextKey) if !exist { - return nil, fmt.Errorf("bouncer not found") + return nil, errors.New("bouncer not found") } bouncerInfo, ok := bouncerInterface.(*ent.Bouncer) if !ok { - return nil, fmt.Errorf("bouncer not found") + return nil, errors.New("bouncer not found") } return bouncerInfo, nil } +func getMachineIDFromContext(ctx *gin.Context) (string, error) { + claims := jwt.ExtractClaims(ctx) + if claims == nil { + return "", errors.New("failed to extract claims") + } + + rawID, ok := claims[middlewares.MachineIDKey] + if !ok { + return "", errors.New("MachineID not found in claims") + } + + id, ok := rawID.(string) + if !ok { + // should never happen + return "", errors.New("failed to cast machineID to string") + } + + return id, nil +} + func (c *Controller) AbortRemoteIf(option bool) gin.HandlerFunc { return func(gctx *gin.Context) { incomingIP := gctx.ClientIP() diff --git a/pkg/apiserver/middlewares/v1/api_key.go b/pkg/apiserver/middlewares/v1/api_key.go index 41ee15b4417..4e273371bfe 100644 --- a/pkg/apiserver/middlewares/v1/api_key.go +++ b/pkg/apiserver/middlewares/v1/api_key.go @@ -18,9 +18,9 @@ import ( const ( APIKeyHeader = "X-Api-Key" - bouncerContextKey = "bouncer_info" - // max allowed by bcrypt 72 = 54 bytes in base64 + BouncerContextKey = "bouncer_info" dummyAPIKeySize = 54 + // max allowed by bcrypt 72 = 54 bytes in base64 ) type APIKey struct { @@ -159,11 +159,6 @@ func (a *APIKey) MiddlewareFunc() gin.HandlerFunc { "name": bouncer.Name, }) - // maybe we want to store the whole bouncer object in the context instead, this would avoid another db query - // in StreamDecision - c.Set("BOUNCER_NAME", bouncer.Name) - c.Set("BOUNCER_HASHED_KEY", bouncer.APIKey) - if bouncer.IPAddress == "" { if err := a.DbClient.UpdateBouncerIP(c.ClientIP(), bouncer.ID); err != nil { logger.Errorf("Failed to update ip address for '%s': %s\n", bouncer.Name, err) @@ -203,7 +198,7 @@ func (a *APIKey) MiddlewareFunc() gin.HandlerFunc { } } - c.Set(bouncerContextKey, bouncer) + c.Set(BouncerContextKey, bouncer) c.Next() } } diff --git a/pkg/apiserver/middlewares/v1/jwt.go b/pkg/apiserver/middlewares/v1/jwt.go index ed4ad107b96..6fe053713bc 100644 --- a/pkg/apiserver/middlewares/v1/jwt.go +++ b/pkg/apiserver/middlewares/v1/jwt.go @@ -22,7 +22,7 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/types" ) -var identityKey = "id" +const MachineIDKey = "id" type JWT struct { Middleware *jwt.GinJWTMiddleware @@ -33,7 +33,7 @@ type JWT struct { func PayloadFunc(data interface{}) jwt.MapClaims { if value, ok := data.(*models.WatcherAuthRequest); ok { return jwt.MapClaims{ - identityKey: &value.MachineID, + MachineIDKey: &value.MachineID, } } @@ -42,7 +42,7 @@ func PayloadFunc(data interface{}) jwt.MapClaims { func IdentityHandler(c *gin.Context) interface{} { claims := jwt.ExtractClaims(c) - machineID := claims[identityKey].(string) + machineID := claims[MachineIDKey].(string) return &models.WatcherAuthRequest{ MachineID: &machineID, @@ -307,7 +307,7 @@ func NewJWT(dbClient *database.Client) (*JWT, error) { Key: secret, Timeout: time.Hour, MaxRefresh: time.Hour, - IdentityKey: identityKey, + IdentityKey: MachineIDKey, PayloadFunc: PayloadFunc, IdentityHandler: IdentityHandler, Authenticator: jwtMiddleware.Authenticator,