From a754cc756b2ea09c7001ec015850cfd98ff36962 Mon Sep 17 00:00:00 2001 From: Aleksandr Maus Date: Thu, 20 Jul 2023 09:23:47 -0400 Subject: [PATCH] Tamper protected Endpoint uninstall - combined PR (#2781) * Pass uninstall token from Agent uninstall command to Endpoint uninstall command * Support singed UNENROLL action and UNENROLL forwarding to Endpoint * Tamper protected Endpoint integration removal * Rename optional_actions to proxied_actions * Cleanup * Address code review, add missing comments for the public structures and variables * Change uninstall to fail the Agent uninstall if service (Endpoint) uninstall fails * Fix comments typos * Final touches after all testing * Implement tamper protected agent upgrade * Address code review feedback * Make linter happy * Added missing copyright header * Feature flag * Code fix after remerge with main changes * Restore logging statement that was unnecessarily changed * Address code review feedback * Implement UNENROLL/UPGRADE actions dispatch with backoff and timeout * Adjust backoff retry, more unit tests coverage * Make linter happy * Updated spec doc * Rename backoffActionDispatcher to proxiedActionsNotifier, and making notify private * Address code review * Fix typo for linter * Address code review comments * Remove unused funcs * Fix unit test * Updated components specs doc, with more clarification for the proxied actions * Update uninstall-token flag description --- ...1-Tamper-protected-Endpoint-uninstall.yaml | 31 ++ docs/component-specs.md | 11 + .../handlers/handler_action_unenroll.go | 32 +- .../handlers/handler_action_unenroll_test.go | 211 ++++++++++ .../handlers/handler_action_upgrade.go | 32 +- .../actions/handlers/handler_helpers.go | 174 ++++++++ .../actions/handlers/handler_helpers_test.go | 217 ++++++++++ .../application/coordinator/coordinator.go | 22 +- .../coordinator/coordinator_test.go | 4 +- .../coordinator/mocks/runtime_manager.go | 380 ++++++++++++++++++ .../pkg/agent/application/managed_mode.go | 1 + internal/pkg/agent/cmd/install.go | 4 +- internal/pkg/agent/cmd/uninstall.go | 4 +- internal/pkg/agent/install/install.go | 13 +- internal/pkg/agent/install/uninstall.go | 38 +- .../pkg/agent/storage/store/action_store.go | 7 +- internal/pkg/fleetapi/action.go | 42 +- internal/pkg/fleetapi/action_test.go | 57 +++ pkg/component/component.go | 107 +++++ pkg/component/component_test.go | 156 ++++++- pkg/component/input_spec.go | 15 +- pkg/component/runtime/command.go | 14 +- pkg/component/runtime/failed.go | 2 +- pkg/component/runtime/manager.go | 16 +- pkg/component/runtime/manager_test.go | 68 ++-- pkg/component/runtime/runtime.go | 6 +- pkg/component/runtime/service.go | 204 ++++++++-- pkg/component/runtime/service_test.go | 177 ++++++++ pkg/features/features.go | 42 +- specs/endpoint-security.spec.yml | 10 + 30 files changed, 1973 insertions(+), 124 deletions(-) create mode 100644 changelog/fragments/1688069371-Tamper-protected-Endpoint-uninstall.yaml create mode 100644 internal/pkg/agent/application/actions/handlers/handler_action_unenroll_test.go create mode 100644 internal/pkg/agent/application/actions/handlers/handler_helpers.go create mode 100644 internal/pkg/agent/application/actions/handlers/handler_helpers_test.go create mode 100644 internal/pkg/agent/application/coordinator/mocks/runtime_manager.go create mode 100644 pkg/component/runtime/service_test.go diff --git a/changelog/fragments/1688069371-Tamper-protected-Endpoint-uninstall.yaml b/changelog/fragments/1688069371-Tamper-protected-Endpoint-uninstall.yaml new file mode 100644 index 00000000000..49ae49c0457 --- /dev/null +++ b/changelog/fragments/1688069371-Tamper-protected-Endpoint-uninstall.yaml @@ -0,0 +1,31 @@ +# Kind can be one of: +# - breaking-change: a change to previously-documented behavior +# - deprecation: functionality that is being removed in a later release +# - bug-fix: fixes a problem in a previous version +# - enhancement: extends functionality but does not break or fix existing behavior +# - feature: new functionality +# - known-issue: problems that we are aware of in a given version +# - security: impacts on the security of a product or a user’s deployment. +# - upgrade: important information for someone upgrading from a prior version +# - other: does not fit into any of the other categories +kind: feature + +# Change summary; a 80ish characters long description of the change. +summary: Tamper protected Endpoint uninstall + +# Long description; in case the summary is not enough to describe the change +# this field accommodate a description without length limits. +# NOTE: This field will be rendered only for breaking-change and known-issue kinds at the moment. +description: | + Add new `--uninstall-token` to allow uninstall when Endpoint protection is enabled. + Enable unenroll and upgrade actions to complete successfully when Endpoint protection is enabled. + Enable Endpoint integration removal when Endpoint protection is enabled. + +# Affected component; a word indicating the component this changeset affects. +component: + +# PR URL; optional; the PR number that added the changeset. +# If not present is automatically filled by the tooling finding the PR where this changelog fragment has been added. +# NOTE: the tooling supports backports, so it's able to fill the original PR number instead of the backport PR number. +# Please provide it if you are adding a fragment for a different PR. +pr: https://github.com/elastic/elastic-agent/pull/2781 diff --git a/docs/component-specs.md b/docs/component-specs.md index ac29df86a31..212e30a0158 100644 --- a/docs/component-specs.md +++ b/docs/component-specs.md @@ -53,6 +53,17 @@ The platforms this input or shipper supports. Must contain one or more of the fo The output types this input or shipper supports. If this is an input, then inputs of this type can only target (non-shipper) output types in this list. If this is a shipper, then this shipper can only implement output types in this list. +### `proxied_actions` (list of strings) + +The action types that should be forwarded to the corresponding component. Currently these actions types are sent ("proxied") to the components in parallel. The agent action handler awaits for actions responses. If any of the proxied actions fail, the action is considered failed by the agent. Inital application for this was forwarding the Agent actions such as UNENROLL and UPGRADE to Endpoint service as a part of the Agent/Endpoint tamper protection feature. + +Example for Endpoint spec: +``` +proxied_actions: + - UNENROLL + - UPGRADE +``` + ### `shippers` (list of strings, input only) The shipper types this input supports. Inputs of this type can target any output type supported by the shippers in this list, as long as the output policy includes `shipper.enabled: true`. If an input supports more than one shipper implementing the same output type, then Agent will prefer the one that appears first in this list. diff --git a/internal/pkg/agent/application/actions/handlers/handler_action_unenroll.go b/internal/pkg/agent/application/actions/handlers/handler_action_unenroll.go index b4c4863b062..0fb9acb8da7 100644 --- a/internal/pkg/agent/application/actions/handlers/handler_action_unenroll.go +++ b/internal/pkg/agent/application/actions/handlers/handler_action_unenroll.go @@ -14,6 +14,7 @@ import ( "github.com/elastic/elastic-agent/internal/pkg/fleetapi" "github.com/elastic/elastic-agent/internal/pkg/fleetapi/acker" "github.com/elastic/elastic-agent/pkg/core/logger" + "github.com/elastic/elastic-agent/pkg/features" ) const ( @@ -32,23 +33,29 @@ type stateStore interface { // For it to be operational again it needs to be either enrolled or reconfigured. type Unenroll struct { log *logger.Logger + coord actionCoordinator ch chan coordinator.ConfigChange closers []context.CancelFunc stateStore stateStore + + tamperProtectionFn func() bool // allows to inject the flag for tests, defaults to features.TamperProtection } // NewUnenroll creates a new Unenroll handler. func NewUnenroll( log *logger.Logger, + coord actionCoordinator, ch chan coordinator.ConfigChange, closers []context.CancelFunc, stateStore stateStore, ) *Unenroll { return &Unenroll{ - log: log, - ch: ch, - closers: closers, - stateStore: stateStore, + log: log, + coord: coord, + ch: ch, + closers: closers, + stateStore: stateStore, + tamperProtectionFn: features.TamperProtection, } } @@ -60,11 +67,28 @@ func (h *Unenroll) Handle(ctx context.Context, a fleetapi.Action, acker acker.Ac return fmt.Errorf("invalid type, expected ActionUnenroll and received %T", a) } + if h.tamperProtectionFn() { + // Find inputs that want to receive UNENROLL action + // Endpoint needs to receive a signed UNENROLL action in order to be able to uncontain itself + state := h.coord.State() + ucs := findMatchingUnitsByActionType(state, a.Type()) + if len(ucs) > 0 { + err := notifyUnitsOfProxiedAction(ctx, h.log, action, ucs, h.coord.PerformAction) + if err != nil { + return err + } + } else { + // Log and continue + h.log.Debugf("No components running for %v action type", a.Type()) + } + } + if action.IsDetected { // not from Fleet; so we set it to nil so policyChange doesn't ack it a = nil } + // Generate empty policy change, this removing all the running components unenrollPolicy := newPolicyChange(ctx, config.New(), a, acker, true) h.ch <- unenrollPolicy diff --git a/internal/pkg/agent/application/actions/handlers/handler_action_unenroll_test.go b/internal/pkg/agent/application/actions/handlers/handler_action_unenroll_test.go new file mode 100644 index 00000000000..61e4ddfc9bb --- /dev/null +++ b/internal/pkg/agent/application/actions/handlers/handler_action_unenroll_test.go @@ -0,0 +1,211 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License; +// you may not use this file except in compliance with the Elastic License. + +package handlers + +import ( + "context" + "testing" + + "github.com/elastic/elastic-agent-client/v7/pkg/client" + "github.com/elastic/elastic-agent-client/v7/pkg/proto" + "github.com/elastic/elastic-agent/internal/pkg/agent/application/coordinator" + "github.com/elastic/elastic-agent/internal/pkg/fleetapi" + "github.com/elastic/elastic-agent/pkg/component" + "github.com/elastic/elastic-agent/pkg/component/runtime" + "github.com/elastic/elastic-agent/pkg/core/logger" + + "github.com/stretchr/testify/require" +) + +func makeComponentState(name string, proxiedActions []string) runtime.ComponentComponentState { + return runtime.ComponentComponentState{ + Component: component.Component{ + InputType: name, + Units: []component.Unit{ + { + Type: client.UnitTypeInput, + Config: &proto.UnitExpectedConfig{Type: name}, + }, + }, + InputSpec: &component.InputRuntimeSpec{ + Spec: component.InputSpec{ + Name: name, + ProxiedActions: proxiedActions, + }, + }, + }, + } +} + +type MockActionCoordinator struct { + st coordinator.State + performedActions int +} + +func (c *MockActionCoordinator) State() coordinator.State { + return c.st +} + +func (c *MockActionCoordinator) PerformAction(ctx context.Context, comp component.Component, unit component.Unit, name string, params map[string]interface{}) (map[string]interface{}, error) { + c.performedActions++ + return nil, nil +} + +func (c *MockActionCoordinator) Clear() { + c.performedActions = 0 +} + +type MockAcker struct { + Acked []fleetapi.Action +} + +func (m *MockAcker) Ack(_ context.Context, action fleetapi.Action) error { + m.Acked = append(m.Acked, action) + return nil +} + +func (m *MockAcker) Commit(_ context.Context) error { + return nil +} + +func (m *MockAcker) Clear() { + m.Acked = nil +} + +func TestActionUnenrollHandler(t *testing.T) { + ctx, cn := context.WithCancel(context.Background()) + defer cn() + + log, _ := logger.New("", false) + coord := &MockActionCoordinator{} + acker := &MockAcker{} + + action := &fleetapi.ActionUnenroll{ + ActionID: "c80e9219-70bf-43d3-b8cd-b5131a771751", + ActionType: "UNENROLL", + } + goodSigned := &fleetapi.Signed{ + Data: "eyJAdGltZXN0YW1wIjoiMjAyMy0wNS0yMlQxNzoxOToyOC40NjNaIiwiZXhwaXJhdGlvbiI6IjIwMjMtMDYtMjFUMTc6MTk6MjguNDYzWiIsImFnZW50cyI6WyI3ZjY0YWI2NC1hNmM0LTQ2ZTMtODIyYS0zODUxZGVkYTJmY2UiXSwiYWN0aW9uX2lkIjoiNGYwODQ2MGYtMDE0Yy00ZDllLWJmOGEtY2FhNjQyNzRhZGU0IiwidHlwZSI6IlVORU5ST0xMIiwidHJhY2VwYXJlbnQiOiIwMC1iOTBkYTlmOGNjNzdhODk0OTc0ZWIxZTIzMGNmNjc2Yy1lOTNlNzk4YTU4ODg2MDVhLTAxIn0=", + Signature: "MEUCIAxxsi9ff1zyV0+4fsJLqbP8Qb83tedU5iIFldtxEzEfAiEA0KUsrL7q+Fv7z6Boux3dY2P4emGi71jsMGanIZ552bM=", + } + action.Signed = goodSigned + + ch := make(chan coordinator.ConfigChange, 1) + go func() { + for { + select { + case <-ctx.Done(): + return + case policyChange := <-ch: + _ = policyChange.Ack() + } + } + }() + + handler := NewUnenroll(log, coord, ch, nil, nil) + + getTamperProtectionFunc := func(enabled bool) func() bool { + return func() bool { + return enabled + } + } + + tests := []struct { + name string + st coordinator.State + wantErr error // Handler error + wantPerformedActions int + tamperProtectionFn func() bool + }{ + { + name: "no running components", + }, + { + name: "endpoint no dispatch", + st: func() coordinator.State { + return coordinator.State{ + Components: []runtime.ComponentComponentState{ + makeComponentState("endpoint", nil), + }, + } + }(), + }, + { + name: "endpoint with UNENROLL, tamper protection feature flag disabled", + st: func() coordinator.State { + return coordinator.State{ + Components: []runtime.ComponentComponentState{ + makeComponentState("endpoint", []string{"UNENROLL"}), + makeComponentState("osquery", nil), + }, + } + }(), + wantPerformedActions: 0, + }, + { + name: "endpoint with UNENROLL, tamper protection feature flag enabled", + st: func() coordinator.State { + return coordinator.State{ + Components: []runtime.ComponentComponentState{ + makeComponentState("endpoint", []string{"UNENROLL"}), + makeComponentState("osquery", nil), + }, + } + }(), + tamperProtectionFn: getTamperProtectionFunc(true), + wantPerformedActions: 1, + }, + { + name: "more than one UNENROLL dispatch, tamper protection feature flag disabled", + st: func() coordinator.State { + return coordinator.State{ + Components: []runtime.ComponentComponentState{ + makeComponentState("endpoint", []string{"UNENROLL"}), + makeComponentState("foobar", []string{"UNENROLL", "FOOBAR"}), + makeComponentState("osquery", nil), + }, + } + }(), + wantPerformedActions: 0, + }, + { + name: "more than one UNENROLL dispatch, tamper protection feature flag enabled", + st: func() coordinator.State { + return coordinator.State{ + Components: []runtime.ComponentComponentState{ + makeComponentState("endpoint", []string{"UNENROLL"}), + makeComponentState("foobar", []string{"UNENROLL", "FOOBAR"}), + makeComponentState("osquery", nil), + }, + } + }(), + tamperProtectionFn: getTamperProtectionFunc(true), + wantPerformedActions: 2, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + defer acker.Clear() + defer coord.Clear() + + coord.st = tc.st + + if tc.tamperProtectionFn == nil { + handler.tamperProtectionFn = getTamperProtectionFunc(false) + } else { + handler.tamperProtectionFn = tc.tamperProtectionFn + } + + err := handler.Handle(ctx, action, acker) + + require.ErrorIs(t, err, tc.wantErr) + if tc.wantErr == nil { + require.Len(t, acker.Acked, 1) + } + require.Equal(t, tc.wantPerformedActions, coord.performedActions) + }) + } +} diff --git a/internal/pkg/agent/application/actions/handlers/handler_action_upgrade.go b/internal/pkg/agent/application/actions/handlers/handler_action_upgrade.go index 360de1e7d84..0214dfa4061 100644 --- a/internal/pkg/agent/application/actions/handlers/handler_action_upgrade.go +++ b/internal/pkg/agent/application/actions/handlers/handler_action_upgrade.go @@ -10,10 +10,10 @@ import ( "fmt" "sync" - "github.com/elastic/elastic-agent/internal/pkg/agent/application/coordinator" "github.com/elastic/elastic-agent/internal/pkg/fleetapi" "github.com/elastic/elastic-agent/internal/pkg/fleetapi/acker" "github.com/elastic/elastic-agent/pkg/core/logger" + "github.com/elastic/elastic-agent/pkg/features" ) // Upgrade is a handler for UPGRADE action. @@ -21,17 +21,20 @@ import ( // from repository specified by fleet. type Upgrade struct { log *logger.Logger - coord *coordinator.Coordinator + coord upgradeCoordinator bkgActions []fleetapi.Action bkgCancel context.CancelFunc bkgMutex sync.Mutex + + tamperProtectionFn func() bool // allows to inject the flag for tests, defaults to features.TamperProtection } // NewUpgrade creates a new Upgrade handler. -func NewUpgrade(log *logger.Logger, coord *coordinator.Coordinator) *Upgrade { +func NewUpgrade(log *logger.Logger, coord upgradeCoordinator) *Upgrade { return &Upgrade{ - log: log, - coord: coord, + log: log, + coord: coord, + tamperProtectionFn: features.TamperProtection, } } @@ -51,6 +54,25 @@ func (h *Upgrade) Handle(ctx context.Context, a fleetapi.Action, ack acker.Acker if !runAsync { return nil } + + if h.tamperProtectionFn() { + // Find inputs that want to receive UPGRADE action + // Endpoint needs to receive a signed UPGRADE action in order to be able to uncontain itself + state := h.coord.State() + ucs := findMatchingUnitsByActionType(state, a.Type()) + if len(ucs) > 0 { + h.log.Debugf("handlerUpgrade: proxy/dispatch action '%+v'", a) + err := notifyUnitsOfProxiedAction(ctx, h.log, action, ucs, h.coord.PerformAction) + h.log.Debugf("handlerUpgrade: after action dispatched '%+v', err: %v", a, err) + if err != nil { + return err + } + } else { + // Log and continue + h.log.Debugf("No components running for %v action type", a.Type()) + } + } + go func() { h.log.Infof("starting upgrade to version %s in background", action.Version) if err := h.coord.Upgrade(asyncCtx, action.Version, action.SourceURI, action, false); err != nil { diff --git a/internal/pkg/agent/application/actions/handlers/handler_helpers.go b/internal/pkg/agent/application/actions/handlers/handler_helpers.go new file mode 100644 index 00000000000..c10a8fbf678 --- /dev/null +++ b/internal/pkg/agent/application/actions/handlers/handler_helpers.go @@ -0,0 +1,174 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License; +// you may not use this file except in compliance with the Elastic License. + +package handlers + +import ( + "context" + "time" + + "golang.org/x/sync/errgroup" + + "github.com/elastic/elastic-agent-client/v7/pkg/client" + "github.com/elastic/elastic-agent-libs/logp" + + "github.com/elastic/elastic-agent/internal/pkg/agent/application/coordinator" + "github.com/elastic/elastic-agent/internal/pkg/agent/errors" + "github.com/elastic/elastic-agent/internal/pkg/core/backoff" + "github.com/elastic/elastic-agent/internal/pkg/fleetapi" + "github.com/elastic/elastic-agent/pkg/component" + "github.com/elastic/elastic-agent/pkg/component/runtime" +) + +type actionCoordinator interface { + State() coordinator.State + PerformAction(ctx context.Context, comp component.Component, unit component.Unit, name string, params map[string]interface{}) (map[string]interface{}, error) +} + +type upgradeCoordinator interface { + actionCoordinator + Upgrade(ctx context.Context, version string, sourceURI string, action *fleetapi.ActionUpgrade, skipVerifyOverride bool, pgpBytes ...string) error +} + +type performActionFunc func(context.Context, component.Component, component.Unit, string, map[string]interface{}) (map[string]interface{}, error) + +type dispatchableAction interface { + MarshalMap() (map[string]interface{}, error) + Type() string +} + +type unitWithComponent struct { + unit component.Unit + component component.Component +} + +func findMatchingUnitsByActionType(state coordinator.State, typ string) []unitWithComponent { + ucs := make([]unitWithComponent, 0) + for _, comp := range state.Components { + if comp.Component.InputSpec != nil && contains(comp.Component.InputSpec.Spec.ProxiedActions, typ) { + name := comp.Component.InputType + for _, unit := range comp.Component.Units { + // All input units should match the component input type, but let's be cautious + if unit.Type == client.UnitTypeInput && unit.Config != nil && unit.Config.Type == name { + ucs = append(ucs, unitWithComponent{unit, comp.Component}) + } + } + } + } + return ucs +} + +func contains[T comparable](arr []T, val T) bool { + for _, v := range arr { + if v == val { + return true + } + } + return false +} + +type proxiedActionsNotifier struct { + log *logp.Logger + performAction performActionFunc + + timeout time.Duration + minBackoff time.Duration + maxBackoff time.Duration +} + +const ( + defaultActionDispatcherTimeout = 40 * time.Second + defaultActionDispatcherBackoffMin = 500 * time.Millisecond + defaultActionDispatcherBackoffMax = 10 * time.Second +) + +func newProxiedActionsNotifier(log *logp.Logger, performAction performActionFunc) proxiedActionsNotifier { + return proxiedActionsNotifier{ + log: log, + performAction: performAction, + timeout: defaultActionDispatcherTimeout, + minBackoff: defaultActionDispatcherBackoffMin, + maxBackoff: defaultActionDispatcherBackoffMax, + } +} + +func (d proxiedActionsNotifier) notify(ctx context.Context, action dispatchableAction, ucs []unitWithComponent) error { + if action == nil { + return nil + } + + // Deserialize the action into map[string]interface{} for dispatching over to the apps + params, err := action.MarshalMap() + if err != nil { + return err + } + + actionType := action.Type() + + g, ctx := errgroup.WithContext(ctx) + + dispatch := func(uc unitWithComponent) error { + if uc.unit.Config == nil { + return nil + } + d.log.Debugf("Dispatch %v action to %v", actionType, uc.unit.Config.Type) + res, err := d.performAction(ctx, uc.component, uc.unit, uc.unit.Config.Type, params) + if err != nil { + d.log.Debugf("%v failed to dispatch to %v, err: %v", actionType, uc.component.ID, err) + // ErrNoUnit means that the unit is not longer available + // This can happen if the policy change updated state while the action proxying was retried + // Stop retrying proxying action to that unit return nil + if errors.Is(err, runtime.ErrNoUnit) { + d.log.Debugf("%v unit is not longer managed by runtime, possibly due to policy change", uc.component.ID) + return nil + } + return err + } + + strErr := readMapString(res, "error", "") + if strErr != "" { + d.log.Debugf("%v failed for %v, err: %v", actionType, uc.component.ID, strErr) + return errors.New(strErr) + } + return nil + } + + dispatchWithBackoff := func(uc unitWithComponent) error { + ctx, cn := context.WithTimeout(ctx, d.timeout) + defer cn() + + attempt := 1 + backExp := backoff.NewExpBackoff(ctx.Done(), d.minBackoff, d.maxBackoff) + start := time.Now() + + for { + err := dispatch(uc) + if err != nil { + if backExp.Wait() { + d.log.Debugf("%v action dispatch to %v with backoff attempt: %v, after %v since start", actionType, uc.component.ID, attempt, time.Since(start)) + attempt++ + continue + } + return err + } + return nil + } + } + + // Iterate through the components and dispatch the action is the action type is listed in the proxied_actions + for _, uc := range ucs { + // Send the action to the target unit via g.Go to collect any resulting errors + target := uc + g.Go(func() error { + return dispatchWithBackoff(target) + }) + } + + return g.Wait() +} + +// notifyUnitsOfProxiedAction dispatches actions to the units/components in parallel, with exponential backoff and timeout +func notifyUnitsOfProxiedAction(ctx context.Context, log *logp.Logger, action dispatchableAction, ucs []unitWithComponent, performAction performActionFunc) error { + return newProxiedActionsNotifier(log, performAction).notify(ctx, action, ucs) +} diff --git a/internal/pkg/agent/application/actions/handlers/handler_helpers_test.go b/internal/pkg/agent/application/actions/handlers/handler_helpers_test.go new file mode 100644 index 00000000000..c2edd1a13c3 --- /dev/null +++ b/internal/pkg/agent/application/actions/handlers/handler_helpers_test.go @@ -0,0 +1,217 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License; +// you may not use this file except in compliance with the Elastic License. + +package handlers + +import ( + "context" + "math" + "testing" + "time" + + "github.com/elastic/elastic-agent-client/v7/pkg/proto" + "github.com/elastic/elastic-agent-libs/logp" + + "github.com/elastic/elastic-agent/internal/pkg/agent/errors" + "github.com/elastic/elastic-agent/pkg/component" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" +) + +type testAction struct { + typ string + data map[string]interface{} +} + +func (a testAction) Type() string { + return a.typ +} + +func (a testAction) MarshalMap() (map[string]interface{}, error) { + return a.data, nil +} + +func TestNotifyUnitsOfProxiedAction(t *testing.T) { + ctx, cn := context.WithCancel(context.Background()) + defer cn() + log := logp.NewLogger("testing") + + happyPerformAction := func(context.Context, component.Component, component.Unit, string, map[string]interface{}) (map[string]interface{}, error) { + return nil, nil + } + + tests := []struct { + Name string + Action dispatchableAction + UCs []unitWithComponent + performAction performActionFunc + }{ + { + Name: "nil action", + }, + { + Name: "no components", + Action: testAction{ + typ: "UNENROLL", + }, + }, + { + Name: "one component", + Action: testAction{ + typ: "UNENROLL", + }, + UCs: []unitWithComponent{ + { + component: component.Component{}, + unit: component.Unit{ + Config: &proto.UnitExpectedConfig{ + Type: "endpoint", + }, + }, + }, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.Name, func(t *testing.T) { + if tc.performAction == nil { + tc.performAction = happyPerformAction + } + + err := notifyUnitsOfProxiedAction(ctx, log, tc.Action, tc.UCs, tc.performAction) + if err != nil { + t.Fatal(err) + } + }) + } +} + +type mockActionPerformer struct { + calledTimes int + shouldFailTimes int +} + +var errPerformActionFail = errors.New("perform action failed") + +func (p *mockActionPerformer) PerformAction(context.Context, component.Component, component.Unit, string, map[string]interface{}) (map[string]interface{}, error) { + p.calledTimes++ + if p.calledTimes < p.shouldFailTimes { + return nil, errPerformActionFail + } + return nil, nil +} + +type ActionPerformer interface { + PerformAction(context.Context, component.Component, component.Unit, string, map[string]interface{}) (map[string]interface{}, error) +} + +func TestProxiedActionsNotifier(t *testing.T) { + ctx, cn := context.WithCancel(context.Background()) + defer cn() + log := logp.NewLogger("testing") + + tests := []struct { + RetryTimeout time.Duration + RetryMinBackoff time.Duration + + Name string + Action dispatchableAction + UCs []unitWithComponent + ActionPerformer mockActionPerformer + WantActionPerformCalledTimes int + WantErr error + }{ + { + Name: "nil action", + }, + { + Name: "no components", + Action: testAction{ + typ: "UNENROLL", + }, + }, + { + Name: "one component, success", + Action: testAction{ + typ: "UNENROLL", + }, + UCs: []unitWithComponent{ + { + component: component.Component{}, + unit: component.Unit{ + Config: &proto.UnitExpectedConfig{ + Type: "endpoint", + }, + }, + }, + }, + WantActionPerformCalledTimes: 1, + }, + { + Name: "one component, failing, timeout", + Action: testAction{ + typ: "UNENROLL", + }, + ActionPerformer: mockActionPerformer{shouldFailTimes: math.MaxInt64}, + RetryMinBackoff: 300 * time.Millisecond, + RetryTimeout: 400 * time.Millisecond, + UCs: []unitWithComponent{ + { + component: component.Component{}, + unit: component.Unit{ + Config: &proto.UnitExpectedConfig{ + Type: "endpoint", + }, + }, + }, + }, + WantActionPerformCalledTimes: 1, + WantErr: errPerformActionFail, + }, + { + Name: "one component, retrying, succeeds", + Action: testAction{ + typ: "UNENROLL", + }, + ActionPerformer: mockActionPerformer{shouldFailTimes: 2}, + RetryMinBackoff: 200 * time.Millisecond, + RetryTimeout: 1 * time.Second, + UCs: []unitWithComponent{ + { + component: component.Component{}, + unit: component.Unit{ + Config: &proto.UnitExpectedConfig{ + Type: "endpoint", + }, + }, + }, + }, + WantActionPerformCalledTimes: 2, + }, + } + + for _, tc := range tests { + t.Run(tc.Name, func(t *testing.T) { + d := newProxiedActionsNotifier(log, tc.ActionPerformer.PerformAction) + if tc.RetryTimeout != 0 { + d.timeout = tc.RetryTimeout + } + if tc.RetryMinBackoff != 0 { + d.minBackoff = tc.RetryMinBackoff + } + err := d.notify(ctx, tc.Action, tc.UCs) + diff := cmp.Diff(tc.WantErr, err, cmpopts.EquateErrors()) + if diff != "" { + t.Fatal(diff) + } + + diff = cmp.Diff(tc.WantActionPerformCalledTimes, tc.ActionPerformer.calledTimes) + if diff != "" { + t.Fatal(diff) + } + }) + } +} diff --git a/internal/pkg/agent/application/coordinator/coordinator.go b/internal/pkg/agent/application/coordinator/coordinator.go index 79620c225f7..36a3f77c674 100644 --- a/internal/pkg/agent/application/coordinator/coordinator.go +++ b/internal/pkg/agent/application/coordinator/coordinator.go @@ -92,7 +92,7 @@ type RuntimeManager interface { Runner // Update updates the current components model. - Update([]component.Component) error + Update(model component.Model) error // State returns the current components model state. State() []runtime.ComponentComponentState @@ -1002,9 +1002,25 @@ func (c *Coordinator) process(ctx context.Context) (err error) { return err } + signed, err := component.SignedFromPolicy(c.derivedConfig) + if err != nil { + if !errors.Is(err, component.ErrNotFound) { + c.logger.Errorf("Failed to parse \"signed\" properties: %v", err) + return err + } + + // Some "signed" properties are not found, continue. + c.logger.Debugf("Continue with missing \"signed\" properties: %v", err) + } + + model := component.Model{ + Components: c.componentModel, + Signed: signed, + } + c.logger.Info("Updating running component model") - c.logger.With("components", c.componentModel).Debug("Updating running component model") - err = c.runtimeMgr.Update(c.componentModel) + c.logger.With("components", model.Components).Debug("Updating running component model") + err = c.runtimeMgr.Update(model) if err != nil { return err } diff --git a/internal/pkg/agent/application/coordinator/coordinator_test.go b/internal/pkg/agent/application/coordinator/coordinator_test.go index 2fac5cd4b15..f5439483be6 100644 --- a/internal/pkg/agent/application/coordinator/coordinator_test.go +++ b/internal/pkg/agent/application/coordinator/coordinator_test.go @@ -708,9 +708,9 @@ func (r *fakeRuntimeManager) Run(ctx context.Context) error { func (r *fakeRuntimeManager) Errors() <-chan error { return nil } -func (r *fakeRuntimeManager) Update(components []component.Component) error { +func (r *fakeRuntimeManager) Update(model component.Model) error { if r.updateCallback != nil { - return r.updateCallback(components) + return r.updateCallback(model.Components) } return nil } diff --git a/internal/pkg/agent/application/coordinator/mocks/runtime_manager.go b/internal/pkg/agent/application/coordinator/mocks/runtime_manager.go new file mode 100644 index 00000000000..dd583f673fd --- /dev/null +++ b/internal/pkg/agent/application/coordinator/mocks/runtime_manager.go @@ -0,0 +1,380 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License; +// you may not use this file except in compliance with the Elastic License. + +// Code generated by mockery v2.20.0. DO NOT EDIT. + +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License; +// you may not use this file except in compliance with the Elastic License. + +package mocks + +import ( + context "context" + + component "github.com/elastic/elastic-agent/pkg/component" + + mock "github.com/stretchr/testify/mock" + + runtime "github.com/elastic/elastic-agent/pkg/component/runtime" +) + +// RuntimeManager is an autogenerated mock type for the RuntimeManager type +type RuntimeManager struct { + mock.Mock +} + +type RuntimeManager_Expecter struct { + mock *mock.Mock +} + +func (_m *RuntimeManager) EXPECT() *RuntimeManager_Expecter { + return &RuntimeManager_Expecter{mock: &_m.Mock} +} + +// Errors provides a mock function with given fields: +func (_m *RuntimeManager) Errors() <-chan error { + ret := _m.Called() + + var r0 <-chan error + if rf, ok := ret.Get(0).(func() <-chan error); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(<-chan error) + } + } + + return r0 +} + +// RuntimeManager_Errors_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Errors' +type RuntimeManager_Errors_Call struct { + *mock.Call +} + +// Errors is a helper method to define mock.On call +func (_e *RuntimeManager_Expecter) Errors() *RuntimeManager_Errors_Call { + return &RuntimeManager_Errors_Call{Call: _e.mock.On("Errors")} +} + +func (_c *RuntimeManager_Errors_Call) Run(run func()) *RuntimeManager_Errors_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *RuntimeManager_Errors_Call) Return(_a0 <-chan error) *RuntimeManager_Errors_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *RuntimeManager_Errors_Call) RunAndReturn(run func() <-chan error) *RuntimeManager_Errors_Call { + _c.Call.Return(run) + return _c +} + +// PerformAction provides a mock function with given fields: ctx, comp, unit, name, params +func (_m *RuntimeManager) PerformAction(ctx context.Context, comp component.Component, unit component.Unit, name string, params map[string]interface{}) (map[string]interface{}, error) { + ret := _m.Called(ctx, comp, unit, name, params) + + var r0 map[string]interface{} + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, component.Component, component.Unit, string, map[string]interface{}) (map[string]interface{}, error)); ok { + return rf(ctx, comp, unit, name, params) + } + if rf, ok := ret.Get(0).(func(context.Context, component.Component, component.Unit, string, map[string]interface{}) map[string]interface{}); ok { + r0 = rf(ctx, comp, unit, name, params) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(map[string]interface{}) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, component.Component, component.Unit, string, map[string]interface{}) error); ok { + r1 = rf(ctx, comp, unit, name, params) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// RuntimeManager_PerformAction_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'PerformAction' +type RuntimeManager_PerformAction_Call struct { + *mock.Call +} + +// PerformAction is a helper method to define mock.On call +// - ctx context.Context +// - comp component.Component +// - unit component.Unit +// - name string +// - params map[string]interface{} +func (_e *RuntimeManager_Expecter) PerformAction(ctx interface{}, comp interface{}, unit interface{}, name interface{}, params interface{}) *RuntimeManager_PerformAction_Call { + return &RuntimeManager_PerformAction_Call{Call: _e.mock.On("PerformAction", ctx, comp, unit, name, params)} +} + +func (_c *RuntimeManager_PerformAction_Call) Run(run func(ctx context.Context, comp component.Component, unit component.Unit, name string, params map[string]interface{})) *RuntimeManager_PerformAction_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(component.Component), args[2].(component.Unit), args[3].(string), args[4].(map[string]interface{})) + }) + return _c +} + +func (_c *RuntimeManager_PerformAction_Call) Return(_a0 map[string]interface{}, _a1 error) *RuntimeManager_PerformAction_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *RuntimeManager_PerformAction_Call) RunAndReturn(run func(context.Context, component.Component, component.Unit, string, map[string]interface{}) (map[string]interface{}, error)) *RuntimeManager_PerformAction_Call { + _c.Call.Return(run) + return _c +} + +// PerformDiagnostics provides a mock function with given fields: _a0, _a1 +func (_m *RuntimeManager) PerformDiagnostics(_a0 context.Context, _a1 ...runtime.ComponentUnitDiagnosticRequest) []runtime.ComponentUnitDiagnostic { + _va := make([]interface{}, len(_a1)) + for _i := range _a1 { + _va[_i] = _a1[_i] + } + var _ca []interface{} + _ca = append(_ca, _a0) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 []runtime.ComponentUnitDiagnostic + if rf, ok := ret.Get(0).(func(context.Context, ...runtime.ComponentUnitDiagnosticRequest) []runtime.ComponentUnitDiagnostic); ok { + r0 = rf(_a0, _a1...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]runtime.ComponentUnitDiagnostic) + } + } + + return r0 +} + +// RuntimeManager_PerformDiagnostics_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'PerformDiagnostics' +type RuntimeManager_PerformDiagnostics_Call struct { + *mock.Call +} + +// PerformDiagnostics is a helper method to define mock.On call +// - _a0 context.Context +// - _a1 ...runtime.ComponentUnitDiagnosticRequest +func (_e *RuntimeManager_Expecter) PerformDiagnostics(_a0 interface{}, _a1 ...interface{}) *RuntimeManager_PerformDiagnostics_Call { + return &RuntimeManager_PerformDiagnostics_Call{Call: _e.mock.On("PerformDiagnostics", + append([]interface{}{_a0}, _a1...)...)} +} + +func (_c *RuntimeManager_PerformDiagnostics_Call) Run(run func(_a0 context.Context, _a1 ...runtime.ComponentUnitDiagnosticRequest)) *RuntimeManager_PerformDiagnostics_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]runtime.ComponentUnitDiagnosticRequest, len(args)-1) + for i, a := range args[1:] { + if a != nil { + variadicArgs[i] = a.(runtime.ComponentUnitDiagnosticRequest) + } + } + run(args[0].(context.Context), variadicArgs...) + }) + return _c +} + +func (_c *RuntimeManager_PerformDiagnostics_Call) Return(_a0 []runtime.ComponentUnitDiagnostic) *RuntimeManager_PerformDiagnostics_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *RuntimeManager_PerformDiagnostics_Call) RunAndReturn(run func(context.Context, ...runtime.ComponentUnitDiagnosticRequest) []runtime.ComponentUnitDiagnostic) *RuntimeManager_PerformDiagnostics_Call { + _c.Call.Return(run) + return _c +} + +// Run provides a mock function with given fields: _a0 +func (_m *RuntimeManager) Run(_a0 context.Context) error { + ret := _m.Called(_a0) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context) error); ok { + r0 = rf(_a0) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// RuntimeManager_Run_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Run' +type RuntimeManager_Run_Call struct { + *mock.Call +} + +// Run is a helper method to define mock.On call +// - _a0 context.Context +func (_e *RuntimeManager_Expecter) Run(_a0 interface{}) *RuntimeManager_Run_Call { + return &RuntimeManager_Run_Call{Call: _e.mock.On("Run", _a0)} +} + +func (_c *RuntimeManager_Run_Call) Run(run func(_a0 context.Context)) *RuntimeManager_Run_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context)) + }) + return _c +} + +func (_c *RuntimeManager_Run_Call) Return(_a0 error) *RuntimeManager_Run_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *RuntimeManager_Run_Call) RunAndReturn(run func(context.Context) error) *RuntimeManager_Run_Call { + _c.Call.Return(run) + return _c +} + +// State provides a mock function with given fields: +func (_m *RuntimeManager) State() []runtime.ComponentComponentState { + ret := _m.Called() + + var r0 []runtime.ComponentComponentState + if rf, ok := ret.Get(0).(func() []runtime.ComponentComponentState); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]runtime.ComponentComponentState) + } + } + + return r0 +} + +// RuntimeManager_State_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'State' +type RuntimeManager_State_Call struct { + *mock.Call +} + +// State is a helper method to define mock.On call +func (_e *RuntimeManager_Expecter) State() *RuntimeManager_State_Call { + return &RuntimeManager_State_Call{Call: _e.mock.On("State")} +} + +func (_c *RuntimeManager_State_Call) Run(run func()) *RuntimeManager_State_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *RuntimeManager_State_Call) Return(_a0 []runtime.ComponentComponentState) *RuntimeManager_State_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *RuntimeManager_State_Call) RunAndReturn(run func() []runtime.ComponentComponentState) *RuntimeManager_State_Call { + _c.Call.Return(run) + return _c +} + +// SubscribeAll provides a mock function with given fields: _a0 +func (_m *RuntimeManager) SubscribeAll(_a0 context.Context) *runtime.SubscriptionAll { + ret := _m.Called(_a0) + + var r0 *runtime.SubscriptionAll + if rf, ok := ret.Get(0).(func(context.Context) *runtime.SubscriptionAll); ok { + r0 = rf(_a0) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*runtime.SubscriptionAll) + } + } + + return r0 +} + +// RuntimeManager_SubscribeAll_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SubscribeAll' +type RuntimeManager_SubscribeAll_Call struct { + *mock.Call +} + +// SubscribeAll is a helper method to define mock.On call +// - _a0 context.Context +func (_e *RuntimeManager_Expecter) SubscribeAll(_a0 interface{}) *RuntimeManager_SubscribeAll_Call { + return &RuntimeManager_SubscribeAll_Call{Call: _e.mock.On("SubscribeAll", _a0)} +} + +func (_c *RuntimeManager_SubscribeAll_Call) Run(run func(_a0 context.Context)) *RuntimeManager_SubscribeAll_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context)) + }) + return _c +} + +func (_c *RuntimeManager_SubscribeAll_Call) Return(_a0 *runtime.SubscriptionAll) *RuntimeManager_SubscribeAll_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *RuntimeManager_SubscribeAll_Call) RunAndReturn(run func(context.Context) *runtime.SubscriptionAll) *RuntimeManager_SubscribeAll_Call { + _c.Call.Return(run) + return _c +} + +// Update provides a mock function with given fields: _a0 +func (_m *RuntimeManager) Update(_a0 component.Model) error { + ret := _m.Called(_a0) + + var r0 error + if rf, ok := ret.Get(0).(func(component.Model) error); ok { + r0 = rf(_a0) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// RuntimeManager_Update_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Update' +type RuntimeManager_Update_Call struct { + *mock.Call +} + +// Update is a helper method to define mock.On call +// - _a0 component.Model +func (_e *RuntimeManager_Expecter) Update(_a0 interface{}) *RuntimeManager_Update_Call { + return &RuntimeManager_Update_Call{Call: _e.mock.On("Update", _a0)} +} + +func (_c *RuntimeManager_Update_Call) Run(run func(_a0 component.Model)) *RuntimeManager_Update_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(component.Model)) + }) + return _c +} + +func (_c *RuntimeManager_Update_Call) Return(_a0 error) *RuntimeManager_Update_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *RuntimeManager_Update_Call) RunAndReturn(run func(component.Model) error) *RuntimeManager_Update_Call { + _c.Call.Return(run) + return _c +} + +type mockConstructorTestingTNewRuntimeManager interface { + mock.TestingT + Cleanup(func()) +} + +// NewRuntimeManager creates a new instance of RuntimeManager. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +func NewRuntimeManager(t mockConstructorTestingTNewRuntimeManager) *RuntimeManager { + mock := &RuntimeManager{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/pkg/agent/application/managed_mode.go b/internal/pkg/agent/application/managed_mode.go index b90951343b1..1357f3dfeff 100644 --- a/internal/pkg/agent/application/managed_mode.go +++ b/internal/pkg/agent/application/managed_mode.go @@ -341,6 +341,7 @@ func (m *managedConfigManager) initDispatcher(canceller context.CancelFunc) *han &fleetapi.ActionUnenroll{}, handlers.NewUnenroll( m.log, + m.coord, m.ch, []context.CancelFunc{canceller}, m.stateStore, diff --git a/internal/pkg/agent/cmd/install.go b/internal/pkg/agent/cmd/install.go index 3a536f75077..f9420074c0a 100644 --- a/internal/pkg/agent/cmd/install.go +++ b/internal/pkg/agent/cmd/install.go @@ -188,7 +188,7 @@ func installCmd(streams *cli.IOStreams, cmd *cobra.Command) error { defer func() { if err != nil { - _ = install.Uninstall(cfgFile, topPath) + _ = install.Uninstall(cfgFile, topPath, "") } }() @@ -222,7 +222,7 @@ func installCmd(streams *cli.IOStreams, cmd *cobra.Command) error { if err != nil { if status != install.PackageInstall { var exitErr *exec.ExitError - _ = install.Uninstall(cfgFile, topPath) + _ = install.Uninstall(cfgFile, topPath, "") if err != nil && errors.As(err, &exitErr) { return fmt.Errorf("enroll command failed with exit code: %d", exitErr.ExitCode()) } diff --git a/internal/pkg/agent/cmd/uninstall.go b/internal/pkg/agent/cmd/uninstall.go index 4b1b31b273a..9a96e412302 100644 --- a/internal/pkg/agent/cmd/uninstall.go +++ b/internal/pkg/agent/cmd/uninstall.go @@ -34,6 +34,7 @@ Unless -f is used this command will ask confirmation before performing removal. } cmd.Flags().BoolP("force", "f", false, "Force overwrite the current and do not prompt for confirmation") + cmd.Flags().String("uninstall-token", "", "Uninstall token required for protected agent uninstall") return cmd } @@ -55,6 +56,7 @@ func uninstallCmd(streams *cli.IOStreams, cmd *cobra.Command) error { } force, _ := cmd.Flags().GetBool("force") + uninstallToken, _ := cmd.Flags().GetString("uninstall-token") if status == install.Broken { if !force { fmt.Fprintf(streams.Out, "Elastic Agent is installed but currently broken: %s\n", reason) @@ -78,7 +80,7 @@ func uninstallCmd(streams *cli.IOStreams, cmd *cobra.Command) error { } } - err = install.Uninstall(paths.ConfigFile(), paths.Top()) + err = install.Uninstall(paths.ConfigFile(), paths.Top(), uninstallToken) if err != nil { return err } diff --git a/internal/pkg/agent/install/install.go b/internal/pkg/agent/install/install.go index 70a5b780391..c21e9e97ffa 100644 --- a/internal/pkg/agent/install/install.go +++ b/internal/pkg/agent/install/install.go @@ -28,10 +28,17 @@ func Install(cfgFile, topPath string) error { return errors.New(err, "failed to discover the source directory for installation", errors.TypeFilesystem) } - // uninstall current installation - err = Uninstall(cfgFile, topPath) + // Uninstall current installation + // + // There is no uninstall token for "install" command. + // Uninstall will fail on protected agent. + // The protected Agent will need to be uninstalled first before it can be installed. + err = Uninstall(cfgFile, topPath, "") if err != nil { - return err + return errors.New( + err, + fmt.Sprintf("failed to uninstall Agent at (%s)", filepath.Dir(topPath)), + errors.M("directory", filepath.Dir(topPath))) } // ensure parent directory exists, copy source into install path diff --git a/internal/pkg/agent/install/uninstall.go b/internal/pkg/agent/install/uninstall.go index 7b90990adc1..05c7a7ad7f7 100644 --- a/internal/pkg/agent/install/uninstall.go +++ b/internal/pkg/agent/install/uninstall.go @@ -26,16 +26,18 @@ import ( "github.com/elastic/elastic-agent/pkg/component" comprt "github.com/elastic/elastic-agent/pkg/component/runtime" "github.com/elastic/elastic-agent/pkg/core/logger" + "github.com/elastic/elastic-agent/pkg/features" ) // Uninstall uninstalls persistently Elastic Agent on the system. -func Uninstall(cfgFile, topPath string) error { +func Uninstall(cfgFile, topPath, uninstallToken string) error { // uninstall the current service svc, err := newService(topPath) if err != nil { return err } status, _ := svc.Status() + if status == service.StatusRunning { err := svc.Stop() if err != nil { @@ -45,12 +47,25 @@ func Uninstall(cfgFile, topPath string) error { errors.M("service", paths.ServiceName)) } } - _ = svc.Uninstall() - if err := uninstallComponents(context.Background(), cfgFile); err != nil { + // Uninstall components first + if err := uninstallComponents(context.Background(), cfgFile, uninstallToken); err != nil { + // If service status was running it was stopped to uninstall the components. + // If the components uninstall failed start the service again + if status == service.StatusRunning { + if startErr := svc.Start(); startErr != nil { + return errors.New( + err, + fmt.Sprintf("failed to restart service (%s), after failed components uninstall: %v", paths.ServiceName, startErr), + errors.M("service", paths.ServiceName)) + } + } return err } + // Uninstall service only after components were uninstalled successfully + _ = svc.Uninstall() + // remove, if present on platform if paths.ShellWrapperPath != "" { err = os.Remove(paths.ShellWrapperPath) @@ -152,7 +167,7 @@ func delayedRemoval(path string) { } -func uninstallComponents(ctx context.Context, cfgFile string) error { +func uninstallComponents(ctx context.Context, cfgFile string, uninstallToken string) error { log, err := logger.NewWithLogpLevel("", logp.ErrorLevel, false) if err != nil { return err @@ -188,6 +203,11 @@ func uninstallComponents(ctx context.Context, cfgFile string) error { return nil } + // Need to read the features from config on uninstall, in order to set the tamper protection feature flag correctly + if err := features.Apply(cfg); err != nil { + return fmt.Errorf("could not parse and apply feature flags config: %w", err) + } + // check caps so we don't try uninstalling things that were already // prevented from installing caps, err := capabilities.LoadFile(paths.AgentCapabilitiesPath(), log) @@ -201,19 +221,23 @@ func uninstallComponents(ctx context.Context, cfgFile string) error { // This component is not active continue } - if err := uninstallServiceComponent(ctx, log, comp); err != nil { + if err := uninstallServiceComponent(ctx, log, comp, uninstallToken); err != nil { os.Stderr.WriteString(fmt.Sprintf("failed to uninstall component %q: %s\n", comp.ID, err)) + // The decision was made to change the behaviour and leave the Agent installed if Endpoint uninstall fails + // https://github.com/elastic/elastic-agent/pull/2708#issuecomment-1574251911 + // Thus returning error here. + return err } } return nil } -func uninstallServiceComponent(ctx context.Context, log *logp.Logger, comp component.Component) error { +func uninstallServiceComponent(ctx context.Context, log *logp.Logger, comp component.Component, uninstallToken string) error { // Do not use infinite retries when uninstalling from the command line. If the uninstall needs to be // retried the entire uninstall command can be retried. Retries may complete asynchronously with the // execution of the uninstall command, leading to bugs like https://github.com/elastic/elastic-agent/issues/3060. - return comprt.UninstallService(ctx, log, comp) + return comprt.UninstallService(ctx, log, comp, uninstallToken) } func serviceComponentsFromConfig(specs component.RuntimeSpecs, cfg *config.Config) ([]component.Component, error) { diff --git a/internal/pkg/agent/storage/store/action_store.go b/internal/pkg/agent/storage/store/action_store.go index ad0aef0734c..ea0b2eb3c8b 100644 --- a/internal/pkg/agent/storage/store/action_store.go +++ b/internal/pkg/agent/storage/store/action_store.go @@ -145,9 +145,10 @@ var _ actionPolicyChangeSerializer = actionPolicyChangeSerializer(fleetapi.Actio // actionUnenrollSerializer is a struct that adds a YAML serialization, type actionUnenrollSerializer struct { - ActionID string `yaml:"action_id"` - ActionType string `yaml:"action_type"` - IsDetected bool `yaml:"is_detected"` + ActionID string `yaml:"action_id"` + ActionType string `yaml:"action_type"` + IsDetected bool `yaml:"is_detected"` + Signed *fleetapi.Signed `yaml:"signed,omitempty"` } // add a guards between the serializer structs and the original struct. diff --git a/internal/pkg/fleetapi/action.go b/internal/pkg/fleetapi/action.go index 1675f323d57..df70677d716 100644 --- a/internal/pkg/fleetapi/action.go +++ b/internal/pkg/fleetapi/action.go @@ -224,14 +224,15 @@ func (a *ActionPolicyChange) AckEvent() AckEvent { // ActionUpgrade is a request for agent to upgrade. type ActionUpgrade struct { - ActionID string `yaml:"action_id"` - ActionType string `yaml:"type"` - ActionStartTime string `json:"start_time" yaml:"start_time,omitempty"` // TODO change to time.Time in unmarshal - ActionExpiration string `json:"expiration" yaml:"expiration,omitempty"` - Version string `json:"version" yaml:"version,omitempty"` - SourceURI string `json:"source_uri,omitempty" yaml:"source_uri,omitempty"` - Retry int `json:"retry_attempt,omitempty" yaml:"retry_attempt,omitempty"` - Err error + ActionID string `yaml:"action_id" mapstructure:"id"` + ActionType string `yaml:"type" mapstructure:"type"` + ActionStartTime string `json:"start_time" yaml:"start_time,omitempty" mapstructure:"-"` // TODO change to time.Time in unmarshal + ActionExpiration string `json:"expiration" yaml:"expiration,omitempty" mapstructure:"-"` + Version string `json:"version" yaml:"version,omitempty" mapstructure:"-"` + SourceURI string `json:"source_uri,omitempty" yaml:"source_uri,omitempty" mapstructure:"-"` + Retry int `json:"retry_attempt,omitempty" yaml:"retry_attempt,omitempty" mapstructure:"-"` + Signed *Signed `json:"signed,omitempty" yaml:"signed,omitempty" mapstructure:"signed,omitempty"` + Err error `json:"-" yaml:"-" mapstructure:"-"` } func (a *ActionUpgrade) String() string { @@ -322,11 +323,19 @@ func (a *ActionUpgrade) SetStartTime(t time.Time) { a.ActionStartTime = t.Format(time.RFC3339) } +// MarshalMap marshals ActionUpgrade into a corresponding map +func (a *ActionUpgrade) MarshalMap() (map[string]interface{}, error) { + var res map[string]interface{} + err := mapstructure.Decode(a, &res) + return res, err +} + // ActionUnenroll is a request for agent to unhook from fleet. type ActionUnenroll struct { - ActionID string `yaml:"action_id"` - ActionType string `yaml:"type"` - IsDetected bool `json:"is_detected,omitempty" yaml:"is_detected,omitempty"` + ActionID string `yaml:"action_id" mapstructure:"id"` + ActionType string `yaml:"type" mapstructure:"type"` + IsDetected bool `json:"is_detected,omitempty" yaml:"is_detected,omitempty" mapstructure:"-"` + Signed *Signed `json:"signed,omitempty" mapstructure:"signed,omitempty"` } func (a *ActionUnenroll) String() string { @@ -352,6 +361,13 @@ func (a *ActionUnenroll) AckEvent() AckEvent { return newAckEvent(a.ActionID, a.ActionType) } +// MarshalMap marshals ActionUnenroll into a corresponding map +func (a *ActionUnenroll) MarshalMap() (map[string]interface{}, error) { + var res map[string]interface{} + err := mapstructure.Decode(a, &res) + return res, err +} + // ActionSettings is a request to change agent settings. type ActionSettings struct { ActionID string `yaml:"action_id"` @@ -562,6 +578,7 @@ func (a *Actions) UnmarshalJSON(data []byte) error { action = &ActionUnenroll{ ActionID: response.ActionID, ActionType: response.ActionType, + Signed: response.Signed, } case ActionTypeUpgrade: action = &ActionUpgrade{ @@ -569,6 +586,7 @@ func (a *Actions) UnmarshalJSON(data []byte) error { ActionType: response.ActionType, ActionStartTime: response.ActionStartTime, ActionExpiration: response.ActionExpiration, + Signed: response.Signed, } if err := json.Unmarshal(response.Data, action); err != nil { @@ -656,11 +674,13 @@ func (a *Actions) UnmarshalYAML(unmarshal func(interface{}) error) error { InputType: n.InputType, Timeout: n.Timeout, Data: n.Data, + Signed: n.Signed, } case ActionTypeUnenroll: action = &ActionUnenroll{ ActionID: n.ActionID, ActionType: n.ActionType, + Signed: n.Signed, } case ActionTypeUpgrade: action = &ActionUpgrade{ diff --git a/internal/pkg/fleetapi/action_test.go b/internal/pkg/fleetapi/action_test.go index 6a8dae3b31a..ac83f31852b 100644 --- a/internal/pkg/fleetapi/action_test.go +++ b/internal/pkg/fleetapi/action_test.go @@ -153,3 +153,60 @@ func TestActionsUnmarshalJSON(t *testing.T) { assert.Equal(t, 1, action.Retry) }) } + +func TestActionUnenrollMarshalMap(t *testing.T) { + action := ActionUnenroll{ + ActionID: "164a6819-5c58-40f7-a33c-821c98ab0a8c", + ActionType: "UNENROLL", + Signed: &Signed{ + Data: "eyJAdGltZXN0YW1wIjoiMjAy", + Signature: "MEQCIGxsrI742xKL6OSI", + }, + } + + m, err := action.MarshalMap() + if err != nil { + t.Fatal(err) + } + + diff := cmp.Diff(m, map[string]interface{}{ + "id": "164a6819-5c58-40f7-a33c-821c98ab0a8c", + "type": "UNENROLL", + "signed": map[string]interface{}{ + "data": "eyJAdGltZXN0YW1wIjoiMjAy", + "signature": "MEQCIGxsrI742xKL6OSI", + }, + }) + + if diff != "" { + t.Fatal(diff) + } +} + +func TestActionUpgradeMarshalMap(t *testing.T) { + action := ActionUpgrade{ + ActionID: "164a6819-5c58-40f7-a33c-821c98ab0a8c", + ActionType: "UPGRADE", + Signed: &Signed{ + Data: "eyJAdGltZXN0YW1wIjoiMjAy", + Signature: "MEQCIGxsrI742xKL6OSI", + }, + } + m, err := action.MarshalMap() + if err != nil { + t.Fatal(err) + } + + diff := cmp.Diff(m, map[string]interface{}{ + "id": "164a6819-5c58-40f7-a33c-821c98ab0a8c", + "type": "UPGRADE", + "signed": map[string]interface{}{ + "data": "eyJAdGltZXN0YW1wIjoiMjAy", + "signature": "MEQCIGxsrI742xKL6OSI", + }, + }) + + if diff != "" { + t.Fatal(diff) + } +} diff --git a/pkg/component/component.go b/pkg/component/component.go index 22a34bb714c..aaddec8b399 100644 --- a/pkg/component/component.go +++ b/pkg/component/component.go @@ -5,6 +5,7 @@ package component import ( + "errors" "fmt" "sort" "strings" @@ -80,6 +81,63 @@ type Unit struct { Err error `yaml:"error,omitempty"` } +// Signed Strongly typed configuration for the signed data +type Signed struct { + Data string `yaml:"data"` // Signed base64 encoded json bytes + Signature string `yaml:"signature"` // Signature +} + +// IsSigned Checks if the signature exists, safe to call on nil +func (s *Signed) IsSigned() bool { + return (s != nil && (len(s.Signature) > 0)) +} + +// ErrNotFound is returned if the expected "signed" property itself or it's expected properties are missing or not a valid data type +var ErrNotFound = errors.New("not found") + +// SignedFromPolicy Returns Signed instance from the nested map representation of the agent configuration +func SignedFromPolicy(policy map[string]interface{}) (*Signed, error) { + v, ok := policy["signed"] + if !ok { + return nil, fmt.Errorf("policy is not signed: %w", ErrNotFound) + } + + signed, ok := v.(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("policy \"signed\" is not map: %w", ErrNotFound) + } + + data, err := getStringValue(signed, "data") + if err != nil { + return nil, err + } + + signature, err := getStringValue(signed, "signature") + if err != nil { + return nil, err + } + + res := &Signed{ + Data: data, + Signature: signature, + } + return res, nil +} + +func getStringValue(m map[string]interface{}, key string) (string, error) { + v, ok := m[key] + if !ok { + return "", fmt.Errorf("missing signed \"%s\": %w", key, ErrNotFound) + } + + s, ok := v.(string) + if !ok { + return "", fmt.Errorf("signed \"%s\" is not string: %w", key, ErrNotFound) + } + + return s, nil +} + // Component is a set of units that needs to run. type Component struct { // ID is the unique ID of the component. @@ -125,6 +183,55 @@ func (c *Component) Type() string { return "" } +// Model is the components model with signed policy data +// This replaces former top level []Components with the top Model that captures signed policy data. +// The signed data is a part of the policy since 8.8.0 release and contains the signed policy fragments and the signature that can be validated. +// The signed data is created and signed by kibana which provides protection from tampering for certain parts of the policy. +// +// The initial idea was that the Agent would validate the signed data if it's present, +// merge the signed data with the policy and dispatch configuration updates to the components. +// The latest Endpoint requirement of not trusting the Agent requires the full signed data with the signature to be passed to Endpoint for validation. +// Endpoint validates the signature and applies the configuration as needed. +// +// The Agent validation of the signature was disabled for 8.8.0 in order to minimize the scope of the change. +// Presently (as of June, 27, 2023) the signature is only validated by Endpoint. +// +// Example of the signed policy property: +// signed: +// +// data: >- +// eyJpZCI6IjBlNjA2OTUwLTE0NTEtMTFlZS04OTI2LTlkZjY4ZjdjMzhlZSIsImFnZW50Ijp7ImZlYXR1cmVzIjp7fSwicHJvdGVjdGlvbiI6eyJlbmFibGVkIjp0cnVlLCJ1bmluc3RhbGxfdG9rZW5faGFzaCI6IjB4MXJ1REo0NVBUYlNuV0V6Yi9xc3VnZHRMNFhKUVRHazU5QitxVEF1OVE9Iiwic2lnbmluZ19rZXkiOiJNRmt3RXdZSEtvWkl6ajBDQVFZSUtvWkl6ajBEQVFjRFFnQUVMRHd4Rk1WTjJvSTFmZW9USGJIWmkrUFJuSjZ5TzVzdUw4MktvRXl1M3FTMDB2OGNGVDNlb2JnZG5oT0MxUG9ka0MwVTFmWjhpN1k1TUlzc2szQ2Rzdz09In19LCJpbnB1dHMiOlt7ImlkIjoiZTgyZmQ1ZDEtOTBkOC00NWJjLWE5MTEtOTU1OTBjNDRjYTc1IiwibmFtZSI6IkVQIiwicmV2aXNpb24iOjEsInR5cGUiOiJlbmRwb2ludCJ9XX0= +// signature: >- +// MEUCIQCpQR8WES3X4gjptjIWtLdqJT0QLRVz5bUnTlG3xt4LfQIgW5ioOoaAUII4G0b74vWGSLSD7sQ6uAdqgZoNF33vSbM= +// +// Example of decoded signed.data from above: +// +// { +// "id": "0e606950-1451-11ee-8926-9df68f7c38ee", +// "agent": { +// "features": {}, +// "protection": { +// "enabled": true, +// "uninstall_token_hash": "0x1ruDJ45PTbSnWEzb/qsugdtL4XJQTGk59B+qTAu9Q=", +// "signing_key": "MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAELDwxFMVN2oI1feoTHbHZi+PRnJ6yO5suL82KoEyu3qS00v8cFT3eobgdnhOC1PodkC0U1fZ8i7Y5MIssk3Cdsw==" +// } +// }, +// "inputs": [ +// { +// "id": "e82fd5d1-90d8-45bc-a911-95590c44ca75", +// "name": "EP", +// "revision": 1, +// "type": "endpoint" +// } +// ] +// } +// +// The signed.data JSON has exact same shape/schema as the policy. +type Model struct { + Components []Component `yaml:"components,omitempty"` + Signed *Signed `yaml:"signed,omitempty"` +} + // ToComponents returns the components that should be running based on the policy and // the current runtime specification. func (r *RuntimeSpecs) ToComponents( diff --git a/pkg/component/component_test.go b/pkg/component/component_test.go index 72982c4e3fa..000297dc59f 100644 --- a/pkg/component/component_test.go +++ b/pkg/component/component_test.go @@ -17,23 +17,23 @@ import ( "testing" "time" - "gopkg.in/yaml.v2" - "github.com/elastic/go-ucfg" + "github.com/elastic/elastic-agent-client/v7/pkg/client" + "github.com/elastic/elastic-agent-client/v7/pkg/proto" "github.com/elastic/elastic-agent-libs/logp" + "github.com/elastic/elastic-agent/internal/pkg/agent/transpiler" "github.com/elastic/elastic-agent/internal/pkg/eql" "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" "google.golang.org/protobuf/testing/protocmp" "google.golang.org/protobuf/types/known/structpb" + "gopkg.in/yaml.v2" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - - "github.com/elastic/elastic-agent-client/v7/pkg/client" - "github.com/elastic/elastic-agent-client/v7/pkg/proto" ) func TestToComponents(t *testing.T) { @@ -2152,6 +2152,152 @@ func (h *testHeadersProvider) Headers() map[string]string { return h.headers } +// TestSignedMarshalUnmarshal will catch if the yaml library will get updated to v3 for example +func TestSignedMarshalUnmarshal(t *testing.T) { + const data = "eyJAdGltZXN0YW1wIjoiMjAyMy0wNS0yMlQxNzoxOToyOC40NjNaIiwiZXhwaXJhdGlvbiI6IjIwMjMtMDYtMjFUMTc6MTk6MjguNDYzWiIsImFnZW50cyI6WyI3ZjY0YWI2NC1hNmM0LTQ2ZTMtODIyYS0zODUxZGVkYTJmY2UiXSwiYWN0aW9uX2lkIjoiNGYwODQ2MGYtMDE0Yy00ZDllLWJmOGEtY2FhNjQyNzRhZGU0IiwidHlwZSI6IlVORU5ST0xMIiwidHJhY2VwYXJlbnQiOiIwMC1iOTBkYTlmOGNjNzdhODk0OTc0ZWIxZTIzMGNmNjc2Yy1lOTNlNzk4YTU4ODg2MDVhLTAxIn0=" + const signature = "MEUCIAxxsi9ff1zyV0+4fsJLqbP8Qb83tedU5iIFldtxEzEfAiEA0KUsrL7q+Fv7z6Boux3dY2P4emGi71jsMGanIZ552bM=" + + signed := Signed{ + Data: data, + Signature: signature, + } + + b, err := yaml.Marshal(signed) + if err != nil { + t.Fatal(err) + } + + var newSigned Signed + err = yaml.Unmarshal(b, &newSigned) + if err != nil { + t.Fatal(err) + } + + diff := cmp.Diff(signed, newSigned) + if diff != "" { + t.Fatal(diff) + } + + diff = cmp.Diff(true, signed.IsSigned()) + if diff != "" { + t.Fatal(diff) + } + + var nilSigned *Signed + diff = cmp.Diff(false, nilSigned.IsSigned()) + if diff != "" { + t.Fatal(diff) + } + + unsigned := Signed{} + diff = cmp.Diff(false, unsigned.IsSigned()) + if diff != "" { + t.Fatal(diff) + } +} + +func TestSignedFromPolicy(t *testing.T) { + const data = "eyJAdGltZXN0YW1wIjoiMjAyMy0wNS0yMlQxNzoxOToyOC40NjNaIiwiZXhwaXJhdGlvbiI6IjIwMjMtMDYtMjFUMTc6MTk6MjguNDYzWiIsImFnZW50cyI6WyI3ZjY0YWI2NC1hNmM0LTQ2ZTMtODIyYS0zODUxZGVkYTJmY2UiXSwiYWN0aW9uX2lkIjoiNGYwODQ2MGYtMDE0Yy00ZDllLWJmOGEtY2FhNjQyNzRhZGU0IiwidHlwZSI6IlVORU5ST0xMIiwidHJhY2VwYXJlbnQiOiIwMC1iOTBkYTlmOGNjNzdhODk0OTc0ZWIxZTIzMGNmNjc2Yy1lOTNlNzk4YTU4ODg2MDVhLTAxIn0=" + const signature = "MEUCIAxxsi9ff1zyV0+4fsJLqbP8Qb83tedU5iIFldtxEzEfAiEA0KUsrL7q+Fv7z6Boux3dY2P4emGi71jsMGanIZ552bM=" + + tests := []struct { + name string + policy map[string]interface{} + wantSigned *Signed + wantErr error + }{ + { + name: "not signed", + wantErr: ErrNotFound, + }, + { + name: "signed nil", + policy: map[string]interface{}{ + "signed": nil, + }, + wantErr: ErrNotFound, + }, + { + name: "signed not map", + policy: map[string]interface{}{ + "signed": "", + }, + wantErr: ErrNotFound, + }, + { + name: "signed empty", + policy: map[string]interface{}{ + "signed": map[string]interface{}{}, + }, + wantErr: ErrNotFound, + }, + { + name: "signed missing signature", + policy: map[string]interface{}{ + "signed": map[string]interface{}{ + "data": data, + }, + }, + wantErr: ErrNotFound, + }, + { + name: "signed missing data", + policy: map[string]interface{}{ + "signed": map[string]interface{}{ + "signaure": signature, + }, + }, + wantErr: ErrNotFound, + }, + { + name: "signed data invalid data type", + policy: map[string]interface{}{ + "signed": map[string]interface{}{ + "data": 1, + }, + }, + wantErr: ErrNotFound, + }, + { + name: "signed signature invalid data type", + policy: map[string]interface{}{ + "signed": map[string]interface{}{ + "signature": 1, + }, + }, + wantErr: ErrNotFound, + }, + { + name: "signed correct", + policy: map[string]interface{}{ + "signed": map[string]interface{}{ + "data": data, + "signature": signature, + }, + }, + wantSigned: &Signed{ + Data: data, + Signature: signature, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + signed, err := SignedFromPolicy(tc.policy) + diff := cmp.Diff(tc.wantSigned, signed) + if diff != "" { + t.Fatal(diff) + } + + diff = cmp.Diff(tc.wantErr, err, cmpopts.EquateErrors()) + if diff != "" { + t.Fatal(diff) + } + }) + } +} + func gatherDurationFieldPaths(s interface{}, pathSoFar string) []string { var gatheredPaths []string diff --git a/pkg/component/input_spec.go b/pkg/component/input_spec.go index e4ef6230678..1f19d53a5de 100644 --- a/pkg/component/input_spec.go +++ b/pkg/component/input_spec.go @@ -12,13 +12,14 @@ import ( // InputSpec is the specification for an input type. type InputSpec struct { - Name string `config:"name" yaml:"name" validate:"required"` - Aliases []string `config:"aliases,omitempty" yaml:"aliases,omitempty"` - Description string `config:"description" yaml:"description" validate:"required"` - Platforms []string `config:"platforms" yaml:"platforms" validate:"required,min=1"` - Outputs []string `config:"outputs,omitempty" yaml:"outputs,omitempty"` - Shippers []string `config:"shippers,omitempty" yaml:"shippers,omitempty"` - Runtime RuntimeSpec `config:"runtime,omitempty" yaml:"runtime,omitempty"` + Name string `config:"name" yaml:"name" validate:"required"` + Aliases []string `config:"aliases,omitempty" yaml:"aliases,omitempty"` + Description string `config:"description" yaml:"description" validate:"required"` + Platforms []string `config:"platforms" yaml:"platforms" validate:"required,min=1"` + Outputs []string `config:"outputs,omitempty" yaml:"outputs,omitempty"` + ProxiedActions []string `config:"proxied_actions,omitempty" yaml:"proxied_actions,omitempty"` + Shippers []string `config:"shippers,omitempty" yaml:"shippers,omitempty"` + Runtime RuntimeSpec `config:"runtime,omitempty" yaml:"runtime,omitempty"` Command *CommandSpec `config:"command,omitempty" yaml:"command,omitempty"` Service *ServiceSpec `config:"service,omitempty" yaml:"service,omitempty"` diff --git a/pkg/component/runtime/command.go b/pkg/component/runtime/command.go index ff8532a761c..b68faa685d2 100644 --- a/pkg/component/runtime/command.go +++ b/pkg/component/runtime/command.go @@ -42,6 +42,18 @@ const ( stateUnknownMessage = "Unknown" ) +func (m actionMode) String() string { + switch m { + case actionTeardown: + return "teardown" + case actionStop: + return "stop" + case actionStart: + return "start" + } + return "" +} + type MonitoringManager interface { EnrichArgs(string, string, []string) []string Prepare(string) error @@ -269,7 +281,7 @@ func (c *commandRuntime) Stop() error { // Teardown tears down the component. // // Non-blocking and never returns an error. -func (c *commandRuntime) Teardown() error { +func (c *commandRuntime) Teardown(_ *component.Signed) error { // clear channel so it's the latest action select { case <-c.actionCh: diff --git a/pkg/component/runtime/failed.go b/pkg/component/runtime/failed.go index 5d39c09d862..ae9517347d5 100644 --- a/pkg/component/runtime/failed.go +++ b/pkg/component/runtime/failed.go @@ -75,7 +75,7 @@ func (c *failedRuntime) Stop() error { } // Teardown marks it stopped. -func (c *failedRuntime) Teardown() error { +func (c *failedRuntime) Teardown(_ *component.Signed) error { return c.Stop() } diff --git a/pkg/component/runtime/manager.go b/pkg/component/runtime/manager.go index b0adff4e007..1e7d3e4b163 100644 --- a/pkg/component/runtime/manager.go +++ b/pkg/component/runtime/manager.go @@ -308,7 +308,7 @@ func (m *Manager) Errors() <-chan error { // Called from the main Coordinator goroutine. // // This returns as soon as possible, the work is performed in the background. -func (m *Manager) Update(components []component.Component) error { +func (m *Manager) Update(model component.Model) error { shuttingDown := m.shuttingDown.Load() if shuttingDown { // ignore any updates once shutdown started @@ -316,7 +316,7 @@ func (m *Manager) Update(components []component.Component) error { } // teardown is true because the public `Update` method would be coming directly from // policy so if a component was removed it needs to be torn down. - return m.update(components, true) + return m.update(model, true) } // State returns the current component states. @@ -667,21 +667,21 @@ func (m *Manager) Actions(server proto.ElasticAgent_ActionsServer) error { // update updates the current state of the running components. // // This returns as soon as possible, work is performed in the background. -func (m *Manager) update(components []component.Component, teardown bool) error { +func (m *Manager) update(model component.Model, teardown bool) error { // ensure that only one `update` can occur at the same time m.updateMx.Lock() defer m.updateMx.Unlock() // prepare the components to add consistent shipper connection information between // the connected components in the model - err := m.connectShippers(components) + err := m.connectShippers(model.Components) if err != nil { return err } touched := make(map[string]bool) - newComponents := make([]component.Component, 0, len(components)) - for _, comp := range components { + newComponents := make([]component.Component, 0, len(model.Components)) + for _, comp := range model.Components { touched[comp.ID] = true m.currentMx.RLock() existing, ok := m.current[comp.ID] @@ -712,7 +712,7 @@ func (m *Manager) update(components []component.Component, teardown bool) error var stoppedWg sync.WaitGroup stoppedWg.Add(len(stop)) for _, existing := range stop { - _ = existing.stop(teardown) + _ = existing.stop(teardown, model.Signed) // stop is async, wait for operation to finish, // otherwise new instance may be started and components // may fight for resources (e.g ports, files, locks) @@ -786,7 +786,7 @@ func (m *Manager) waitForStopped(comp *componentRuntimeState) { func (m *Manager) shutdown() { // don't tear down as this is just a shutdown, so components most likely will come back // on next start of the manager - _ = m.update([]component.Component{}, false) + _ = m.update(component.Model{Components: []component.Component{}}, false) // wait until all components are removed for { diff --git a/pkg/component/runtime/manager_test.go b/pkg/component/runtime/manager_test.go index e961734181c..f503be69b61 100644 --- a/pkg/component/runtime/manager_test.go +++ b/pkg/component/runtime/manager_test.go @@ -141,7 +141,7 @@ func TestManager_SimpleComponentErr(t *testing.T) { defer drainErrChan(errCh) defer drainErrChan(subErrCh) - err = m.Update([]component.Component{comp}) + err = m.Update(component.Model{Components: []component.Component{comp}}) require.NoError(t, err) endTimer := time.NewTimer(30 * time.Second) @@ -233,7 +233,7 @@ func TestManager_FakeInput_StartStop(t *testing.T) { subErrCh <- fmt.Errorf("unit failed: %s", unit.Message) } else if unit.State == client.UnitStateHealthy { // remove the component which will stop it - err := m.Update([]component.Component{}) + err := m.Update(component.Model{Components: []component.Component{}}) if err != nil { subErrCh <- err } @@ -256,7 +256,7 @@ func TestManager_FakeInput_StartStop(t *testing.T) { defer drainErrChan(errCh) defer drainErrChan(subErrCh) - err = m.Update([]component.Component{comp}) + err = m.Update(component.Model{Components: []component.Component{comp}}) require.NoError(t, err) endTimer := time.NewTimer(30 * time.Second) @@ -380,7 +380,7 @@ func TestManager_FakeInput_Features(t *testing.T) { Fqdn: &proto.FQDNFeature{Enabled: true}, } - err := m.Update([]component.Component{comp}) + err := m.Update(component.Model{Components: []component.Component{comp}}) if err != nil { subscriptionErrCh <- fmt.Errorf("[case %d]: failed to update component: %w", healthIteration, err) @@ -435,7 +435,7 @@ func TestManager_FakeInput_Features(t *testing.T) { "message": "Fake Healthy", }) - err := m.Update([]component.Component{comp}) + err := m.Update(component.Model{Components: []component.Component{comp}}) if err != nil { t.Logf("error updating component state to health: %v", err) @@ -455,7 +455,7 @@ func TestManager_FakeInput_Features(t *testing.T) { defer drainErrChan(managerErrCh) defer drainErrChan(subscriptionErrCh) - err = m.Update([]component.Component{comp}) + err = m.Update(component.Model{Components: []component.Component{comp}}) require.NoError(t, err) timeout := 30 * time.Second @@ -571,7 +571,7 @@ func TestManager_FakeInput_BadUnitToGood(t *testing.T) { } unitBad = false - err := m.Update([]component.Component{updatedComp}) + err := m.Update(component.Model{Components: []component.Component{updatedComp}}) if err != nil { subErrCh <- err } @@ -599,7 +599,7 @@ func TestManager_FakeInput_BadUnitToGood(t *testing.T) { } } else if unit.State == client.UnitStateHealthy { // bad unit is now healthy; stop the component - err := m.Update([]component.Component{}) + err := m.Update(component.Model{Components: []component.Component{}}) if err != nil { subErrCh <- err } @@ -623,7 +623,7 @@ func TestManager_FakeInput_BadUnitToGood(t *testing.T) { defer drainErrChan(errCh) defer drainErrChan(subErrCh) - err = m.Update([]component.Component{comp}) + err = m.Update(component.Model{Components: []component.Component{comp}}) require.NoError(t, err) endTimer := time.NewTimer(30 * time.Second) @@ -738,7 +738,7 @@ func TestManager_FakeInput_GoodUnitToBad(t *testing.T) { Err: errors.New("hard-error for config"), } unitGood = false - err := m.Update([]component.Component{updatedComp}) + err := m.Update(component.Model{Components: []component.Component{updatedComp}}) if err != nil { subErrCh <- err } @@ -751,7 +751,7 @@ func TestManager_FakeInput_GoodUnitToBad(t *testing.T) { } else { if unit.State == client.UnitStateFailed { // went to failed; stop whole component - err := m.Update([]component.Component{}) + err := m.Update(component.Model{Components: []component.Component{}}) if err != nil { subErrCh <- err } @@ -773,7 +773,7 @@ func TestManager_FakeInput_GoodUnitToBad(t *testing.T) { defer drainErrChan(errCh) defer drainErrChan(subErrCh) - err = m.Update([]component.Component{comp}) + err = m.Update(component.Model{Components: []component.Component{comp}}) require.NoError(t, err) endTimer := time.NewTimer(30 * time.Second) @@ -877,7 +877,7 @@ func TestManager_FakeInput_NoDeadlock(t *testing.T) { } i += 1 comp = updatedComp - err := m.Update([]component.Component{updatedComp}) + err := m.Update(component.Model{Components: []component.Component{updatedComp}}) if err != nil { updatedErr <- err return @@ -918,7 +918,7 @@ LOOP: case <-endTimer.C: // no deadlock after timeout (all good stop the component) updatedCancel() - _ = m.Update([]component.Component{}) + _ = m.Update(component.Model{Components: []component.Component{}}) break LOOP case err := <-errCh: require.NoError(t, err) @@ -1010,7 +1010,7 @@ func TestManager_FakeInput_Configure(t *testing.T) { "state": int(client.UnitStateDegraded), "message": "Fake Degraded", }) - err := m.Update([]component.Component{comp}) + err := m.Update(component.Model{Components: []component.Component{comp}}) if err != nil { subErrCh <- err } @@ -1033,7 +1033,7 @@ func TestManager_FakeInput_Configure(t *testing.T) { defer drainErrChan(errCh) defer drainErrChan(subErrCh) - err = m.Update([]component.Component{comp}) + err = m.Update(component.Model{Components: []component.Component{comp}}) require.NoError(t, err) endTimer := time.NewTimer(30 * time.Second) @@ -1151,7 +1151,7 @@ func TestManager_FakeInput_RemoveUnit(t *testing.T) { } else if unit1.State == client.UnitStateHealthy { // unit1 is healthy lets remove it from the component comp.Units = comp.Units[0:1] - err := m.Update([]component.Component{comp}) + err := m.Update(component.Model{Components: []component.Component{comp}}) if err != nil { subErrCh <- err } @@ -1186,7 +1186,7 @@ func TestManager_FakeInput_RemoveUnit(t *testing.T) { defer drainErrChan(errCh) defer drainErrChan(subErrCh) - err = m.Update([]component.Component{comp}) + err = m.Update(component.Model{Components: []component.Component{comp}}) require.NoError(t, err) endTimer := time.NewTimer(30 * time.Second) @@ -1310,7 +1310,7 @@ func TestManager_FakeInput_ActionState(t *testing.T) { defer drainErrChan(errCh) defer drainErrChan(subErrCh) - err = m.Update([]component.Component{comp}) + err = m.Update(component.Model{Components: []component.Component{comp}}) require.NoError(t, err) endTimer := time.NewTimer(30 * time.Second) @@ -1445,7 +1445,7 @@ func TestManager_FakeInput_Restarts(t *testing.T) { defer drainErrChan(errCh) defer drainErrChan(subErrCh) - err = m.Update([]component.Component{comp}) + err = m.Update(component.Model{Components: []component.Component{comp}}) require.NoError(t, err) endTimer := time.NewTimer(30 * time.Second) @@ -1562,7 +1562,7 @@ func TestManager_FakeInput_Restarts_ConfigKill(t *testing.T) { "message": "Fake Healthy", "kill": rp[1], }) - err := m.Update([]component.Component{comp}) + err := m.Update(component.Model{Components: []component.Component{comp}}) if err != nil { subErrCh <- err } @@ -1587,7 +1587,7 @@ func TestManager_FakeInput_Restarts_ConfigKill(t *testing.T) { defer drainErrChan(errCh) defer drainErrChan(subErrCh) - err = m.Update([]component.Component{comp}) + err = m.Update(component.Model{Components: []component.Component{comp}}) require.NoError(t, err) endTimer := time.NewTimer(1 * time.Minute) @@ -1701,7 +1701,7 @@ func TestManager_FakeInput_KeepsRestarting(t *testing.T) { "message": fmt.Sprintf("Fake Healthy %d", lastStoppedCount), "kill_on_interval": true, }) - err := m.Update([]component.Component{comp}) + err := m.Update(component.Model{Components: []component.Component{comp}}) if err != nil { subErrCh <- err } @@ -1729,7 +1729,7 @@ func TestManager_FakeInput_KeepsRestarting(t *testing.T) { defer drainErrChan(errCh) defer drainErrChan(subErrCh) - err = m.Update([]component.Component{comp}) + err = m.Update(component.Model{Components: []component.Component{comp}}) require.NoError(t, err) endTimer := time.NewTimer(1 * time.Minute) @@ -1844,7 +1844,7 @@ func TestManager_FakeInput_RestartsOnMissedCheckins(t *testing.T) { defer drainErrChan(errCh) defer drainErrChan(subErrCh) - err = m.Update([]component.Component{comp}) + err = m.Update(component.Model{Components: []component.Component{comp}}) require.NoError(t, err) endTimer := time.NewTimer(30 * time.Second) @@ -1961,7 +1961,7 @@ func TestManager_FakeInput_InvalidAction(t *testing.T) { defer drainErrChan(errCh) defer drainErrChan(subErrCh) - err = m.Update([]component.Component{comp}) + err = m.Update(component.Model{Components: []component.Component{comp}}) require.NoError(t, err) endTimer := time.NewTimer(30 * time.Second) @@ -2158,7 +2158,7 @@ func TestManager_FakeInput_MultiComponent(t *testing.T) { defer drainErrChan(subErrCh1) defer drainErrChan(subErrCh2) - err = m.Update(components) + err = m.Update(component.Model{Components: components}) require.NoError(t, err) count := 0 @@ -2314,7 +2314,7 @@ func TestManager_FakeInput_LogLevel(t *testing.T) { defer drainErrChan(errCh) defer drainErrChan(subErrCh) - err = m.Update([]component.Component{comp}) + err = m.Update(component.Model{Components: []component.Component{comp}}) require.NoError(t, err) endTimer := time.NewTimer(30 * time.Second) @@ -2528,7 +2528,7 @@ func TestManager_FakeShipper(t *testing.T) { subErrCh <- err } else { // successful; turn it all off - err := m.Update([]component.Component{}) + err := m.Update(component.Model{Components: []component.Component{}}) if err != nil { subErrCh <- err } @@ -2557,7 +2557,7 @@ func TestManager_FakeShipper(t *testing.T) { subErrCh <- err } else { // successful; turn it all off - err := m.Update([]component.Component{}) + err := m.Update(component.Model{Components: []component.Component{}}) if err != nil { subErrCh <- err } @@ -2592,7 +2592,7 @@ func TestManager_FakeShipper(t *testing.T) { subErrCh <- err } else { // successful; turn it all off - err := m.Update([]component.Component{}) + err := m.Update(component.Model{Components: []component.Component{}}) if err != nil { subErrCh <- err } @@ -2617,7 +2617,7 @@ func TestManager_FakeShipper(t *testing.T) { defer drainErrChan(errCh) defer drainErrChan(subErrCh) - err = m.Update(comps) + err = m.Update(component.Model{Components: comps}) require.NoError(t, err) timeout := 2 * time.Minute @@ -2822,7 +2822,7 @@ func TestManager_FakeInput_OutputChange(t *testing.T) { } time.Sleep(100 * time.Millisecond) - err = m.Update(components) + err = m.Update(component.Model{Components: components}) require.NoError(t, err) updateSleep := 300 * time.Millisecond @@ -2831,7 +2831,7 @@ func TestManager_FakeInput_OutputChange(t *testing.T) { updateSleep = time.Second } time.Sleep(updateSleep) - err = m.Update(components2) + err = m.Update(component.Model{Components: components2}) require.NoError(t, err) count := 0 diff --git a/pkg/component/runtime/runtime.go b/pkg/component/runtime/runtime.go index 2c86a92bf04..a69aeea8d14 100644 --- a/pkg/component/runtime/runtime.go +++ b/pkg/component/runtime/runtime.go @@ -50,7 +50,7 @@ type componentRuntime interface { // // Used to tell control the difference between stopping a component to restart it or upgrade it, versus // the component being completely removed. - Teardown() error + Teardown(signed *component.Signed) error } // newComponentRuntime creates the proper runtime based on the input specification for the component. @@ -180,10 +180,10 @@ func (s *componentRuntimeState) start() error { return s.runtime.Start() } -func (s *componentRuntimeState) stop(teardown bool) error { +func (s *componentRuntimeState) stop(teardown bool, signed *component.Signed) error { s.shuttingDown.Store(true) if teardown { - return s.runtime.Teardown() + return s.runtime.Teardown(signed) } return s.runtime.Stop() } diff --git a/pkg/component/runtime/service.go b/pkg/component/runtime/service.go index 71d71890052..1e172a78ed1 100644 --- a/pkg/component/runtime/service.go +++ b/pkg/component/runtime/service.go @@ -16,8 +16,14 @@ import ( "github.com/elastic/elastic-agent-client/v7/pkg/proto" "github.com/elastic/elastic-agent/pkg/component" "github.com/elastic/elastic-agent/pkg/core/logger" + "github.com/elastic/elastic-agent/pkg/features" ) +type actionModeSigned struct { + actionMode + signed *component.Signed +} + const ( defaultCheckServiceStatusInterval = 30 * time.Second // 30 seconds default for now, consistent with the command check-in interval ) @@ -37,7 +43,7 @@ type serviceRuntime struct { log *logger.Logger ch chan ComponentState - actionCh chan actionMode + actionCh chan actionModeSigned compCh chan component.Component statusCh chan service.Status @@ -64,7 +70,7 @@ func newServiceRuntime(comp component.Component, logger *logger.Logger) (*servic comp: comp, log: logger.Named("service_runtime"), ch: make(chan ComponentState), - actionCh: make(chan actionMode, 1), + actionCh: make(chan actionModeSigned, 1), compCh: make(chan component.Component, 1), statusCh: make(chan service.Status), state: state, @@ -81,7 +87,32 @@ func newServiceRuntime(comp component.Component, logger *logger.Logger) (*servic // Called by Manager inside a goroutine. Run does not return until the passed in context is done. Run is always // called before any of the other methods in the interface and once the context is done none of those methods should // ever be called again. +// +// ================================================================================================== +// +// Updated teardown sequence: +// +// 1. if tearing down already (tearingDown == true), continue with stop/uninstall +// +// 2. if not tearing down already (tearingDown == false) +// a. inject new signed payload for component +// b. reset check-in timer +// c. set teardown timeout timer +// d. set tearingDown=true +// c. send component update (with new signed payload) +// d. await for check-in after update or teardown timeout +// e. upon receiving check-in if (tearingDown == true), send teardown action to itself +// +// ================================================================================================== func (s *serviceRuntime) Run(ctx context.Context, comm Communicator) (err error) { + // The teardownCheckingTimeout is set to the same amount as the checkin timeout for now + teardownCheckinTimeout := s.checkinPeriod() + teardownCheckinTimer := time.NewTimer(teardownCheckinTimeout) + defer teardownCheckinTimer.Stop() + + // Stop teardown checkin timeout timer initially + teardownCheckinTimer.Stop() + checkinTimer := time.NewTimer(s.checkinPeriod()) defer checkinTimer.Stop() @@ -92,6 +123,7 @@ func (s *serviceRuntime) Run(ctx context.Context, comm Communicator) (err error) cis *connInfoServer lastCheckin time.Time missedCheckins int + tearingDown bool ) cisStop := func() { @@ -102,6 +134,55 @@ func (s *serviceRuntime) Run(ctx context.Context, comm Communicator) (err error) } defer cisStop() + onStop := func(am actionMode) { + // Stop check-in timer + s.log.Debugf("stop check-in timer for %s service", s.name()) + checkinTimer.Stop() + + // Stop connection info + s.log.Debugf("stop connection info for %s service", s.name()) + cisStop() + + // Stop service + s.stop(ctx, comm, lastCheckin, am == actionTeardown) + } + + processTeardown := func(am actionMode, signed *component.Signed) error { + s.log.Debugf("start teardown for %s service", s.name()) + // Inject new signed + newComp, err := injectSigned(s.comp, signed) + if err != nil { + s.log.Errorf("failed to inject signed configuration for %s service, err: %v", s.name(), err) + return err + } + + // Set teardown timeout timer + teardownCheckinTimer.Reset(teardownCheckinTimeout) + + // Process newComp update + // This should send component update that should cause service checkin + s.log.Debugf("process new comp config for %s service", s.name()) + s.processNewComp(newComp, comm) + return nil + } + + onTeardown := func(as actionModeSigned) { + tamperProtection := features.TamperProtection() + s.log.Debugf("got teardown for %s service, tearingDown==%v, tamperProtectoin=%v", s.name(), tearingDown, tamperProtection) + + // If tamper protection is disabled do the old behavior + if !tamperProtection { + onStop(as.actionMode) + return + } + + if !tearingDown { + tearingDown = true + err = processTeardown(as.actionMode, as.signed) + } + + } + for { var err error select { @@ -109,7 +190,8 @@ func (s *serviceRuntime) Run(ctx context.Context, comm Communicator) (err error) s.log.Debug("context is done. exiting.") return ctx.Err() case as := <-s.actionCh: - switch as { + s.log.Debugf("got action %v for %s service", as.actionMode, s.name()) + switch as.actionMode { case actionStart: // Initial state on start lastCheckin = time.Time{} @@ -135,17 +217,10 @@ func (s *serviceRuntime) Run(ctx context.Context, comm Communicator) (err error) // Start check-in timer checkinTimer.Reset(s.checkinPeriod()) - case actionStop, actionTeardown: - // Stop check-in timer - s.log.Debugf("stop check-in timer for %s service", s.name()) - checkinTimer.Stop() - - // Stop connection info - s.log.Debugf("stop connection info for %s service", s.name()) - cisStop() - - // Stop service - s.stop(ctx, comm, lastCheckin, as == actionTeardown) + case actionStop: + onStop(as.actionMode) + case actionTeardown: + onTeardown(as) } if err != nil { s.forceCompState(client.UnitStateFailed, err.Error()) @@ -153,14 +228,58 @@ func (s *serviceRuntime) Run(ctx context.Context, comm Communicator) (err error) case newComp := <-s.compCh: s.processNewComp(newComp, comm) case checkin := <-comm.CheckinObserved(): + s.log.Debugf("got check-in for %s service, tearingDown=%v", s.name(), tearingDown) s.processCheckin(checkin, comm, &lastCheckin) + // Got check-in upon teardown update + // tearingDown can be set to true only if tamper protection feature is enabled + if tearingDown { + tearingDown = false + teardownCheckinTimer.Stop() + onStop(actionTeardown) + } case <-checkinTimer.C: s.checkStatus(s.checkinPeriod(), &lastCheckin, &missedCheckins) checkinTimer.Reset(s.checkinPeriod()) + case <-teardownCheckinTimer.C: + s.log.Debugf("got tearing down timeout for %s service", s.name()) + // Teardown timed out + // tearingDown can be set to true only if tamper protection feature is enabled + if tearingDown { + tearingDown = false + onStop(actionTeardown) + } } } } +func injectSigned(comp component.Component, signed *component.Signed) (component.Component, error) { + if signed == nil { + return comp, nil + } + + const signedKey = "signed" + for i, unit := range comp.Units { + if unit.Type == client.UnitTypeInput { + unitCfgMap := unit.Config.Source.AsMap() + + unitCfgMap[signedKey] = map[string]interface{}{ + "data": signed.Data, + "signature": signed.Signature, + } + + unitCfg, err := component.ExpectedConfig(unitCfgMap) + if err != nil { + return comp, err + } + + unit.Config = unitCfg + comp.Units[i] = unit + } + } + + return comp, nil +} + func (s *serviceRuntime) start(ctx context.Context) (err error) { name := s.name() @@ -352,7 +471,7 @@ func (s *serviceRuntime) Start() error { case <-s.actionCh: default: } - s.actionCh <- actionStart + s.actionCh <- actionModeSigned{actionStart, nil} return nil } @@ -378,20 +497,20 @@ func (s *serviceRuntime) Stop() error { case <-s.actionCh: default: } - s.actionCh <- actionStop + s.actionCh <- actionModeSigned{actionStop, nil} return nil } // Teardown stop and uninstall the service. // // Non-blocking and never returns an error. -func (s *serviceRuntime) Teardown() error { +func (s *serviceRuntime) Teardown(signed *component.Signed) error { // clear channel so it's the latest action select { case <-s.actionCh: default: } - s.actionCh <- actionTeardown + s.actionCh <- actionModeSigned{actionTeardown, signed} return nil } @@ -450,19 +569,58 @@ func (s *serviceRuntime) install(ctx context.Context) error { func (s *serviceRuntime) uninstall(ctx context.Context) error { // Always retry for internal attempts to uninstall, because they are an attempt to converge the agent's current state // with its desired state based on the agent policy. - return uninstallService(ctx, s.log, s.comp, s.executeServiceCommandImpl) + return uninstallService(ctx, s.log, s.comp, "", s.executeServiceCommandImpl) } // UninstallService uninstalls the service. When shouldRetry is true the uninstall command will be retried until it succeeds. -func UninstallService(ctx context.Context, log *logger.Logger, comp component.Component) error { - return uninstallService(ctx, log, comp, executeServiceCommand) +func UninstallService(ctx context.Context, log *logger.Logger, comp component.Component, uninstallToken string) error { + return uninstallService(ctx, log, comp, uninstallToken, executeServiceCommand) +} + +//nolint:gosec // was false flagged as hardcoded credentials by linter. it is not. +const uninstallTokenArg = "--uninstall-token" + +// resolveUninstallTokenArg Resolves the uninstall token parameter. +// If the uninstall spec arguments contains the --uninstall-token then +// 1. Remove the argument if the value of uninstallToken is empty +// or +// 2. Inject the value of uninstallToken after the --uninstall-token argument +// +// If args do not contain "--uninstall-token", older endpoint spec, do nothing +func resolveUninstallTokenArg(uninstallSpec *component.ServiceOperationsCommandSpec, uninstallToken string) *component.ServiceOperationsCommandSpec { + if uninstallSpec == nil { + return nil + } + + spec := *uninstallSpec + for i, arg := range spec.Args { + if arg == uninstallTokenArg { + if uninstallToken == "" { // Remove --uninstall-token argument if the token is empty + spec.Args = append(spec.Args[:i], spec.Args[i+1:]...) + } else { // Inject token value after --uninstall-token argument + args := append(spec.Args[:i+1], uninstallToken) + spec.Args = append(args, spec.Args[i+1:]...) + } + break + } + } + return &spec } -func uninstallService(ctx context.Context, log *logger.Logger, comp component.Component, executeServiceCommandImpl executeServiceCommandFunc) error { +func uninstallService(ctx context.Context, log *logger.Logger, comp component.Component, uninstallToken string, executeServiceCommandImpl executeServiceCommandFunc) error { if comp.InputSpec.Spec.Service.Operations.Uninstall == nil { log.Errorf("missing uninstall spec for %s service", comp.InputSpec.BinaryName) return ErrOperationSpecUndefined } + + // If tamper protection feature flag is disabled, force uninstallToken value to empty, + // this will remove the --uninstall-token command arg + if !features.TamperProtection() { + uninstallToken = "" + } + + uninstallSpec := resolveUninstallTokenArg(comp.InputSpec.Spec.Service.Operations.Uninstall, uninstallToken) + log.Debugf("uninstall %s service", comp.InputSpec.BinaryName) - return executeServiceCommandImpl(ctx, log, comp.InputSpec.BinaryPath, comp.InputSpec.Spec.Service.Operations.Uninstall) + return executeServiceCommandImpl(ctx, log, comp.InputSpec.BinaryPath, uninstallSpec) } diff --git a/pkg/component/runtime/service_test.go b/pkg/component/runtime/service_test.go new file mode 100644 index 00000000000..e8ce6e682b9 --- /dev/null +++ b/pkg/component/runtime/service_test.go @@ -0,0 +1,177 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License; +// you may not use this file except in compliance with the Elastic License. + +package runtime + +import ( + "testing" + + "github.com/elastic/elastic-agent-client/v7/pkg/client" + "github.com/elastic/elastic-agent-client/v7/pkg/proto" + "github.com/elastic/elastic-agent/pkg/component" + + "github.com/google/go-cmp/cmp" +) + +func makeComponent(name string, config map[string]interface{}) (component.Component, error) { + c := component.Component{ + Units: []component.Unit{ + { + Type: client.UnitTypeInput, + Config: &proto.UnitExpectedConfig{Type: name}, + }, + }, + InputSpec: &component.InputRuntimeSpec{ + Spec: component.InputSpec{ + Name: name, + }, + }, + } + unitCfg, err := component.ExpectedConfig(config) + if err != nil { + return c, err + } + c.Units[0].Config = unitCfg + return c, nil +} + +func makeEndpointComponent(t *testing.T, config map[string]interface{}) component.Component { + comp, err := makeComponent("endpoint", config) + if err != nil { + t.Fatal(err) + } + return comp +} + +func compareCompsConfigs(t *testing.T, comp component.Component, cfg map[string]interface{}) { + for _, unit := range comp.Units { + if unit.Type == client.UnitTypeInput { + unitCfgMap := unit.Config.Source.AsMap() + diff := cmp.Diff(cfg, unitCfgMap) + if diff != "" { + t.Fatal(diff) + } + } + } +} + +func TestInjectSigned(t *testing.T) { + signed := &component.Signed{ + Data: "eyJAdGltZXN0YW1wIjoiMjAyMy0wNS0yMlQxNzoxOToyOC40NjNaIiwiZXhwaXJhdGlvbiI6IjIwMjMtMDYtMjFUMTc6MTk6MjguNDYzWiIsImFnZW50cyI6WyI3ZjY0YWI2NC1hNmM0LTQ2ZTMtODIyYS0zODUxZGVkYTJmY2UiXSwiYWN0aW9uX2lkIjoiNGYwODQ2MGYtMDE0Yy00ZDllLWJmOGEtY2FhNjQyNzRhZGU0IiwidHlwZSI6IlVORU5ST0xMIiwidHJhY2VwYXJlbnQiOiIwMC1iOTBkYTlmOGNjNzdhODk0OTc0ZWIxZTIzMGNmNjc2Yy1lOTNlNzk4YTU4ODg2MDVhLTAxIn0=", + Signature: "MEUCIAxxsi9ff1zyV0+4fsJLqbP8Qb83tedU5iIFldtxEzEfAiEA0KUsrL7q+Fv7z6Boux3dY2P4emGi71jsMGanIZ552bM=", + } + + tests := []struct { + name string + cfg map[string]interface{} + signed *component.Signed + wantCfg map[string]interface{} + }{ + { + name: "nil signed", + cfg: map[string]interface{}{}, + wantCfg: map[string]interface{}{}, + }, + { + name: "signed", + cfg: map[string]interface{}{}, + signed: signed, + wantCfg: map[string]interface{}{ + "signed": map[string]interface{}{ + "data": "eyJAdGltZXN0YW1wIjoiMjAyMy0wNS0yMlQxNzoxOToyOC40NjNaIiwiZXhwaXJhdGlvbiI6IjIwMjMtMDYtMjFUMTc6MTk6MjguNDYzWiIsImFnZW50cyI6WyI3ZjY0YWI2NC1hNmM0LTQ2ZTMtODIyYS0zODUxZGVkYTJmY2UiXSwiYWN0aW9uX2lkIjoiNGYwODQ2MGYtMDE0Yy00ZDllLWJmOGEtY2FhNjQyNzRhZGU0IiwidHlwZSI6IlVORU5ST0xMIiwidHJhY2VwYXJlbnQiOiIwMC1iOTBkYTlmOGNjNzdhODk0OTc0ZWIxZTIzMGNmNjc2Yy1lOTNlNzk4YTU4ODg2MDVhLTAxIn0=", + "signature": "MEUCIAxxsi9ff1zyV0+4fsJLqbP8Qb83tedU5iIFldtxEzEfAiEA0KUsrL7q+Fv7z6Boux3dY2P4emGi71jsMGanIZ552bM=", + }, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + newComp, err := injectSigned(makeEndpointComponent(t, tc.cfg), tc.signed) + if err != nil { + t.Fatal(err) + } + + compareCompsConfigs(t, newComp, tc.wantCfg) + }) + } + +} + +func TestResolveUninstallTokenArg(t *testing.T) { + tests := []struct { + name string + uninstallSpec *component.ServiceOperationsCommandSpec + uninstallToken string + wantUninstallSpec *component.ServiceOperationsCommandSpec + }{ + { + name: "nil uninstall spec", + }, + { + name: "no uninstall token", + uninstallSpec: &component.ServiceOperationsCommandSpec{ + Args: []string{"uninstall", "--log", "stderr"}, + }, + wantUninstallSpec: &component.ServiceOperationsCommandSpec{ + Args: []string{"uninstall", "--log", "stderr"}, + }, + }, + { + name: "with uninstall token arg and empty token value", + uninstallSpec: &component.ServiceOperationsCommandSpec{ + Args: []string{"uninstall", "--log", "stderr", "--uninstall-token"}, + }, + wantUninstallSpec: &component.ServiceOperationsCommandSpec{ + Args: []string{"uninstall", "--log", "stderr"}, + }, + }, + { + name: "with uninstall token arg and non-empty token value", + uninstallSpec: &component.ServiceOperationsCommandSpec{ + Args: []string{"uninstall", "--log", "stderr", "--uninstall-token"}, + }, + uninstallToken: "EQo1ML2T95pdcH", + wantUninstallSpec: &component.ServiceOperationsCommandSpec{ + Args: []string{"uninstall", "--log", "stderr", "--uninstall-token", "EQo1ML2T95pdcH"}, + }, + }, + { + name: "with uninstall token args cap gt len", + uninstallSpec: &component.ServiceOperationsCommandSpec{ + Args: func() []string { + args := make([]string, 0, 8) + args = append(args, "uninstall", "--log", "stderr", "--uninstall-token") + return args + }(), + }, + uninstallToken: "EQo1ML2T95pdcH", + wantUninstallSpec: &component.ServiceOperationsCommandSpec{ + Args: []string{"uninstall", "--log", "stderr", "--uninstall-token", "EQo1ML2T95pdcH"}, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + var originalUninstallSpec component.ServiceOperationsCommandSpec + if tc.uninstallSpec != nil { + originalUninstallSpec = *tc.uninstallSpec + } + spec := resolveUninstallTokenArg(tc.uninstallSpec, tc.uninstallToken) + diff := cmp.Diff(tc.wantUninstallSpec, spec) + if diff != "" { + t.Fatal(diff) + } + + // Test that the original spec was not changed + if tc.uninstallSpec != nil { + diff = cmp.Diff(originalUninstallSpec, *tc.uninstallSpec) + if diff != "" { + t.Fatal(diff) + } + } + }) + } +} diff --git a/pkg/features/features.go b/pkg/features/features.go index 7613b055fd2..89a52c84d5b 100644 --- a/pkg/features/features.go +++ b/pkg/features/features.go @@ -15,8 +15,16 @@ import ( "google.golang.org/protobuf/types/known/structpb" ) +// The default value of tamper protection flag if the flag is missing +// The following was agreed upon for upcoming releases +// 8.10 - default is disabled +// 8.11+ - default is enabled +const defaultTamperProtection = false + var ( - current = Flags{} + current = Flags{ + tamperProtection: defaultTamperProtection, + } ) type BoolValueOnChangeCallback func(new, old bool) @@ -27,6 +35,8 @@ type Flags struct { fqdn bool fqdnCallbacks map[string]BoolValueOnChangeCallback + + tamperProtection bool } type cfg struct { @@ -35,6 +45,9 @@ type cfg struct { FQDN struct { Enabled bool `json:"enabled" yaml:"enabled" config:"enabled"` } `json:"fqdn" yaml:"fqdn" config:"fqdn"` + TamperProtection *struct { + Enabled bool `json:"enabled" yaml:"enabled" config:"enabled"` + } `json:"tamper_protection,omitempty" yaml:"tamper_protection,omitempty" config:"tamper_protection,omitempty"` } `json:"features" yaml:"features" config:"features"` } `json:"agent" yaml:"agent" config:"agent"` } @@ -46,6 +59,13 @@ func (f *Flags) FQDN() bool { return f.fqdn } +func (f *Flags) TamperProtection() bool { + f.mu.RLock() + defer f.mu.RUnlock() + + return f.tamperProtection +} + func (f *Flags) AsProto() *proto.Features { return &proto.Features{ Fqdn: &proto.FQDNFeature{ @@ -93,6 +113,14 @@ func (f *Flags) setFQDN(newValue bool) { } } +// setTamperProtection sets the value of the TamperProtection flag in Flags. +func (f *Flags) setTamperProtection(newValue bool) { + f.mu.Lock() + defer f.mu.Unlock() + + f.tamperProtection = newValue +} + // setSource sets the source from he given cfg. func (f *Flags) setSource(c cfg) error { // Use JSON marshalling-unmarshalling to convert cfg to mapstr @@ -150,6 +178,12 @@ func Parse(policy any) (*Flags, error) { flags := new(Flags) flags.setFQDN(parsedFlags.Agent.Features.FQDN.Enabled) + + // Tamper protection flag is optional, fallback on default value if missing + if parsedFlags.Agent.Features.TamperProtection != nil { + flags.setTamperProtection(parsedFlags.Agent.Features.TamperProtection.Enabled) + } + if err := flags.setSource(parsedFlags); err != nil { return nil, fmt.Errorf("error creating feature flags source: %w", err) } @@ -171,6 +205,7 @@ func Apply(c *config.Config) error { } current.setFQDN(parsed.FQDN()) + current.setTamperProtection(parsed.TamperProtection()) return err } @@ -178,3 +213,8 @@ func Apply(c *config.Config) error { func FQDN() bool { return current.FQDN() } + +// TamperProtection reports if tamper protection feature is enabled +func TamperProtection() bool { + return current.TamperProtection() +} diff --git a/specs/endpoint-security.spec.yml b/specs/endpoint-security.spec.yml index 5c51869953d..341f037de7f 100644 --- a/specs/endpoint-security.spec.yml +++ b/specs/endpoint-security.spec.yml @@ -10,6 +10,9 @@ inputs: outputs: - elasticsearch - logstash + proxied_actions: + - UNENROLL + - UPGRADE runtime: preventions: - condition: ${runtime.arch} == 'arm64' and ${runtime.family} == 'redhat' and ${runtime.major} == '7' @@ -41,6 +44,7 @@ inputs: - "uninstall" - "--log" - "stderr" + - "--uninstall-token" timeout: 600s - name: endpoint description: "Endpoint Security" @@ -50,6 +54,9 @@ inputs: outputs: - elasticsearch - logstash + proxied_actions: + - UNENROLL + - UPGRADE service: cport: 6788 log: @@ -62,6 +69,9 @@ inputs: outputs: - elasticsearch - logstash + proxied_actions: + - UNENROLL + - UPGRADE runtime: preventions: - condition: ${user.root} == false