diff --git a/internal/notification/channel.go b/internal/notification/channel.go index 9db71a241..6d8e4a7ee 100644 --- a/internal/notification/channel.go +++ b/internal/notification/channel.go @@ -6,6 +6,7 @@ import ( "fmt" "github.com/openmeterio/openmeter/api" + "github.com/openmeterio/openmeter/internal/notification/webhook" "github.com/openmeterio/openmeter/pkg/convert" "github.com/openmeterio/openmeter/pkg/defaultx" "github.com/openmeterio/openmeter/pkg/models" @@ -143,7 +144,17 @@ type WebHookChannelConfig struct { // Validate returns an error if webhook channel configuration is invalid. func (w WebHookChannelConfig) Validate() error { if w.URL == "" { - return fmt.Errorf("invalid webhook channel configuration: missing URL") + return ValidationError{ + Err: errors.New("missing URL"), + } + } + + if w.SigningSecret != "" { + if err := webhook.ValidateSigningSecret(w.SigningSecret); err != nil { + return ValidationError{ + Err: fmt.Errorf("invalid signing secret: %w", err), + } + } } return nil diff --git a/internal/notification/repository/repository.go b/internal/notification/repository/repository.go index 8f467e954..4320676ff 100644 --- a/internal/notification/repository/repository.go +++ b/internal/notification/repository/repository.go @@ -285,6 +285,11 @@ func (r repository) ListRules(ctx context.Context, params notification.ListRules query = query.Where(ruledb.TypeIn(params.Types...)) } + if len(params.Channels) > 0 { + query = query.Where(ruledb.HasChannelsWith(channeldb.IDIn(params.Channels...))) + } + + // Eager load Channels query = query.WithChannels() order := entutils.GetOrdering(sortx.OrderDefault) diff --git a/internal/notification/rule.go b/internal/notification/rule.go index b5c1268fb..3e20e1fc5 100644 --- a/internal/notification/rule.go +++ b/internal/notification/rule.go @@ -204,19 +204,10 @@ func (b BalanceThresholdRuleConfig) Validate(ctx context.Context, service Servic for _, threshold := range b.Thresholds { switch threshold.Type { - case BalanceThresholdTypeNumber: + case BalanceThresholdTypeNumber, BalanceThresholdTypePercent: if threshold.Value <= 0 { return ValidationError{ - Err: fmt.Errorf("invalid threshold with type %s: value must be greater than 0: %f", - threshold.Type, - threshold.Value, - ), - } - } - case BalanceThresholdTypePercent: - if threshold.Value <= 0 || threshold.Value > 100 { - return ValidationError{ - Err: fmt.Errorf("invalid threshold with type %s: value must be between 0 anad 100: %f", + Err: fmt.Errorf("invalid threshold with type %s: value must be greater than 0: %.2f", threshold.Type, threshold.Value, ), @@ -274,6 +265,7 @@ type ListRulesInput struct { Rules []string IncludeDisabled bool Types []RuleType + Channels []string OrderBy api.ListNotificationRulesParamsOrderBy Order sortx.Order diff --git a/internal/notification/service.go b/internal/notification/service.go index 5b8c1a616..d20323a1d 100644 --- a/internal/notification/service.go +++ b/internal/notification/service.go @@ -191,6 +191,27 @@ func (c service) DeleteChannel(ctx context.Context, params DeleteChannelInput) e return fmt.Errorf("invalid params: %w", err) } + rules, err := c.repo.ListRules(ctx, ListRulesInput{ + Namespaces: []string{params.Namespace}, + IncludeDisabled: true, + Channels: []string{params.ID}, + }) + if err != nil { + return fmt.Errorf("failed to list rules for channel: %w", err) + } + + if rules.TotalCount > 0 { + ruleIDs := make([]string, 0, len(rules.Items)) + + for _, rule := range rules.Items { + ruleIDs = append(ruleIDs, rule.ID) + } + + return ValidationError{ + Err: fmt.Errorf("cannot delete channel as it is assigned to one or more rules: %v", ruleIDs), + } + } + txFunc := func(ctx context.Context, repo TxRepository) error { if err := c.webhook.DeleteWebhook(ctx, webhook.DeleteWebhookInput{ Namespace: params.Namespace, diff --git a/internal/notification/webhook/svix.go b/internal/notification/webhook/svix.go index 8a812c6ee..95f578362 100644 --- a/internal/notification/webhook/svix.go +++ b/internal/notification/webhook/svix.go @@ -268,7 +268,7 @@ func (h svixWebhookHandler) CreateWebhook(ctx context.Context, params CreateWebh // Set custom HTTP headers for webhook endpoint if provided if len(params.CustomHeaders) > 0 { - webhook.CustomHeaders, err = h.GetOrUpdateEndpointHeaders(ctx, app.Id, endpoint.Id, nil) + webhook.CustomHeaders, err = h.GetOrUpdateEndpointHeaders(ctx, app.Id, endpoint.Id, params.CustomHeaders) if err != nil { return nil, err } diff --git a/internal/notification/webhook/webhook.go b/internal/notification/webhook/webhook.go index 09693e62b..ac443f98e 100644 --- a/internal/notification/webhook/webhook.go +++ b/internal/notification/webhook/webhook.go @@ -77,9 +77,8 @@ func (i CreateWebhookInput) Validate() error { } if i.Secret != nil { - secret, _ := strings.CutPrefix(*i.Secret, SigningSecretPrefix) - if _, err := base64.StdEncoding.DecodeString(secret); err != nil { - return errors.New("invalid secret: must be base64 encoded") + if err := ValidateSigningSecret(*i.Secret); err != nil { + return fmt.Errorf("invalid secret: %w", err) } } @@ -262,3 +261,16 @@ func New(config Config) (Handler, error) { return handler, nil } + +func ValidateSigningSecret(secret string) error { + s, _ := strings.CutPrefix(secret, SigningSecretPrefix) + if len(s) < 32 || len(s) > 100 { + return errors.New("secret length must be between 32 to 100 chars without the optional prefix") + } + + if _, err := base64.StdEncoding.DecodeString(s); err != nil { + return errors.New("invalid base64 string") + } + + return nil +}