From 4521a98ecc34cf4beeee653f1eb5914f1879f42d Mon Sep 17 00:00:00 2001 From: mmetc <92726601+mmetc@users.noreply.github.com> Date: Mon, 17 Jun 2024 10:39:50 +0200 Subject: [PATCH] db: don't set machine heartbeat until first connection (#3019) * db: don't set machine heartbeat until first connection * cscli machines prune: if hearbeat is not set, look at creation date * lint --- cmd/crowdsec-cli/machines.go | 2 +- pkg/database/ent/machine/machine.go | 2 -- pkg/database/ent/machine_create.go | 4 --- pkg/database/ent/runtime.go | 4 --- pkg/database/ent/schema/machine.go | 2 +- pkg/database/machines.go | 45 ++++++++++++++++++++--------- test/bats/30_machines.bats | 7 +++++ 7 files changed, 41 insertions(+), 25 deletions(-) diff --git a/cmd/crowdsec-cli/machines.go b/cmd/crowdsec-cli/machines.go index 7beaa5c7fdd..20933dc28e5 100644 --- a/cmd/crowdsec-cli/machines.go +++ b/cmd/crowdsec-cli/machines.go @@ -414,7 +414,7 @@ func (cli *cliMachines) prune(duration time.Duration, notValidOnly bool, force b } if !notValidOnly { - if pending, err := cli.db.QueryLastValidatedHeartbeatLT(time.Now().UTC().Add(-duration)); err == nil { + if pending, err := cli.db.QueryMachinesInactiveSince(time.Now().UTC().Add(-duration)); err == nil { machines = append(machines, pending...) } } diff --git a/pkg/database/ent/machine/machine.go b/pkg/database/ent/machine/machine.go index 46ea6deb03d..d7dece9f8ef 100644 --- a/pkg/database/ent/machine/machine.go +++ b/pkg/database/ent/machine/machine.go @@ -87,8 +87,6 @@ var ( UpdateDefaultUpdatedAt func() time.Time // DefaultLastPush holds the default value on creation for the "last_push" field. DefaultLastPush func() time.Time - // DefaultLastHeartbeat holds the default value on creation for the "last_heartbeat" field. - DefaultLastHeartbeat func() time.Time // ScenariosValidator is a validator for the "scenarios" field. It is called by the builders before save. ScenariosValidator func(string) error // DefaultIsValidated holds the default value on creation for the "isValidated" field. diff --git a/pkg/database/ent/machine_create.go b/pkg/database/ent/machine_create.go index 8d4bfb74b2a..2e4cf9f1500 100644 --- a/pkg/database/ent/machine_create.go +++ b/pkg/database/ent/machine_create.go @@ -227,10 +227,6 @@ func (mc *MachineCreate) defaults() { v := machine.DefaultLastPush() mc.mutation.SetLastPush(v) } - if _, ok := mc.mutation.LastHeartbeat(); !ok { - v := machine.DefaultLastHeartbeat() - mc.mutation.SetLastHeartbeat(v) - } if _, ok := mc.mutation.IsValidated(); !ok { v := machine.DefaultIsValidated mc.mutation.SetIsValidated(v) diff --git a/pkg/database/ent/runtime.go b/pkg/database/ent/runtime.go index b4da6dfb9db..8d50d916029 100644 --- a/pkg/database/ent/runtime.go +++ b/pkg/database/ent/runtime.go @@ -142,10 +142,6 @@ func init() { machineDescLastPush := machineFields[2].Descriptor() // machine.DefaultLastPush holds the default value on creation for the last_push field. machine.DefaultLastPush = machineDescLastPush.Default.(func() time.Time) - // machineDescLastHeartbeat is the schema descriptor for last_heartbeat field. - machineDescLastHeartbeat := machineFields[3].Descriptor() - // machine.DefaultLastHeartbeat holds the default value on creation for the last_heartbeat field. - machine.DefaultLastHeartbeat = machineDescLastHeartbeat.Default.(func() time.Time) // machineDescScenarios is the schema descriptor for scenarios field. machineDescScenarios := machineFields[7].Descriptor() // machine.ScenariosValidator is a validator for the "scenarios" field. It is called by the builders before save. diff --git a/pkg/database/ent/schema/machine.go b/pkg/database/ent/schema/machine.go index 997a2041453..7b4d97ed35c 100644 --- a/pkg/database/ent/schema/machine.go +++ b/pkg/database/ent/schema/machine.go @@ -4,6 +4,7 @@ import ( "entgo.io/ent" "entgo.io/ent/schema/edge" "entgo.io/ent/schema/field" + "github.com/crowdsecurity/crowdsec/pkg/types" ) @@ -25,7 +26,6 @@ func (Machine) Fields() []ent.Field { Default(types.UtcNow). Nillable().Optional(), field.Time("last_heartbeat"). - Default(types.UtcNow). Nillable().Optional(), field.String("machineId"). Unique(). diff --git a/pkg/database/machines.go b/pkg/database/machines.go index 7a64c1d4d6e..18fd32fdd84 100644 --- a/pkg/database/machines.go +++ b/pkg/database/machines.go @@ -13,8 +13,10 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/types" ) -const CapiMachineID = types.CAPIOrigin -const CapiListsMachineID = types.ListOrigin +const ( + CapiMachineID = types.CAPIOrigin + CapiListsMachineID = types.ListOrigin +) func (c *Client) CreateMachine(machineID *string, password *strfmt.Password, ipAddress string, isValidated bool, force bool, authType string) (*ent.Machine, error) { hashPassword, err := bcrypt.GenerateFromPassword([]byte(*password), bcrypt.DefaultCost) @@ -30,6 +32,7 @@ func (c *Client) CreateMachine(machineID *string, password *strfmt.Password, ipA if err != nil { return nil, errors.Wrapf(QueryFail, "machine '%s': %s", *machineID, err) } + if len(machineExist) > 0 { if force { _, err := c.Ent.Machine.Update().Where(machine.MachineIdEQ(*machineID)).SetPassword(string(hashPassword)).Save(c.CTX) @@ -37,12 +40,15 @@ func (c *Client) CreateMachine(machineID *string, password *strfmt.Password, ipA c.Log.Warningf("CreateMachine : %s", err) return nil, errors.Wrapf(UpdateFail, "machine '%s'", *machineID) } + machine, err := c.QueryMachineByID(*machineID) if err != nil { return nil, errors.Wrapf(QueryFail, "machine '%s': %s", *machineID, err) } + return machine, nil } + return nil, errors.Wrapf(UserExists, "user '%s'", *machineID) } @@ -54,7 +60,6 @@ func (c *Client) CreateMachine(machineID *string, password *strfmt.Password, ipA SetIsValidated(isValidated). SetAuthType(authType). Save(c.CTX) - if err != nil { c.Log.Warningf("CreateMachine : %s", err) return nil, errors.Wrapf(InsertFail, "creating machine '%s'", *machineID) @@ -72,6 +77,7 @@ func (c *Client) QueryMachineByID(machineID string) (*ent.Machine, error) { c.Log.Warningf("QueryMachineByID : %s", err) return &ent.Machine{}, errors.Wrapf(UserNotExists, "user '%s'", machineID) } + return machine, nil } @@ -80,6 +86,7 @@ func (c *Client) ListMachines() ([]*ent.Machine, error) { if err != nil { return nil, errors.Wrapf(QueryFail, "listing machines: %s", err) } + return machines, nil } @@ -88,21 +95,21 @@ func (c *Client) ValidateMachine(machineID string) error { if err != nil { return errors.Wrapf(UpdateFail, "validating machine: %s", err) } + if rets == 0 { - return fmt.Errorf("machine not found") + return errors.New("machine not found") } + return nil } func (c *Client) QueryPendingMachine() ([]*ent.Machine, error) { - var machines []*ent.Machine - var err error - - machines, err = c.Ent.Machine.Query().Where(machine.IsValidatedEQ(false)).All(c.CTX) + machines, err := c.Ent.Machine.Query().Where(machine.IsValidatedEQ(false)).All(c.CTX) if err != nil { c.Log.Warningf("QueryPendingMachine : %s", err) return nil, errors.Wrapf(QueryFail, "querying pending machines: %s", err) } + return machines, nil } @@ -116,7 +123,7 @@ func (c *Client) DeleteWatcher(name string) error { } if nbDeleted == 0 { - return fmt.Errorf("machine doesn't exist") + return errors.New("machine doesn't exist") } return nil @@ -127,10 +134,12 @@ func (c *Client) BulkDeleteWatchers(machines []*ent.Machine) (int, error) { for i, b := range machines { ids[i] = b.ID } + nbDeleted, err := c.Ent.Machine.Delete().Where(machine.IDIn(ids...)).Exec(c.CTX) if err != nil { return nbDeleted, err } + return nbDeleted, nil } @@ -139,6 +148,7 @@ func (c *Client) UpdateMachineLastHeartBeat(machineID string) error { if err != nil { return errors.Wrapf(UpdateFail, "updating machine last_heartbeat: %s", err) } + return nil } @@ -150,6 +160,7 @@ func (c *Client) UpdateMachineScenarios(scenarios string, ID int) error { if err != nil { return fmt.Errorf("unable to update machine in database: %s", err) } + return nil } @@ -160,6 +171,7 @@ func (c *Client) UpdateMachineIP(ipAddr string, ID int) error { if err != nil { return fmt.Errorf("unable to update machine IP in database: %s", err) } + return nil } @@ -170,6 +182,7 @@ func (c *Client) UpdateMachineVersion(ipAddr string, ID int) error { if err != nil { return fmt.Errorf("unable to update machine version in database: %s", err) } + return nil } @@ -178,17 +191,23 @@ func (c *Client) IsMachineRegistered(machineID string) (bool, error) { if err != nil { return false, err } + if len(exist) == 1 { return true, nil } + if len(exist) > 1 { - return false, fmt.Errorf("more than one item with the same machineID in database") + return false, errors.New("more than one item with the same machineID in database") } return false, nil - } -func (c *Client) QueryLastValidatedHeartbeatLT(t time.Time) ([]*ent.Machine, error) { - return c.Ent.Machine.Query().Where(machine.LastHeartbeatLT(t), machine.IsValidatedEQ(true)).All(c.CTX) +func (c *Client) QueryMachinesInactiveSince(t time.Time) ([]*ent.Machine, error) { + return c.Ent.Machine.Query().Where( + machine.Or( + machine.And(machine.LastHeartbeatLT(t), machine.IsValidatedEQ(true)), + machine.And(machine.LastHeartbeatIsNil(), machine.CreatedAtLT(t)), + ), + ).All(c.CTX) } diff --git a/test/bats/30_machines.bats b/test/bats/30_machines.bats index 415e5f8693f..1d65151b6c8 100644 --- a/test/bats/30_machines.bats +++ b/test/bats/30_machines.bats @@ -62,6 +62,13 @@ teardown() { assert_output 1 } +@test "heartbeat is initially null" { + rune -0 cscli machines add foo --auto --file /dev/null + rune -0 cscli machines list -o json + rune -0 yq '.[] | select(.machineId == "foo") | .last_heartbeat' <(output) + assert_output null +} + @test "register, validate and then remove a machine" { rune -0 cscli lapi register --machine CiTestMachineRegister -f /dev/null -o human assert_stderr --partial "Successfully registered to Local API (LAPI)"