Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

access_monitoring_rule: Support plugin.spec.name condition variable #51816

Draft
wants to merge 5 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,11 @@ metadata:
name: your-plugin-name
spec:
subjects: ['access_request']
condition: 'access_request.spec.roles.contains("your_role_name")'
condition: >
plugin.spec.name == "slack" &&
access_request.spec.roles.contains("your_role_name")
notification:
# Deprecated: Use condition: 'plugin.spec.name == "slack"' instead.
name: 'slack'
recipients: ['your_slack_channel']
```
Expand All @@ -100,6 +103,7 @@ Fields of the Access Request that are currently supported are
| access_request.spec.request_reason | The request reason. |
| access_request.spec.creation_time | The creation time of the request. |
| access_request.spec.expiry | The expiry time of the request. |
| plugin.spec.name | The name of the plugin that this AMR applies to. |

Predicate expressions used in the condition of Access Monitoring Rules must evaluate to
either true or false.
Expand Down
53 changes: 46 additions & 7 deletions integrations/access/accessmonitoring/access_monitoring_rules.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package accessmonitoring

import (
"context"
"fmt"
"maps"
"slices"
"sync"
Expand Down Expand Up @@ -130,6 +131,10 @@ func (amrh *RuleHandler) HandleAccessMonitoringRule(ctx context.Context, event t
delete(amrh.accessMonitoringRules.rules, event.Resource.GetName())
return nil
}

// Convert the notification.name requirement into a condition expression.
appendPluginNameCondition(req)

amrh.accessMonitoringRules.rules[req.Metadata.Name] = req
if amrh.onCacheUpdateCallback != nil {
amrh.onCacheUpdateCallback(types.OpPut, req.GetMetadata().Name, req)
Expand All @@ -149,9 +154,21 @@ func (amrh *RuleHandler) RecipientsFromAccessMonitoringRules(ctx context.Context
recipientSet := common.NewRecipientSet()

for _, rule := range amrh.getAccessMonitoringRules() {
match, err := MatchAccessRequest(rule.Spec.Condition, req)
env := AccessRequestExpressionEnv{
Roles: req.GetRoles(),
SuggestedReviewers: req.GetSuggestedReviewers(),
Annotations: req.GetSystemAnnotations(),
User: req.GetUser(),
RequestReason: req.GetRequestReason(),
CreationTime: req.GetCreationTime(),
Expiry: req.Expiry(),
Plugin: PluginExpressionEnv{
Name: amrh.pluginName,
},
}
match, err := IsConditionMatched(rule.GetSpec().GetCondition(), env)
if err != nil {
log.WarnContext(ctx, "Failed to parse access monitoring notification rule",
log.WarnContext(ctx, "Failed to parse/evaluate access monitoring rule",
"error", err,
"rule", rule.Metadata.Name,
)
Expand All @@ -176,9 +193,21 @@ func (amrh *RuleHandler) RawRecipientsFromAccessMonitoringRules(ctx context.Cont
log := logger.Get(ctx)
recipientSet := stringset.New()
for _, rule := range amrh.getAccessMonitoringRules() {
match, err := MatchAccessRequest(rule.Spec.Condition, req)
env := AccessRequestExpressionEnv{
Roles: req.GetRoles(),
SuggestedReviewers: req.GetSuggestedReviewers(),
Annotations: req.GetSystemAnnotations(),
User: req.GetUser(),
RequestReason: req.GetRequestReason(),
CreationTime: req.GetCreationTime(),
Expiry: req.Expiry(),
Plugin: PluginExpressionEnv{
Name: amrh.pluginName,
},
}
match, err := IsConditionMatched(rule.Spec.Condition, env)
if err != nil {
log.WarnContext(ctx, "Failed to parse access monitoring notification rule",
log.WarnContext(ctx, "Failed to parse/evaluate access monitoring rule",
"error", err,
"rule", rule.Metadata.Name,
)
Expand Down Expand Up @@ -225,10 +254,20 @@ func (amrh *RuleHandler) getAccessMonitoringRules() map[string]*accessmonitoring
}

func (amrh *RuleHandler) ruleApplies(amr *accessmonitoringrulesv1.AccessMonitoringRule) bool {
if amr.Spec.Notification.Name != amrh.pluginName {
return false
}
return slices.ContainsFunc(amr.Spec.Subjects, func(subject string) bool {
return subject == types.KindAccessRequest
})
}

// appendPluginNameCondition converts the notification.name requirement into a
// condition.
func appendPluginNameCondition(req *accessmonitoringrulesv1.AccessMonitoringRule) {
if req.GetSpec().GetNotification().GetName() == "" {
return
}

// The notification.name is deprecated. Use plugin.spec.name condition instead.
req.Spec.Condition = fmt.Sprintf("plugin.spec.name == %q && %s",
req.GetSpec().GetNotification().GetName(),
req.GetSpec().GetCondition())
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,32 +44,27 @@ func TestHandleAccessMonitoringRule(t *testing.T) {
rule1, err := services.NewAccessMonitoringRuleWithLabels("rule1", nil, &pb.AccessMonitoringRuleSpec{
Subjects: []string{types.KindAccessRequest},
Condition: "true",
Notification: &pb.Notification{
Name: "fakePluginName",
Recipients: []string{"a", "b"},
},
})
require.NoError(t, err)
amrh.HandleAccessMonitoringRule(context.Background(), types.Event{
Type: types.OpPut,
Resource: types.Resource153ToLegacy(rule1),
})
require.Len(t, amrh.getAccessMonitoringRules(), 1)
require.Len(t, amrh.getAccessMonitoringRules(), 1,
"cache AMRs with subject == 'access_request'")
require.Equal(t, `true`, amrh.getAccessMonitoringRules()["rule1"].GetSpec().GetCondition())

rule2, err := services.NewAccessMonitoringRuleWithLabels("rule2", nil, &pb.AccessMonitoringRuleSpec{
Subjects: []string{types.KindAccessRequest},
Subjects: []string{types.KindAccessList},
Condition: "true",
Notification: &pb.Notification{
Name: "aDifferentFakePlugin",
Recipients: []string{"a", "b"},
},
})
require.NoError(t, err)
amrh.HandleAccessMonitoringRule(context.Background(), types.Event{
Type: types.OpPut,
Resource: types.Resource153ToLegacy(rule2),
})
require.Len(t, amrh.getAccessMonitoringRules(), 1)
require.Len(t, amrh.getAccessMonitoringRules(), 1,
"do not cache AMRs with subject != 'access_request'")

amrh.HandleAccessMonitoringRule(context.Background(), types.Event{
Type: types.OpDelete,
Expand All @@ -78,17 +73,18 @@ func TestHandleAccessMonitoringRule(t *testing.T) {
require.Empty(t, amrh.getAccessMonitoringRules())
}

func TestHandleAccessMonitoringRulePluginNameMisMatch(t *testing.T) {
func TestNotificationRule(t *testing.T) {
amrh := NewRuleHandler(RuleHandlerConfig{
PluginType: "fakePluginType",
PluginName: "fakePluginName",
FetchRecipientCallback: mockFetchRecipient,
})

// Support empty notification name.
rule1, err := services.NewAccessMonitoringRuleWithLabels("rule1", nil, &pb.AccessMonitoringRuleSpec{
Subjects: []string{types.KindAccessRequest},
Condition: "true",
Notification: &pb.Notification{
Name: "notTheFakePluginName",
Recipients: []string{"a", "b"},
},
})
Expand All @@ -97,26 +93,35 @@ func TestHandleAccessMonitoringRulePluginNameMisMatch(t *testing.T) {
Type: types.OpPut,
Resource: types.Resource153ToLegacy(rule1),
})
require.Empty(t, amrh.getAccessMonitoringRules())
require.Equal(t, `true`, amrh.getAccessMonitoringRules()["rule1"].GetSpec().GetCondition())

// Support nil notification.
rule2, err := services.NewAccessMonitoringRuleWithLabels("rule2", nil, &pb.AccessMonitoringRuleSpec{
Subjects: []string{types.KindAccessRequest},
Condition: "true",
Notification: &pb.Notification{
Name: "fakePluginName",
Recipients: []string{"c", "d"},
},
})
require.NoError(t, err)
amrh.HandleAccessMonitoringRule(context.Background(), types.Event{
Type: types.OpPut,
Resource: types.Resource153ToLegacy(rule2),
})
require.Len(t, amrh.getAccessMonitoringRules(), 1)
require.Equal(t, `true`, amrh.getAccessMonitoringRules()["rule2"].GetSpec().GetCondition())

// Support notification name.
rule3, err := services.NewAccessMonitoringRuleWithLabels("rule3", nil, &pb.AccessMonitoringRuleSpec{
Subjects: []string{types.KindAccessRequest},
Condition: "true",
Notification: &pb.Notification{
Name: "fakePluginName",
Recipients: []string{"a", "b"},
},
})
require.NoError(t, err)
amrh.HandleAccessMonitoringRule(context.Background(), types.Event{
Type: types.OpDelete,
Resource: types.Resource153ToLegacy(rule2),
Type: types.OpPut,
Resource: types.Resource153ToLegacy(rule3),
})
require.Empty(t, amrh.getAccessMonitoringRules())
require.Equal(t, `plugin.spec.name == "fakePluginName" && true`,
amrh.getAccessMonitoringRules()["rule3"].GetSpec().GetCondition(),
"AMR condition should be modified to include plugin.spec.name validation")
}
65 changes: 34 additions & 31 deletions integrations/access/accessmonitoring/request_mapping.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,24 +21,31 @@ import (

"github.com/gravitational/trace"

"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/expression"
"github.com/gravitational/teleport/lib/utils/typical"
)

// accessRequestExpressionEnv holds user details that can be mapped in an
// AccessRequestExpressionEnv holds user details that can be mapped in an
// access request condition assertion.
type accessRequestExpressionEnv struct {
type AccessRequestExpressionEnv struct {
Roles []string
SuggestedReviewers []string
Annotations map[string][]string
User string
RequestReason string
CreationTime time.Time
Expiry time.Time

Plugin PluginExpressionEnv
}

// PluginExpressionEnv holds plugin specific condition variables.
type PluginExpressionEnv struct {
// Name specifies the plugin name.
Name string
}

type accessRequestExpression typical.Expression[accessRequestExpressionEnv, any]
type accessRequestExpression typical.Expression[AccessRequestExpressionEnv, any]

func parseAccessRequestExpression(expr string) (accessRequestExpression, error) {
parser, err := newRequestConditionParser()
Expand All @@ -52,62 +59,58 @@ func parseAccessRequestExpression(expr string) (accessRequestExpression, error)
return parsedExpr, nil
}

func newRequestConditionParser() (*typical.Parser[accessRequestExpressionEnv, any], error) {
func newRequestConditionParser() (*typical.Parser[AccessRequestExpressionEnv, any], error) {
typicalEnvVar := map[string]typical.Variable{
"true": true,
"false": false,
"access_request.spec.roles": typical.DynamicVariable[accessRequestExpressionEnv](func(env accessRequestExpressionEnv) (expression.Set, error) {
"access_request.spec.roles": typical.DynamicVariable(func(env AccessRequestExpressionEnv) (expression.Set, error) {
return expression.NewSet(env.Roles...), nil
}),
"access_request.spec.suggested_reviewers": typical.DynamicVariable[accessRequestExpressionEnv](func(env accessRequestExpressionEnv) (expression.Set, error) {
"access_request.spec.suggested_reviewers": typical.DynamicVariable(func(env AccessRequestExpressionEnv) (expression.Set, error) {
return expression.NewSet(env.SuggestedReviewers...), nil
}),
"access_request.spec.system_annotations": typical.DynamicMap[accessRequestExpressionEnv, expression.Set](func(env accessRequestExpressionEnv) (expression.Dict, error) {
"access_request.spec.system_annotations": typical.DynamicMap(func(env AccessRequestExpressionEnv) (expression.Dict, error) {
return expression.DictFromStringSliceMap(env.Annotations), nil
}),
"access_request.spec.user": typical.DynamicVariable[accessRequestExpressionEnv](func(env accessRequestExpressionEnv) (string, error) {
"access_request.spec.user": typical.DynamicVariable(func(env AccessRequestExpressionEnv) (string, error) {
return env.User, nil
}),
"access_request.spec.request_reason": typical.DynamicVariable[accessRequestExpressionEnv](func(env accessRequestExpressionEnv) (string, error) {
"access_request.spec.request_reason": typical.DynamicVariable(func(env AccessRequestExpressionEnv) (string, error) {
return env.RequestReason, nil
}),
"access_request.spec.creation_time": typical.DynamicVariable[accessRequestExpressionEnv](func(env accessRequestExpressionEnv) (time.Time, error) {
"access_request.spec.creation_time": typical.DynamicVariable(func(env AccessRequestExpressionEnv) (time.Time, error) {
return env.CreationTime, nil
}),
"access_request.spec.expiry": typical.DynamicVariable[accessRequestExpressionEnv](func(env accessRequestExpressionEnv) (time.Time, error) {
"access_request.spec.expiry": typical.DynamicVariable(func(env AccessRequestExpressionEnv) (time.Time, error) {
return env.Expiry, nil
}),

// Plugin provided condition variables.
"plugin.spec.name": typical.DynamicVariable(func(env AccessRequestExpressionEnv) (string, error) {
return env.Plugin.Name, nil
}),
}
defParserSpec := expression.DefaultParserSpec[accessRequestExpressionEnv]()
defParserSpec := expression.DefaultParserSpec[AccessRequestExpressionEnv]()
defParserSpec.Variables = typicalEnvVar

requestConditionParser, err := typical.NewParser[accessRequestExpressionEnv, any](defParserSpec)
requestConditionParser, err := typical.NewParser[AccessRequestExpressionEnv, any](defParserSpec)
if err != nil {
return nil, trace.Wrap(err)
}
return requestConditionParser, nil
}

func MatchAccessRequest(expr string, req types.AccessRequest) (bool, error) {
parsedExpr, err := parseAccessRequestExpression(expr)
// IsConditionMatched returns the evaluated condition expression value.
// A true value indicates that the condition is a match for the access request env.
func IsConditionMatched(condition string, env AccessRequestExpressionEnv) (bool, error) {
parsedExpr, err := parseAccessRequestExpression(condition)
if err != nil {
return false, trace.Wrap(err)
}

match, err := parsedExpr.Evaluate(accessRequestExpressionEnv{
Roles: req.GetRoles(),
SuggestedReviewers: req.GetSuggestedReviewers(),
Annotations: req.GetSystemAnnotations(),
User: req.GetUser(),
RequestReason: req.GetRequestReason(),
CreationTime: req.GetCreationTime(),
Expiry: req.Expiry(),
})
match, err := parsedExpr.Evaluate(env)
if err != nil {
return false, trace.Wrap(err, "evaluating access monitoring rule condition expression %q", expr)
}
if matched, ok := match.(bool); ok && matched {
return true, nil
return false, trace.Wrap(err, "evaluating access monitoring rule condition expression %q", condition)
}
return false, nil
matched, ok := match.(bool)
return ok && matched, nil
}
Loading
Loading