Skip to content

Commit

Permalink
cscli machines: extract list(), avoid globals (dbClient, csConfig)
Browse files Browse the repository at this point in the history
  • Loading branch information
mmetc committed Jan 24, 2024
1 parent 5e723f7 commit 8896b38
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 69 deletions.
2 changes: 1 addition & 1 deletion cmd/crowdsec-cli/bouncers.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ func (cli *cliBouncers) list() error {
enc.SetIndent("", " ")

if err := enc.Encode(bouncers); err != nil {
return fmt.Errorf("failed to unmarshal: %w", err)
return fmt.Errorf("failed to marshal: %w", err)
}

return nil
Expand Down
141 changes: 76 additions & 65 deletions cmd/crowdsec-cli/machines.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"encoding/csv"
"encoding/json"
"fmt"
"io"
"math/big"
"os"
"strings"
Expand Down Expand Up @@ -101,50 +100,15 @@ func getLastHeartbeat(m *ent.Machine) (string, bool) {
return hb, true
}

func getAgents(out io.Writer, dbClient *database.Client) error {
machines, err := dbClient.ListMachines()
if err != nil {
return fmt.Errorf("unable to list machines: %s", err)
}

switch csConfig.Cscli.Output {
case "human":
getAgentsTable(out, machines)
case "json":
enc := json.NewEncoder(out)
enc.SetIndent("", " ")
if err := enc.Encode(machines); err != nil {
return fmt.Errorf("failed to marshal")
}
return nil
case "raw":
csvwriter := csv.NewWriter(out)
err := csvwriter.Write([]string{"machine_id", "ip_address", "updated_at", "validated", "version", "auth_type", "last_heartbeat"})
if err != nil {
return fmt.Errorf("failed to write header: %s", err)
}
for _, m := range machines {
validated := "false"
if m.IsValidated {
validated = "true"
}
hb, _ := getLastHeartbeat(m)
err := csvwriter.Write([]string{m.MachineId, m.IpAddress, m.UpdatedAt.Format(time.RFC3339), validated, m.Version, m.AuthType, hb})
if err != nil {
return fmt.Errorf("failed to write raw output: %w", err)
}
}
csvwriter.Flush()
default:
return fmt.Errorf("unknown output '%s'", csConfig.Cscli.Output)
}
return nil
type cliMachines struct{
db *database.Client
cfg func() *csconfig.Config
}

type cliMachines struct{}

func NewCLIMachines() *cliMachines {
return &cliMachines{}
func NewCLIMachines(getconfig func() *csconfig.Config) *cliMachines {
return &cliMachines{
cfg: getconfig,
}
}

func (cli *cliMachines) NewCommand() *cobra.Command {
Expand All @@ -159,10 +123,10 @@ Note: This command requires database direct access, so is intended to be run on
Aliases: []string{"machine"},
PersistentPreRunE: func(_ *cobra.Command, _ []string) error {
var err error
if err = require.LAPI(csConfig); err != nil {
if err = require.LAPI(cli.cfg()); err != nil {
return err
}
dbClient, err = database.NewClient(csConfig.DbConfig)
cli.db, err = database.NewClient(cli.cfg().DbConfig)
if err != nil {
return fmt.Errorf("unable to create new database client: %s", err)
}
Expand All @@ -179,6 +143,55 @@ Note: This command requires database direct access, so is intended to be run on
return cmd
}

func (cli *cliMachines) list() error {
out := color.Output

machines, err := cli.db.ListMachines()
if err != nil {
return fmt.Errorf("unable to list machines: %s", err)
}

switch cli.cfg().Cscli.Output {
case "human":
getAgentsTable(out, machines)
case "json":
enc := json.NewEncoder(out)
enc.SetIndent("", " ")

if err := enc.Encode(machines); err != nil {
return fmt.Errorf("failed to marshal")
}

return nil
case "raw":
csvwriter := csv.NewWriter(out)

err := csvwriter.Write([]string{"machine_id", "ip_address", "updated_at", "validated", "version", "auth_type", "last_heartbeat"})
if err != nil {
return fmt.Errorf("failed to write header: %s", err)
}

for _, m := range machines {
validated := "false"
if m.IsValidated {
validated = "true"
}

hb, _ := getLastHeartbeat(m)

if err := csvwriter.Write([]string{m.MachineId, m.IpAddress, m.UpdatedAt.Format(time.RFC3339), validated, m.Version, m.AuthType, hb}); err != nil {
return fmt.Errorf("failed to write raw output: %w", err)
}
}

csvwriter.Flush()
default:
return fmt.Errorf("unknown output '%s'", cli.cfg().Cscli.Output)
}

return nil
}

func (cli *cliMachines) NewListCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "list",
Expand All @@ -188,12 +201,7 @@ func (cli *cliMachines) NewListCmd() *cobra.Command {
Args: cobra.NoArgs,
DisableAutoGenTag: true,
RunE: func(_ *cobra.Command, _ []string) error {
err := getAgents(color.Output, dbClient)
if err != nil {
return fmt.Errorf("unable to list machines: %s", err)
}

return nil
return cli.list()
},
}

Expand Down Expand Up @@ -256,15 +264,18 @@ func (cli *cliMachines) add(args []string, machinePassword string, dumpFile stri
machineID = args[0]
}

clientCfg := cli.cfg().API.Client
serverCfg := cli.cfg().API.Server

/*check if file already exists*/
if dumpFile == "" && csConfig.API.Client != nil && csConfig.API.Client.CredentialsFilePath != "" {
credFile := csConfig.API.Client.CredentialsFilePath
if dumpFile == "" && clientCfg != nil && clientCfg.CredentialsFilePath != "" {
credFile := clientCfg.CredentialsFilePath
// use the default only if the file does not exist
_, err = os.Stat(credFile)

switch {
case os.IsNotExist(err) || force:
dumpFile = csConfig.API.Client.CredentialsFilePath
dumpFile = credFile
case err != nil:
return fmt.Errorf("unable to stat '%s': %s", credFile, err)
default:
Expand All @@ -291,18 +302,18 @@ func (cli *cliMachines) add(args []string, machinePassword string, dumpFile stri

password := strfmt.Password(machinePassword)

_, err = dbClient.CreateMachine(&machineID, &password, "", true, force, types.PasswordAuthType)
_, err = cli.db.CreateMachine(&machineID, &password, "", true, force, types.PasswordAuthType)
if err != nil {
return fmt.Errorf("unable to create machine: %s", err)
}

fmt.Printf("Machine '%s' successfully added to the local API.\n", machineID)

if apiURL == "" {
if csConfig.API.Client != nil && csConfig.API.Client.Credentials != nil && csConfig.API.Client.Credentials.URL != "" {
apiURL = csConfig.API.Client.Credentials.URL
} else if csConfig.API.Server != nil && csConfig.API.Server.ListenURI != "" {
apiURL = "http://" + csConfig.API.Server.ListenURI
if clientCfg != nil && clientCfg.Credentials != nil && clientCfg.Credentials.URL != "" {
apiURL = clientCfg.Credentials.URL
} else if serverCfg != nil && serverCfg.ListenURI != "" {
apiURL = "http://" + serverCfg.ListenURI
} else {
return fmt.Errorf("unable to dump an api URL. Please provide it in your configuration or with the -u parameter")
}
Expand Down Expand Up @@ -333,7 +344,7 @@ func (cli *cliMachines) add(args []string, machinePassword string, dumpFile stri
}

func (cli *cliMachines) deleteValid(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
machines, err := dbClient.ListMachines()
machines, err := cli.db.ListMachines()
if err != nil {
cobra.CompError("unable to list machines " + err.Error())
}
Expand All @@ -351,7 +362,7 @@ func (cli *cliMachines) deleteValid(cmd *cobra.Command, args []string, toComplet

func (cli *cliMachines) delete(machines []string) error {
for _, machineID := range machines {
err := dbClient.DeleteWatcher(machineID)
err := cli.db.DeleteWatcher(machineID)
if err != nil {
log.Errorf("unable to delete machine '%s': %s", machineID, err)
return nil
Expand Down Expand Up @@ -392,12 +403,12 @@ func (cli *cliMachines) prune(duration time.Duration, notValidOnly bool, force b
}

machines := []*ent.Machine{}
if pending, err := dbClient.QueryPendingMachine(); err == nil {
if pending, err := cli.db.QueryPendingMachine(); err == nil {
machines = append(machines, pending...)
}

if !notValidOnly {
if pending, err := dbClient.QueryLastValidatedHeartbeatLT(time.Now().UTC().Add(duration)); err == nil {
if pending, err := cli.db.QueryLastValidatedHeartbeatLT(time.Now().UTC().Add(duration)); err == nil {
machines = append(machines, pending...)
}
}
Expand All @@ -420,7 +431,7 @@ func (cli *cliMachines) prune(duration time.Duration, notValidOnly bool, force b
}
}

deleted, err := dbClient.BulkDeleteWatchers(machines)
deleted, err := cli.db.BulkDeleteWatchers(machines)
if err != nil {
return fmt.Errorf("unable to prune machines: %s", err)
}
Expand Down Expand Up @@ -462,7 +473,7 @@ cscli machines prune --not-validated-only --force`,
}

func (cli *cliMachines) validate(machineID string) error {
if err := dbClient.ValidateMachine(machineID); err != nil {
if err := cli.db.ValidateMachine(machineID); err != nil {
return fmt.Errorf("unable to validate machine '%s': %s", machineID, err)
}
log.Infof("machine '%s' validated successfully", machineID)
Expand Down
2 changes: 1 addition & 1 deletion cmd/crowdsec-cli/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ It is meant to allow you to manage bans, parsers/scenarios/etc, api and generall
cmd.AddCommand(NewCLIAlerts().NewCommand())
cmd.AddCommand(NewCLISimulation().NewCommand())
cmd.AddCommand(NewCLIBouncers(getconfig).NewCommand())
cmd.AddCommand(NewCLIMachines().NewCommand())
cmd.AddCommand(NewCLIMachines(getconfig).NewCommand())
cmd.AddCommand(NewCLICapi().NewCommand())
cmd.AddCommand(NewLapiCmd())
cmd.AddCommand(NewCompletionCmd())
Expand Down
5 changes: 3 additions & 2 deletions cmd/crowdsec-cli/support.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,10 +159,11 @@ func collectBouncers(dbClient *database.Client) ([]byte, error) {

func collectAgents(dbClient *database.Client) ([]byte, error) {
out := bytes.NewBuffer(nil)
err := getAgents(out, dbClient)
machines, err := dbClient.ListMachines()
if err != nil {
return nil, err
return nil, fmt.Errorf("unable to list machines: %s", err)
}
getAgentsTable(out, machines)
return out.Bytes(), nil
}

Expand Down

0 comments on commit 8896b38

Please sign in to comment.