From 1f9aced6b077d9386525a20a6a40549eba75670b Mon Sep 17 00:00:00 2001 From: Krisztian Gacsal Date: Tue, 27 Aug 2024 16:20:25 +0200 Subject: [PATCH 1/3] refactor: notification service --- cmd/notification-service/main.go | 4 +- cmd/server/main.go | 3 +- openmeter/notification/service.go | 582 ---------------------- openmeter/notification/service/channel.go | 222 +++++++++ openmeter/notification/service/event.go | 97 ++++ openmeter/notification/service/rule.go | 205 ++++++++ openmeter/notification/service/service.go | 96 ++++ test/notification/testenv.go | 3 +- 8 files changed, 626 insertions(+), 586 deletions(-) create mode 100644 openmeter/notification/service/channel.go create mode 100644 openmeter/notification/service/event.go create mode 100644 openmeter/notification/service/rule.go create mode 100644 openmeter/notification/service/service.go diff --git a/cmd/notification-service/main.go b/cmd/notification-service/main.go index 03470122e..8d05365a4 100644 --- a/cmd/notification-service/main.go +++ b/cmd/notification-service/main.go @@ -30,9 +30,9 @@ import ( "github.com/openmeterio/openmeter/config" "github.com/openmeterio/openmeter/openmeter/meter" - "github.com/openmeterio/openmeter/openmeter/notification" "github.com/openmeterio/openmeter/openmeter/notification/consumer" notificationrepository "github.com/openmeterio/openmeter/openmeter/notification/repository" + notificationservice "github.com/openmeterio/openmeter/openmeter/notification/service" notificationwebhook "github.com/openmeterio/openmeter/openmeter/notification/webhook" registrybuilder "github.com/openmeterio/openmeter/openmeter/registry/builder" "github.com/openmeterio/openmeter/openmeter/streaming/clickhouse_connector" @@ -311,7 +311,7 @@ func main() { os.Exit(1) } - notificationService, err := notification.New(notification.Config{ + notificationService, err := notificationservice.New(notificationservice.Config{ Repository: notificationRepo, Webhook: notificationWebhook, FeatureConnector: entitlementConnRegistry.Feature, diff --git a/cmd/server/main.go b/cmd/server/main.go index 6c16e55de..8120d78e6 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -45,6 +45,7 @@ import ( "github.com/openmeterio/openmeter/openmeter/namespace/namespacedriver" "github.com/openmeterio/openmeter/openmeter/notification" notificationrepository "github.com/openmeterio/openmeter/openmeter/notification/repository" + notificationservice "github.com/openmeterio/openmeter/openmeter/notification/service" notificationwebhook "github.com/openmeterio/openmeter/openmeter/notification/webhook" "github.com/openmeterio/openmeter/openmeter/registry" registrybuilder "github.com/openmeterio/openmeter/openmeter/registry/builder" @@ -402,7 +403,7 @@ func main() { os.Exit(1) } - notificationService, err = notification.New(notification.Config{ + notificationService, err = notificationservice.New(notificationservice.Config{ Repository: notificationRepo, Webhook: notificationWebhook, FeatureConnector: entitlementConnRegistry.Feature, diff --git a/openmeter/notification/service.go b/openmeter/notification/service.go index 32a582638..83c8b938f 100644 --- a/openmeter/notification/service.go +++ b/openmeter/notification/service.go @@ -2,17 +2,8 @@ package notification import ( "context" - "errors" - "fmt" - "log/slog" - "github.com/samber/lo" - - "github.com/openmeterio/openmeter/openmeter/notification/webhook" "github.com/openmeterio/openmeter/openmeter/productcatalog" - "github.com/openmeterio/openmeter/pkg/convert" - "github.com/openmeterio/openmeter/pkg/models" - "github.com/openmeterio/openmeter/pkg/pagination" ) type Service interface { @@ -53,576 +44,3 @@ type EventService interface { type FeatureService interface { ListFeature(ctx context.Context, namespace string, features ...string) ([]productcatalog.Feature, error) } - -var _ Service = (*service)(nil) - -const ( - ChannelIDMetadataKey = "om-channel-id" -) - -type service struct { - feature productcatalog.FeatureConnector - - repo Repository - webhook webhook.Handler - - eventHandler EventHandler - - logger *slog.Logger -} - -func (c service) Close() error { - return c.eventHandler.Close() -} - -type Config struct { - FeatureConnector productcatalog.FeatureConnector - - Repository Repository - Webhook webhook.Handler - - Logger *slog.Logger -} - -func New(config Config) (Service, error) { - if config.Repository == nil { - return nil, errors.New("missing repository") - } - - if config.FeatureConnector == nil { - return nil, errors.New("missing feature connector") - } - - if config.Webhook == nil { - return nil, errors.New("missing webhook handler") - } - - if config.Logger == nil { - return nil, errors.New("missing logger") - } - - eventHandler, err := NewEventHandler(EventHandlerConfig{ - Repository: config.Repository, - Webhook: config.Webhook, - Logger: config.Logger, - }) - if err != nil { - return nil, fmt.Errorf("failed to initialize notification event handler: %w", err) - } - - if err = eventHandler.Start(); err != nil { - return nil, fmt.Errorf("failed to initialize notification event handler: %w", err) - } - - return &service{ - repo: config.Repository, - feature: config.FeatureConnector, - webhook: config.Webhook, - eventHandler: eventHandler, - logger: config.Logger, - }, nil -} - -func (c service) ListFeature(ctx context.Context, namespace string, features ...string) ([]productcatalog.Feature, error) { - resp, err := c.feature.ListFeatures(ctx, productcatalog.ListFeaturesParams{ - IDsOrKeys: features, - Namespace: namespace, - MeterSlugs: nil, - IncludeArchived: false, - }) - if err != nil { - return nil, fmt.Errorf("failed to get features: %w", err) - } - - return resp.Items, nil -} - -func (c service) ListChannels(ctx context.Context, params ListChannelsInput) (ListChannelsResult, error) { - if err := params.Validate(ctx, c); err != nil { - return ListChannelsResult{}, fmt.Errorf("invalid params: %w", err) - } - - return c.repo.ListChannels(ctx, params) -} - -func (c service) CreateChannel(ctx context.Context, params CreateChannelInput) (*Channel, error) { - if err := params.Validate(ctx, c); err != nil { - return nil, fmt.Errorf("invalid params: %w", err) - } - - logger := c.logger.WithGroup("channel").With( - "operation", "create", - "namespace", params.Namespace, - ) - - logger.Debug("creating channel", "type", params.Type) - - txFunc := func(ctx context.Context, repo TxRepository) (*Channel, error) { - channel, err := repo.CreateChannel(ctx, params) - if err != nil { - return nil, fmt.Errorf("failed to create channel: %w", err) - } - - logger = logger.With("id", channel.ID) - - logger.Debug("channel stored in repository") - - switch params.Type { - case ChannelTypeWebhook: - var headers map[string]string - headers, err = StrictInterfaceMapToStringMap(channel.Config.WebHook.CustomHeaders) - if err != nil { - return nil, fmt.Errorf("failed to cast custom headers: %w", err) - } - - var wb *webhook.Webhook - wb, err = c.webhook.CreateWebhook(ctx, webhook.CreateWebhookInput{ - Namespace: params.Namespace, - ID: &channel.ID, - URL: channel.Config.WebHook.URL, - CustomHeaders: headers, - Disabled: channel.Disabled, - Secret: &channel.Config.WebHook.SigningSecret, - Metadata: map[string]string{ - ChannelIDMetadataKey: channel.ID, - }, - Description: convert.ToPointer("Notification Channel: " + channel.ID), - }) - if err != nil { - return nil, fmt.Errorf("failed to create webhook for channel: %w", err) - } - - logger.Debug("webhook is created") - - updateIn := UpdateChannelInput{ - NamespacedModel: channel.NamespacedModel, - Type: channel.Type, - Name: channel.Name, - Disabled: channel.Disabled, - Config: channel.Config, - ID: channel.ID, - } - updateIn.Config.WebHook.SigningSecret = wb.Secret - - channel, err = repo.UpdateChannel(ctx, updateIn) - if err != nil { - return nil, fmt.Errorf("failed to update channel: %w", err) - } - logger.Debug("channel is updated in database with webhook configuration") - default: - return nil, fmt.Errorf("invalid channel type: %s", channel.Type) - } - - return channel, nil - } - - return WithTx[*Channel](ctx, c.repo, txFunc) -} - -func (c service) DeleteChannel(ctx context.Context, params DeleteChannelInput) error { - if err := params.Validate(ctx, c); err != nil { - return fmt.Errorf("invalid params: %w", err) - } - - logger := c.logger.WithGroup("channel").With( - "operation", "delete", - "id", params.ID, - "namespace", params.Namespace, - ) - - logger.Debug("deleting channel") - - rules, err := c.repo.ListRules(ctx, ListRulesInput{ - Namespaces: []string{params.Namespace}, - IncludeDisabled: true, - Channels: []string{params.ID}, - }) - if err != nil { - return fmt.Errorf("failed to list rules for channel: %w", err) - } - - if rules.TotalCount > 0 { - ruleIDs := make([]string, 0, len(rules.Items)) - - for _, rule := range rules.Items { - ruleIDs = append(ruleIDs, rule.ID) - } - - return ValidationError{ - Err: fmt.Errorf("cannot delete channel as it is assigned to one or more rules: %v", ruleIDs), - } - } - - txFunc := func(ctx context.Context, repo TxRepository) error { - if err := c.webhook.DeleteWebhook(ctx, webhook.DeleteWebhookInput{ - Namespace: params.Namespace, - ID: params.ID, - }); err != nil { - return fmt.Errorf("failed to delete webhook: %w", err) - } - - logger.Debug("webhook associated with channel deleted") - - return repo.DeleteChannel(ctx, params) - } - - return WithTxNoValue(ctx, c.repo, txFunc) -} - -func (c service) GetChannel(ctx context.Context, params GetChannelInput) (*Channel, error) { - if err := params.Validate(ctx, c); err != nil { - return nil, fmt.Errorf("invalid params: %w", err) - } - - return c.repo.GetChannel(ctx, params) -} - -func (c service) UpdateChannel(ctx context.Context, params UpdateChannelInput) (*Channel, error) { - if err := params.Validate(ctx, c); err != nil { - return nil, fmt.Errorf("invalid params: %w", err) - } - - logger := c.logger.WithGroup("channel").With( - "operation", "update", - "id", params.ID, - "namespace", params.Namespace, - ) - - logger.Debug("updating channel") - - channel, err := c.repo.GetChannel(ctx, GetChannelInput{ - ID: params.ID, - Namespace: params.Namespace, - }) - if err != nil { - return nil, fmt.Errorf("failed to get channel: %w", err) - } - - if channel.DeletedAt != nil { - return nil, UpdateAfterDeleteError{ - Err: errors.New("not allowed to update deleted channel"), - } - } - - txFunc := func(ctx context.Context, repo TxRepository) (*Channel, error) { - channel, err = repo.UpdateChannel(ctx, params) - if err != nil { - return nil, fmt.Errorf("failed to create channel: %w", err) - } - - logger.Debug("channel updated in repository") - - switch params.Type { - case ChannelTypeWebhook: - var headers map[string]string - headers, err = StrictInterfaceMapToStringMap(channel.Config.WebHook.CustomHeaders) - if err != nil { - return nil, fmt.Errorf("failed to cast custom headers: %w", err) - } - - _, err = c.webhook.UpdateWebhook(ctx, webhook.UpdateWebhookInput{ - Namespace: params.Namespace, - ID: channel.ID, - URL: channel.Config.WebHook.URL, - CustomHeaders: headers, - Disabled: channel.Disabled, - Secret: &channel.Config.WebHook.SigningSecret, - Metadata: map[string]string{ - ChannelIDMetadataKey: channel.ID, - }, - Description: convert.ToPointer("Notification Channel: " + channel.ID), - }) - if err != nil { - return nil, fmt.Errorf("failed to update webhook for channel: %w", err) - } - - logger.Debug("webhook is updated") - - default: - return nil, fmt.Errorf("invalid channel type: %s", channel.Type) - } - - return channel, nil - } - - return WithTx[*Channel](ctx, c.repo, txFunc) -} - -func (c service) ListRules(ctx context.Context, params ListRulesInput) (ListRulesResult, error) { - if err := params.Validate(ctx, c); err != nil { - return ListRulesResult{}, fmt.Errorf("invalid params: %w", err) - } - - return c.repo.ListRules(ctx, params) -} - -func (c service) CreateRule(ctx context.Context, params CreateRuleInput) (*Rule, error) { - if err := params.Validate(ctx, c); err != nil { - return nil, fmt.Errorf("invalid params: %w", err) - } - - logger := c.logger.WithGroup("rule").With( - "operation", "create", - "namespace", params.Namespace, - ) - - logger.Debug("creating rule", "type", params.Type) - - txFunc := func(ctx context.Context, repo TxRepository) (*Rule, error) { - rule, err := repo.CreateRule(ctx, params) - if err != nil { - return nil, fmt.Errorf("failed to create rule: %w", err) - } - - for _, channel := range rule.Channels { - switch channel.Type { - case ChannelTypeWebhook: - _, err = c.webhook.UpdateWebhookChannels(ctx, webhook.UpdateWebhookChannelsInput{ - Namespace: params.Namespace, - ID: channel.ID, - AddChannels: []string{ - rule.ID, - }, - }) - if err != nil { - return nil, fmt.Errorf("failed to update webhook for channel: %w", err) - } - default: - return nil, fmt.Errorf("invalid channel type: %s", channel.Type) - } - } - - return rule, nil - } - - return WithTx[*Rule](ctx, c.repo, txFunc) -} - -func (c service) DeleteRule(ctx context.Context, params DeleteRuleInput) error { - if err := params.Validate(ctx, c); err != nil { - return fmt.Errorf("invalid params: %w", err) - } - - txFunc := func(ctx context.Context, repo TxRepository) error { - rule, err := c.repo.GetRule(ctx, GetRuleInput(params)) - if err != nil { - return fmt.Errorf("failed to get rule: %w", err) - } - - for _, channel := range rule.Channels { - switch channel.Type { - case ChannelTypeWebhook: - _, err = c.webhook.UpdateWebhookChannels(ctx, webhook.UpdateWebhookChannelsInput{ - Namespace: params.Namespace, - ID: channel.ID, - RemoveChannels: []string{ - rule.ID, - }, - }) - if err != nil { - return fmt.Errorf("failed to update webhook for channel: %w", err) - } - default: - return fmt.Errorf("invalid channel type: %s", channel.Type) - } - } - - return c.repo.DeleteRule(ctx, params) - } - - return WithTxNoValue(ctx, c.repo, txFunc) -} - -func (c service) GetRule(ctx context.Context, params GetRuleInput) (*Rule, error) { - if err := params.Validate(ctx, c); err != nil { - return nil, fmt.Errorf("invalid params: %w", err) - } - - return c.repo.GetRule(ctx, params) -} - -func (c service) UpdateRule(ctx context.Context, params UpdateRuleInput) (*Rule, error) { - if err := params.Validate(ctx, c); err != nil { - return nil, fmt.Errorf("invalid params: %w", err) - } - - logger := c.logger.WithGroup("rule").With( - "operation", "update", - "id", params.ID, - "namespace", params.Namespace, - ) - - logger.Debug("updating rule") - - rule, err := c.repo.GetRule(ctx, GetRuleInput{ - ID: params.ID, - Namespace: params.Namespace, - }) - if err != nil { - return nil, fmt.Errorf("failed to get rule: %w", err) - } - - if rule.DeletedAt != nil { - return nil, UpdateAfterDeleteError{ - Err: errors.New("not allowed to update deleted rule"), - } - } - - // Get list of channel IDs currently assigned to rule - oldChannelIDs := lo.Map(rule.Channels, func(channel Channel, _ int) string { - return channel.ID - }) - logger.Debug("currently assigned channels", "channels", oldChannelIDs) - - // Calculate channels diff for the update - channelIDsDiff := NewChannelIDsDifference(params.Channels, oldChannelIDs) - - logger.WithGroup("channels").Debug("difference in channels assignment", - "changed", channelIDsDiff.HasChanged(), - "additions", channelIDsDiff.Additions(), - "removals", channelIDsDiff.Removals(), - ) - - // We can return early ff there is no change in the list of channels assigned to rule. - if !channelIDsDiff.HasChanged() { - return c.repo.UpdateRule(ctx, params) - } - - txFunc := func(ctx context.Context, repo TxRepository) (*Rule, error) { - // Fetch all the channels from repo which are either added or removed from rule - channels, err := repo.ListChannels(ctx, ListChannelsInput{ - Page: pagination.Page{ - // In order to avoid under-fetching. There cannot be more affected channels than - // twice as the maximum number of allowed channels per rule. - PageSize: 2 * MaxChannelsPerRule, - PageNumber: 1, - }, - Namespaces: []string{params.Namespace}, - Channels: channelIDsDiff.All(), - IncludeDisabled: true, - }) - if err != nil { - return nil, fmt.Errorf("failed to list channels for rule: %w", err) - } - logger.Debug("fetched all affected channels", "channels", channels.Items) - - // Update affected channels - for _, channel := range channels.Items { - switch channel.Type { - case ChannelTypeWebhook: - input := webhook.UpdateWebhookChannelsInput{ - Namespace: params.Namespace, - ID: channel.ID, - } - - if channelIDsDiff.InAdditions(channel.ID) { - input.AddChannels = []string{rule.ID} - } - - if channelIDsDiff.InRemovals(channel.ID) { - input.RemoveChannels = []string{rule.ID} - } - - logger.Debug("updating webhook for channel", "id", channel.ID, "input", input) - - _, err = c.webhook.UpdateWebhookChannels(ctx, input) - if err != nil { - return nil, fmt.Errorf("failed to update webhook for channel: %w", err) - } - default: - return nil, fmt.Errorf("invalid channel type: %s", channel.Type) - } - } - - return c.repo.UpdateRule(ctx, params) - } - - return WithTx[*Rule](ctx, c.repo, txFunc) -} - -func (c service) ListEvents(ctx context.Context, params ListEventsInput) (ListEventsResult, error) { - if err := params.Validate(ctx, c); err != nil { - return ListEventsResult{}, fmt.Errorf("invalid params: %w", err) - } - - return c.repo.ListEvents(ctx, params) -} - -func (c service) GetEvent(ctx context.Context, params GetEventInput) (*Event, error) { - if err := params.Validate(ctx, c); err != nil { - return nil, fmt.Errorf("invalid params: %w", err) - } - - return c.repo.GetEvent(ctx, params) -} - -func (c service) CreateEvent(ctx context.Context, params CreateEventInput) (*Event, error) { - if err := params.Validate(ctx, c); err != nil { - return nil, fmt.Errorf("invalid params: %w", err) - } - - logger := c.logger.WithGroup("event").With( - "operation", "create", - "namespace", params.Namespace, - ) - - logger.Debug("creating event") - - rule, err := c.repo.GetRule(ctx, GetRuleInput{ - Namespace: params.Namespace, - ID: params.RuleID, - }) - if err != nil { - return nil, fmt.Errorf("failed to get rule: %w", err) - } - - if rule.DeletedAt != nil { - return nil, NotFoundError{ - NamespacedID: models.NamespacedID{ - Namespace: params.Namespace, - ID: params.RuleID, - }, - } - } - - if rule.Disabled { - return nil, ValidationError{ - Err: errors.New("failed to send event: rule is disabled"), - } - } - - event, err := c.repo.CreateEvent(ctx, params) - if err != nil { - return nil, fmt.Errorf("failed to create event: %w", err) - } - - if err = c.eventHandler.Dispatch(event); err != nil { - return nil, fmt.Errorf("failed to dispatch event: %w", err) - } - - return event, nil -} - -func (c service) UpdateEventDeliveryStatus(ctx context.Context, params UpdateEventDeliveryStatusInput) (*EventDeliveryStatus, error) { - if err := params.Validate(ctx, c); err != nil { - return nil, fmt.Errorf("invalid params: %w", err) - } - - return c.repo.UpdateEventDeliveryStatus(ctx, params) -} - -func (c service) ListEventsDeliveryStatus(ctx context.Context, params ListEventsDeliveryStatusInput) (ListEventsDeliveryStatusResult, error) { - if err := params.Validate(ctx, c); err != nil { - return ListEventsDeliveryStatusResult{}, fmt.Errorf("invalid params: %w", err) - } - - return c.repo.ListEventsDeliveryStatus(ctx, params) -} - -func (c service) GetEventDeliveryStatus(ctx context.Context, params GetEventDeliveryStatusInput) (*EventDeliveryStatus, error) { - if err := params.Validate(ctx, c); err != nil { - return nil, fmt.Errorf("invalid params: %w", err) - } - - return c.repo.GetEventDeliveryStatus(ctx, params) -} diff --git a/openmeter/notification/service/channel.go b/openmeter/notification/service/channel.go new file mode 100644 index 000000000..94684264b --- /dev/null +++ b/openmeter/notification/service/channel.go @@ -0,0 +1,222 @@ +package service + +import ( + "context" + "errors" + "fmt" + + "github.com/openmeterio/openmeter/openmeter/notification" + "github.com/openmeterio/openmeter/openmeter/notification/webhook" + "github.com/openmeterio/openmeter/pkg/convert" +) + +func (s Service) ListChannels(ctx context.Context, params notification.ListChannelsInput) (notification.ListChannelsResult, error) { + if err := params.Validate(ctx, s); err != nil { + return notification.ListChannelsResult{}, fmt.Errorf("invalid params: %w", err) + } + + return s.repo.ListChannels(ctx, params) +} + +func (s Service) CreateChannel(ctx context.Context, params notification.CreateChannelInput) (*notification.Channel, error) { + if err := params.Validate(ctx, s); err != nil { + return nil, fmt.Errorf("invalid params: %w", err) + } + + logger := s.logger.WithGroup("channel").With( + "operation", "create", + "namespace", params.Namespace, + ) + + logger.Debug("creating channel", "type", params.Type) + + txFunc := func(ctx context.Context, repo notification.TxRepository) (*notification.Channel, error) { + channel, err := repo.CreateChannel(ctx, params) + if err != nil { + return nil, fmt.Errorf("failed to create channel: %w", err) + } + + logger = logger.With("id", channel.ID) + + logger.Debug("channel stored in repository") + + switch params.Type { + case notification.ChannelTypeWebhook: + var headers map[string]string + headers, err = notification.StrictInterfaceMapToStringMap(channel.Config.WebHook.CustomHeaders) + if err != nil { + return nil, fmt.Errorf("failed to cast custom headers: %w", err) + } + + var wb *webhook.Webhook + wb, err = s.webhook.CreateWebhook(ctx, webhook.CreateWebhookInput{ + Namespace: params.Namespace, + ID: &channel.ID, + URL: channel.Config.WebHook.URL, + CustomHeaders: headers, + Disabled: channel.Disabled, + Secret: &channel.Config.WebHook.SigningSecret, + Metadata: map[string]string{ + ChannelIDMetadataKey: channel.ID, + }, + Description: convert.ToPointer("Notification Channel: " + channel.ID), + }) + if err != nil { + return nil, fmt.Errorf("failed to create webhook for channel: %w", err) + } + + logger.Debug("webhook is created") + + updateIn := notification.UpdateChannelInput{ + NamespacedModel: channel.NamespacedModel, + Type: channel.Type, + Name: channel.Name, + Disabled: channel.Disabled, + Config: channel.Config, + ID: channel.ID, + } + updateIn.Config.WebHook.SigningSecret = wb.Secret + + channel, err = repo.UpdateChannel(ctx, updateIn) + if err != nil { + return nil, fmt.Errorf("failed to update channel: %w", err) + } + logger.Debug("channel is updated in database with webhook configuration") + default: + return nil, fmt.Errorf("invalid channel type: %s", channel.Type) + } + + return channel, nil + } + + return notification.WithTx[*notification.Channel](ctx, s.repo, txFunc) +} + +func (s Service) DeleteChannel(ctx context.Context, params notification.DeleteChannelInput) error { + if err := params.Validate(ctx, s); err != nil { + return fmt.Errorf("invalid params: %w", err) + } + + logger := s.logger.WithGroup("channel").With( + "operation", "delete", + "id", params.ID, + "namespace", params.Namespace, + ) + + logger.Debug("deleting channel") + + rules, err := s.repo.ListRules(ctx, notification.ListRulesInput{ + Namespaces: []string{params.Namespace}, + IncludeDisabled: true, + Channels: []string{params.ID}, + }) + if err != nil { + return fmt.Errorf("failed to list rules for channel: %w", err) + } + + if rules.TotalCount > 0 { + ruleIDs := make([]string, 0, len(rules.Items)) + + for _, rule := range rules.Items { + ruleIDs = append(ruleIDs, rule.ID) + } + + return notification.ValidationError{ + Err: fmt.Errorf("cannot delete channel as it is assigned to one or more rules: %v", ruleIDs), + } + } + + txFunc := func(ctx context.Context, repo notification.TxRepository) error { + if err := s.webhook.DeleteWebhook(ctx, webhook.DeleteWebhookInput{ + Namespace: params.Namespace, + ID: params.ID, + }); err != nil { + return fmt.Errorf("failed to delete webhook: %w", err) + } + + logger.Debug("webhook associated with channel deleted") + + return repo.DeleteChannel(ctx, params) + } + + return notification.WithTxNoValue(ctx, s.repo, txFunc) +} + +func (s Service) GetChannel(ctx context.Context, params notification.GetChannelInput) (*notification.Channel, error) { + if err := params.Validate(ctx, s); err != nil { + return nil, fmt.Errorf("invalid params: %w", err) + } + + return s.repo.GetChannel(ctx, params) +} + +func (s Service) UpdateChannel(ctx context.Context, params notification.UpdateChannelInput) (*notification.Channel, error) { + if err := params.Validate(ctx, s); err != nil { + return nil, fmt.Errorf("invalid params: %w", err) + } + + logger := s.logger.WithGroup("channel").With( + "operation", "update", + "id", params.ID, + "namespace", params.Namespace, + ) + + logger.Debug("updating channel") + + channel, err := s.repo.GetChannel(ctx, notification.GetChannelInput{ + ID: params.ID, + Namespace: params.Namespace, + }) + if err != nil { + return nil, fmt.Errorf("failed to get channel: %w", err) + } + + if channel.DeletedAt != nil { + return nil, notification.UpdateAfterDeleteError{ + Err: errors.New("not allowed to update deleted channel"), + } + } + + txFunc := func(ctx context.Context, repo notification.TxRepository) (*notification.Channel, error) { + channel, err = repo.UpdateChannel(ctx, params) + if err != nil { + return nil, fmt.Errorf("failed to create channel: %w", err) + } + + logger.Debug("channel updated in repository") + + switch params.Type { + case notification.ChannelTypeWebhook: + var headers map[string]string + headers, err = notification.StrictInterfaceMapToStringMap(channel.Config.WebHook.CustomHeaders) + if err != nil { + return nil, fmt.Errorf("failed to cast custom headers: %w", err) + } + + _, err = s.webhook.UpdateWebhook(ctx, webhook.UpdateWebhookInput{ + Namespace: params.Namespace, + ID: channel.ID, + URL: channel.Config.WebHook.URL, + CustomHeaders: headers, + Disabled: channel.Disabled, + Secret: &channel.Config.WebHook.SigningSecret, + Metadata: map[string]string{ + ChannelIDMetadataKey: channel.ID, + }, + Description: convert.ToPointer("Notification Channel: " + channel.ID), + }) + if err != nil { + return nil, fmt.Errorf("failed to update webhook for channel: %w", err) + } + + logger.Debug("webhook is updated") + + default: + return nil, fmt.Errorf("invalid channel type: %s", channel.Type) + } + + return channel, nil + } + + return notification.WithTx[*notification.Channel](ctx, s.repo, txFunc) +} diff --git a/openmeter/notification/service/event.go b/openmeter/notification/service/event.go new file mode 100644 index 000000000..bf0098b14 --- /dev/null +++ b/openmeter/notification/service/event.go @@ -0,0 +1,97 @@ +package service + +import ( + "context" + "errors" + "fmt" + + "github.com/openmeterio/openmeter/openmeter/notification" + "github.com/openmeterio/openmeter/pkg/models" +) + +func (s Service) ListEvents(ctx context.Context, params notification.ListEventsInput) (notification.ListEventsResult, error) { + if err := params.Validate(ctx, s); err != nil { + return notification.ListEventsResult{}, fmt.Errorf("invalid params: %w", err) + } + + return s.repo.ListEvents(ctx, params) +} + +func (s Service) GetEvent(ctx context.Context, params notification.GetEventInput) (*notification.Event, error) { + if err := params.Validate(ctx, s); err != nil { + return nil, fmt.Errorf("invalid params: %w", err) + } + + return s.repo.GetEvent(ctx, params) +} + +func (s Service) CreateEvent(ctx context.Context, params notification.CreateEventInput) (*notification.Event, error) { + if err := params.Validate(ctx, s); err != nil { + return nil, fmt.Errorf("invalid params: %w", err) + } + + logger := s.logger.WithGroup("event").With( + "operation", "create", + "namespace", params.Namespace, + ) + + logger.Debug("creating event") + + rule, err := s.repo.GetRule(ctx, notification.GetRuleInput{ + Namespace: params.Namespace, + ID: params.RuleID, + }) + if err != nil { + return nil, fmt.Errorf("failed to get rule: %w", err) + } + + if rule.DeletedAt != nil { + return nil, notification.NotFoundError{ + NamespacedID: models.NamespacedID{ + Namespace: params.Namespace, + ID: params.RuleID, + }, + } + } + + if rule.Disabled { + return nil, notification.ValidationError{ + Err: errors.New("failed to send event: rule is disabled"), + } + } + + event, err := s.repo.CreateEvent(ctx, params) + if err != nil { + return nil, fmt.Errorf("failed to create event: %w", err) + } + + if err = s.eventHandler.Dispatch(event); err != nil { + return nil, fmt.Errorf("failed to dispatch event: %w", err) + } + + return event, nil +} + +func (s Service) UpdateEventDeliveryStatus(ctx context.Context, params notification.UpdateEventDeliveryStatusInput) (*notification.EventDeliveryStatus, error) { + if err := params.Validate(ctx, s); err != nil { + return nil, fmt.Errorf("invalid params: %w", err) + } + + return s.repo.UpdateEventDeliveryStatus(ctx, params) +} + +func (s Service) ListEventsDeliveryStatus(ctx context.Context, params notification.ListEventsDeliveryStatusInput) (notification.ListEventsDeliveryStatusResult, error) { + if err := params.Validate(ctx, s); err != nil { + return notification.ListEventsDeliveryStatusResult{}, fmt.Errorf("invalid params: %w", err) + } + + return s.repo.ListEventsDeliveryStatus(ctx, params) +} + +func (s Service) GetEventDeliveryStatus(ctx context.Context, params notification.GetEventDeliveryStatusInput) (*notification.EventDeliveryStatus, error) { + if err := params.Validate(ctx, s); err != nil { + return nil, fmt.Errorf("invalid params: %w", err) + } + + return s.repo.GetEventDeliveryStatus(ctx, params) +} diff --git a/openmeter/notification/service/rule.go b/openmeter/notification/service/rule.go new file mode 100644 index 000000000..a136274ec --- /dev/null +++ b/openmeter/notification/service/rule.go @@ -0,0 +1,205 @@ +package service + +import ( + "context" + "errors" + "fmt" + + "github.com/samber/lo" + + "github.com/openmeterio/openmeter/openmeter/notification" + "github.com/openmeterio/openmeter/openmeter/notification/webhook" + "github.com/openmeterio/openmeter/pkg/pagination" +) + +func (s Service) ListRules(ctx context.Context, params notification.ListRulesInput) (notification.ListRulesResult, error) { + if err := params.Validate(ctx, s); err != nil { + return notification.ListRulesResult{}, fmt.Errorf("invalid params: %w", err) + } + + return s.repo.ListRules(ctx, params) +} + +func (s Service) CreateRule(ctx context.Context, params notification.CreateRuleInput) (*notification.Rule, error) { + if err := params.Validate(ctx, s); err != nil { + return nil, fmt.Errorf("invalid params: %w", err) + } + + logger := s.logger.WithGroup("rule").With( + "operation", "create", + "namespace", params.Namespace, + ) + + logger.Debug("creating rule", "type", params.Type) + + txFunc := func(ctx context.Context, repo notification.TxRepository) (*notification.Rule, error) { + rule, err := repo.CreateRule(ctx, params) + if err != nil { + return nil, fmt.Errorf("failed to create rule: %w", err) + } + + for _, channel := range rule.Channels { + switch channel.Type { + case notification.ChannelTypeWebhook: + _, err = s.webhook.UpdateWebhookChannels(ctx, webhook.UpdateWebhookChannelsInput{ + Namespace: params.Namespace, + ID: channel.ID, + AddChannels: []string{ + rule.ID, + }, + }) + if err != nil { + return nil, fmt.Errorf("failed to update webhook for channel: %w", err) + } + default: + return nil, fmt.Errorf("invalid channel type: %s", channel.Type) + } + } + + return rule, nil + } + + return notification.WithTx[*notification.Rule](ctx, s.repo, txFunc) +} + +func (s Service) DeleteRule(ctx context.Context, params notification.DeleteRuleInput) error { + if err := params.Validate(ctx, s); err != nil { + return fmt.Errorf("invalid params: %w", err) + } + + txFunc := func(ctx context.Context, repo notification.TxRepository) error { + rule, err := s.repo.GetRule(ctx, notification.GetRuleInput(params)) + if err != nil { + return fmt.Errorf("failed to get rule: %w", err) + } + + for _, channel := range rule.Channels { + switch channel.Type { + case notification.ChannelTypeWebhook: + _, err = s.webhook.UpdateWebhookChannels(ctx, webhook.UpdateWebhookChannelsInput{ + Namespace: params.Namespace, + ID: channel.ID, + RemoveChannels: []string{ + rule.ID, + }, + }) + if err != nil { + return fmt.Errorf("failed to update webhook for channel: %w", err) + } + default: + return fmt.Errorf("invalid channel type: %s", channel.Type) + } + } + + return s.repo.DeleteRule(ctx, params) + } + + return notification.WithTxNoValue(ctx, s.repo, txFunc) +} + +func (s Service) GetRule(ctx context.Context, params notification.GetRuleInput) (*notification.Rule, error) { + if err := params.Validate(ctx, s); err != nil { + return nil, fmt.Errorf("invalid params: %w", err) + } + + return s.repo.GetRule(ctx, params) +} + +func (s Service) UpdateRule(ctx context.Context, params notification.UpdateRuleInput) (*notification.Rule, error) { + if err := params.Validate(ctx, s); err != nil { + return nil, fmt.Errorf("invalid params: %w", err) + } + + logger := s.logger.WithGroup("rule").With( + "operation", "update", + "id", params.ID, + "namespace", params.Namespace, + ) + + logger.Debug("updating rule") + + rule, err := s.repo.GetRule(ctx, notification.GetRuleInput{ + ID: params.ID, + Namespace: params.Namespace, + }) + if err != nil { + return nil, fmt.Errorf("failed to get rule: %w", err) + } + + if rule.DeletedAt != nil { + return nil, notification.UpdateAfterDeleteError{ + Err: errors.New("not allowed to update deleted rule"), + } + } + + // Get list of channel IDs currently assigned to rule + oldChannelIDs := lo.Map(rule.Channels, func(channel notification.Channel, _ int) string { + return channel.ID + }) + logger.Debug("currently assigned channels", "channels", oldChannelIDs) + + // Calculate channels diff for the update + channelIDsDiff := notification.NewChannelIDsDifference(params.Channels, oldChannelIDs) + + logger.WithGroup("channels").Debug("difference in channels assignment", + "changed", channelIDsDiff.HasChanged(), + "additions", channelIDsDiff.Additions(), + "removals", channelIDsDiff.Removals(), + ) + + // We can return early ff there is no change in the list of channels assigned to rule. + if !channelIDsDiff.HasChanged() { + return s.repo.UpdateRule(ctx, params) + } + + txFunc := func(ctx context.Context, repo notification.TxRepository) (*notification.Rule, error) { + // Fetch all the channels from repo which are either added or removed from rule + channels, err := repo.ListChannels(ctx, notification.ListChannelsInput{ + Page: pagination.Page{ + // In order to avoid under-fetching. There cannot be more affected channels than + // twice as the maximum number of allowed channels per rule. + PageSize: 2 * notification.MaxChannelsPerRule, + PageNumber: 1, + }, + Namespaces: []string{params.Namespace}, + Channels: channelIDsDiff.All(), + IncludeDisabled: true, + }) + if err != nil { + return nil, fmt.Errorf("failed to list channels for rule: %w", err) + } + logger.Debug("fetched all affected channels", "channels", channels.Items) + + // Update affected channels + for _, channel := range channels.Items { + switch channel.Type { + case notification.ChannelTypeWebhook: + input := webhook.UpdateWebhookChannelsInput{ + Namespace: params.Namespace, + ID: channel.ID, + } + + if channelIDsDiff.InAdditions(channel.ID) { + input.AddChannels = []string{rule.ID} + } + + if channelIDsDiff.InRemovals(channel.ID) { + input.RemoveChannels = []string{rule.ID} + } + + logger.Debug("updating webhook for channel", "id", channel.ID, "input", input) + + _, err = s.webhook.UpdateWebhookChannels(ctx, input) + if err != nil { + return nil, fmt.Errorf("failed to update webhook for channel: %w", err) + } + default: + return nil, fmt.Errorf("invalid channel type: %s", channel.Type) + } + } + + return s.repo.UpdateRule(ctx, params) + } + + return notification.WithTx[*notification.Rule](ctx, s.repo, txFunc) +} diff --git a/openmeter/notification/service/service.go b/openmeter/notification/service/service.go new file mode 100644 index 000000000..0df200c28 --- /dev/null +++ b/openmeter/notification/service/service.go @@ -0,0 +1,96 @@ +package service + +import ( + "context" + "errors" + "fmt" + "log/slog" + + "github.com/openmeterio/openmeter/openmeter/notification" + "github.com/openmeterio/openmeter/openmeter/notification/webhook" + "github.com/openmeterio/openmeter/openmeter/productcatalog" +) + +const ( + ChannelIDMetadataKey = "om-channel-id" +) + +var _ notification.Service = (*Service)(nil) + +type Service struct { + feature productcatalog.FeatureConnector + + repo notification.Repository + webhook webhook.Handler + + eventHandler notification.EventHandler + + logger *slog.Logger +} + +func (s Service) Close() error { + return s.eventHandler.Close() +} + +type Config struct { + FeatureConnector productcatalog.FeatureConnector + + Repository notification.Repository + Webhook webhook.Handler + + Logger *slog.Logger +} + +func New(config Config) (*Service, error) { + if config.Repository == nil { + return nil, errors.New("missing repository") + } + + if config.FeatureConnector == nil { + return nil, errors.New("missing feature connector") + } + + if config.Webhook == nil { + return nil, errors.New("missing webhook handler") + } + + if config.Logger == nil { + return nil, errors.New("missing logger") + } + config.Logger = config.Logger.WithGroup("notification") + + eventHandler, err := notification.NewEventHandler(notification.EventHandlerConfig{ + Repository: config.Repository, + Webhook: config.Webhook, + Logger: config.Logger, + }) + if err != nil { + return nil, fmt.Errorf("failed to initialize notification event handler: %w", err) + } + + if err = eventHandler.Start(); err != nil { + return nil, fmt.Errorf("failed to initialize notification event handler: %w", err) + } + + return &Service{ + repo: config.Repository, + feature: config.FeatureConnector, + webhook: config.Webhook, + eventHandler: eventHandler, + logger: config.Logger, + }, nil +} + +func (s Service) ListFeature(ctx context.Context, namespace string, features ...string) ([]productcatalog.Feature, error) { + resp, err := s.feature.ListFeatures(ctx, productcatalog.ListFeaturesParams{ + IDsOrKeys: features, + Namespace: namespace, + MeterSlugs: nil, + IncludeArchived: false, + }) + if err != nil { + return nil, fmt.Errorf("failed to get features: %w", err) + } + + return resp.Items, nil +} diff --git a/test/notification/testenv.go b/test/notification/testenv.go index 5cbc6c28e..7c9cc3102 100644 --- a/test/notification/testenv.go +++ b/test/notification/testenv.go @@ -11,6 +11,7 @@ import ( "github.com/openmeterio/openmeter/openmeter/meter" "github.com/openmeterio/openmeter/openmeter/notification" notificationrepository "github.com/openmeterio/openmeter/openmeter/notification/repository" + notificationservice "github.com/openmeterio/openmeter/openmeter/notification/service" notificationwebhook "github.com/openmeterio/openmeter/openmeter/notification/webhook" "github.com/openmeterio/openmeter/openmeter/productcatalog" productcatalogadapter "github.com/openmeterio/openmeter/openmeter/productcatalog/adapter" @@ -148,7 +149,7 @@ func NewTestEnv(ctx context.Context) (TestEnv, error) { return nil, fmt.Errorf("failed to create webhook handler: %w", err) } - service, err := notification.New(notification.Config{ + service, err := notificationservice.New(notificationservice.Config{ Repository: repo, FeatureConnector: featureConnector, Webhook: webhook, From 301db34cfc06fc1a9b89c61e139ada50c67db612 Mon Sep 17 00:00:00 2001 From: Krisztian Gacsal Date: Tue, 27 Aug 2024 17:09:34 +0200 Subject: [PATCH 2/3] refactor: notification repository --- openmeter/notification/repository/channel.go | 180 +++++ openmeter/notification/repository/event.go | 385 +++++++++ .../notification/repository/repository.go | 744 ------------------ openmeter/notification/repository/rule.go | 219 ++++++ 4 files changed, 784 insertions(+), 744 deletions(-) create mode 100644 openmeter/notification/repository/channel.go create mode 100644 openmeter/notification/repository/event.go create mode 100644 openmeter/notification/repository/rule.go diff --git a/openmeter/notification/repository/channel.go b/openmeter/notification/repository/channel.go new file mode 100644 index 000000000..08665a75c --- /dev/null +++ b/openmeter/notification/repository/channel.go @@ -0,0 +1,180 @@ +package repository + +import ( + "context" + "fmt" + + entdb "github.com/openmeterio/openmeter/openmeter/ent/db" + channeldb "github.com/openmeterio/openmeter/openmeter/ent/db/notificationchannel" + "github.com/openmeterio/openmeter/openmeter/notification" + "github.com/openmeterio/openmeter/pkg/clock" + "github.com/openmeterio/openmeter/pkg/framework/entutils" + "github.com/openmeterio/openmeter/pkg/models" + "github.com/openmeterio/openmeter/pkg/pagination" + "github.com/openmeterio/openmeter/pkg/sortx" +) + +func (r repository) ListChannels(ctx context.Context, params notification.ListChannelsInput) (pagination.PagedResponse[notification.Channel], error) { + db := r.client() + + query := db.NotificationChannel.Query(). + Where(channeldb.DeletedAtIsNil()) // Do not return deleted channels + + if len(params.Namespaces) > 0 { + query = query.Where(channeldb.NamespaceIn(params.Namespaces...)) + } + + if len(params.Channels) > 0 { + query = query.Where(channeldb.IDIn(params.Channels...)) + } + + if !params.IncludeDisabled { + query = query.Where(channeldb.Disabled(false)) + } + + order := entutils.GetOrdering(sortx.OrderDefault) + if !params.Order.IsDefaultValue() { + order = entutils.GetOrdering(params.Order) + } + + switch params.OrderBy { + case notification.ChannelOrderByCreatedAt: + query = query.Order(channeldb.ByCreatedAt(order...)) + case notification.ChannelOrderByUpdatedAt: + query = query.Order(channeldb.ByUpdatedAt(order...)) + case notification.ChannelOrderByType: + query = query.Order(channeldb.ByType(order...)) + case notification.ChannelOrderByID: + fallthrough + default: + query = query.Order(channeldb.ByID(order...)) + } + + response := pagination.PagedResponse[notification.Channel]{ + Page: params.Page, + } + + paged, err := query.Paginate(ctx, params.Page) + if err != nil { + return response, err + } + + result := make([]notification.Channel, 0, len(paged.Items)) + for _, item := range paged.Items { + if item == nil { + r.logger.Warn("invalid query result: nil notification channel received") + continue + } + + result = append(result, *ChannelFromDBEntity(*item)) + } + + response.TotalCount = paged.TotalCount + response.Items = result + + return response, nil +} + +func (r repository) CreateChannel(ctx context.Context, params notification.CreateChannelInput) (*notification.Channel, error) { + db := r.client() + + query := db.NotificationChannel.Create(). + SetType(params.Type). + SetName(params.Name). + SetNamespace(params.Namespace). + SetDisabled(params.Disabled). + SetConfig(params.Config) + + channel, err := query.Save(ctx) + if err != nil { + return nil, fmt.Errorf("failed to create notification channel: %w", err) + } + + if channel == nil { + return nil, fmt.Errorf("invalid query result: nil notification channel received") + } + + return ChannelFromDBEntity(*channel), nil +} + +func (r repository) DeleteChannel(ctx context.Context, params notification.DeleteChannelInput) error { + db := r.client() + + query := db.NotificationChannel.UpdateOneID(params.ID). + SetDeletedAt(clock.Now().UTC()). + SetDisabled(true) + + _, err := query.Save(ctx) + if err != nil { + if entdb.IsNotFound(err) { + return notification.NotFoundError{ + NamespacedID: models.NamespacedID{ + Namespace: params.Namespace, + ID: params.ID, + }, + } + } + + return fmt.Errorf("failed to delete notification channel: %w", err) + } + + return nil +} + +func (r repository) GetChannel(ctx context.Context, params notification.GetChannelInput) (*notification.Channel, error) { + db := r.client() + + query := db.NotificationChannel.Query(). + Where(channeldb.ID(params.ID)). + Where(channeldb.Namespace(params.Namespace)) + + queryRow, err := query.First(ctx) + if err != nil { + if entdb.IsNotFound(err) { + return nil, notification.NotFoundError{ + NamespacedID: models.NamespacedID{ + Namespace: params.Namespace, + ID: params.ID, + }, + } + } + + return nil, fmt.Errorf("failed to fetch notification channel: %w", err) + } + + if queryRow == nil { + return nil, fmt.Errorf("invalid query result: nil notification channel received") + } + + return ChannelFromDBEntity(*queryRow), nil +} + +func (r repository) UpdateChannel(ctx context.Context, params notification.UpdateChannelInput) (*notification.Channel, error) { + db := r.client() + + query := db.NotificationChannel.UpdateOneID(params.ID). + SetUpdatedAt(clock.Now().UTC()). + SetDisabled(params.Disabled). + SetConfig(params.Config). + SetName(params.Name) + + queryRow, err := query.Save(ctx) + if err != nil { + if entdb.IsNotFound(err) { + return nil, notification.NotFoundError{ + NamespacedID: models.NamespacedID{ + Namespace: params.Namespace, + ID: params.ID, + }, + } + } + + return nil, fmt.Errorf("failed to update notification channel: %w", err) + } + + if queryRow == nil { + return nil, fmt.Errorf("invalid query result: nil notification channel received") + } + + return ChannelFromDBEntity(*queryRow), nil +} diff --git a/openmeter/notification/repository/event.go b/openmeter/notification/repository/event.go new file mode 100644 index 000000000..6b53c8dd1 --- /dev/null +++ b/openmeter/notification/repository/event.go @@ -0,0 +1,385 @@ +package repository + +import ( + "context" + "encoding/json" + "errors" + "fmt" + + entdb "github.com/openmeterio/openmeter/openmeter/ent/db" + channeldb "github.com/openmeterio/openmeter/openmeter/ent/db/notificationchannel" + eventdb "github.com/openmeterio/openmeter/openmeter/ent/db/notificationevent" + statusdb "github.com/openmeterio/openmeter/openmeter/ent/db/notificationeventdeliverystatus" + ruledb "github.com/openmeterio/openmeter/openmeter/ent/db/notificationrule" + "github.com/openmeterio/openmeter/openmeter/notification" + "github.com/openmeterio/openmeter/pkg/framework/entutils" + "github.com/openmeterio/openmeter/pkg/models" + "github.com/openmeterio/openmeter/pkg/pagination" + "github.com/openmeterio/openmeter/pkg/sortx" +) + +func (r repository) ListEvents(ctx context.Context, params notification.ListEventsInput) (pagination.PagedResponse[notification.Event], error) { + db := r.client() + + query := db.NotificationEvent.Query() + + if len(params.Namespaces) > 0 { + query = query.Where(eventdb.NamespaceIn(params.Namespaces...)) + } + + if len(params.Events) > 0 { + query = query.Where(eventdb.IDIn(params.Events...)) + } + + if !params.From.IsZero() { + query = query.Where(eventdb.CreatedAtGTE(params.From.UTC())) + } + + if !params.To.IsZero() { + query = query.Where(eventdb.CreatedAtLTE(params.To.UTC())) + } + + if len(params.DeduplicationHashes) > 0 { + query = query.Where( + entutils.JSONBIn(eventdb.FieldAnnotations, notification.AnnotationEventDedupeHash, params.DeduplicationHashes), + ) + } + + // Eager load DeliveryStatus, Rules (including Channels) + if len(params.DeliveryStatusStates) > 0 { + query = query.WithDeliveryStatuses(func(query *entdb.NotificationEventDeliveryStatusQuery) { + query.Where(statusdb.StateIn(params.DeliveryStatusStates...)) + }) + } else { + query = query.WithDeliveryStatuses() + } + + if len(params.Features) > 0 { + query = query.Where( + eventdb.Or( + entutils.JSONBIn(eventdb.FieldAnnotations, notification.AnnotationEventFeatureKey, params.Features), + entutils.JSONBIn(eventdb.FieldAnnotations, notification.AnnotationEventFeatureID, params.Features), + ), + ) + } + + if len(params.Subjects) > 0 { + query = query.Where( + eventdb.Or( + entutils.JSONBIn(eventdb.FieldAnnotations, notification.AnnotationEventSubjectKey, params.Subjects), + entutils.JSONBIn(eventdb.FieldAnnotations, notification.AnnotationEventSubjectID, params.Subjects), + ), + ) + } + + if len(params.Rules) > 0 { + query = query.Where(eventdb.RuleIDIn(params.Rules...)) + } + + if len(params.Channels) > 0 { + query = query.Where(eventdb.HasRulesWith(ruledb.HasChannelsWith(channeldb.IDIn(params.Channels...)))) + } + + query = query.WithRules(func(query *entdb.NotificationRuleQuery) { + query.WithChannels() + }) + + order := entutils.GetOrdering(sortx.OrderDesc) + if !params.Order.IsDefaultValue() { + order = entutils.GetOrdering(params.Order) + } + + switch params.OrderBy { + case notification.EventOrderByID: + query = query.Order(eventdb.ByID(order...)) + case notification.EventOrderByCreatedAt: + fallthrough + default: + query = query.Order(eventdb.ByCreatedAt(order...)) + } + + response := pagination.PagedResponse[notification.Event]{ + Page: params.Page, + } + + paged, err := query.Paginate(ctx, params.Page) + if err != nil { + return response, err + } + + result := make([]notification.Event, 0, len(paged.Items)) + for _, eventRow := range paged.Items { + if eventRow == nil { + r.logger.Warn("invalid query result: nil notification event received") + continue + } + + event, err := EventFromDBEntity(*eventRow) + if err != nil { + return response, fmt.Errorf("failed to get notification events: %w", err) + } + + result = append(result, *event) + } + + response.TotalCount = paged.TotalCount + response.Items = result + + return response, nil +} + +func (r repository) GetEvent(ctx context.Context, params notification.GetEventInput) (*notification.Event, error) { + db := r.client() + + query := db.NotificationEvent.Query(). + Where(eventdb.Namespace(params.Namespace)). + Where(eventdb.ID(params.ID)). + WithDeliveryStatuses(). + WithRules() + + eventRow, err := query.First(ctx) + if err != nil { + if entdb.IsNotFound(err) { + return nil, notification.NotFoundError{ + NamespacedID: models.NamespacedID{ + Namespace: params.Namespace, + ID: params.ID, + }, + } + } + + return nil, fmt.Errorf("failed to get notification event: %w", err) + } + + if eventRow == nil { + return nil, errors.New("invalid query response: nil notification event received") + } + + event, err := EventFromDBEntity(*eventRow) + if err != nil { + return nil, fmt.Errorf("failed to get notification event: %w", err) + } + + return event, nil +} + +func (r repository) CreateEvent(ctx context.Context, params notification.CreateEventInput) (*notification.Event, error) { + payloadJSON, err := json.Marshal(params.Payload) + if err != nil { + return nil, fmt.Errorf("failed to serialize notification event payload: %w", err) + } + + db := r.client() + + query := db.NotificationEvent.Create(). + SetType(params.Type). + SetNamespace(params.Namespace). + SetRuleID(params.RuleID). + SetPayload(string(payloadJSON)). + SetAnnotations(params.Annotations) + + eventRow, err := query.Save(ctx) + if err != nil { + return nil, fmt.Errorf("failed to create notification event: %w", err) + } + + if eventRow == nil { + return nil, errors.New("invalid query response: nil notification event received") + } + + ruleQuery := db.NotificationRule.Query(). + Where(ruledb.Namespace(params.Namespace)). + Where(ruledb.ID(params.RuleID)). + Where(ruledb.DeletedAtIsNil()). + WithChannels() + + ruleRow, err := ruleQuery.First(ctx) + if err != nil { + if entdb.IsNotFound(err) { + return nil, notification.NotFoundError{ + NamespacedID: models.NamespacedID{ + Namespace: params.Namespace, + ID: params.RuleID, + }, + } + } + + return nil, fmt.Errorf("failed to fetch notification rule: %w", err) + } + if ruleRow == nil { + return nil, errors.New("invalid query result: nil notification rule received") + } + + if _, err = ruleRow.Edges.ChannelsOrErr(); err != nil { + return nil, fmt.Errorf("invalid query result: failed to load notification chnnaels for rule: %w", err) + } + + eventRow.Edges.Rules = ruleRow + + statusBulkQuery := make([]*entdb.NotificationEventDeliveryStatusCreate, 0, len(ruleRow.Edges.Channels)) + for _, channel := range ruleRow.Edges.Channels { + if channel == nil { + r.logger.Warn("invalid query result: nil channel received") + continue + } + + q := db.NotificationEventDeliveryStatus.Create(). + SetNamespace(params.Namespace). + SetEventID(eventRow.ID). + SetChannelID(channel.ID). + SetState(notification.EventDeliveryStatusStatePending). + AddEvents(eventRow) + + statusBulkQuery = append(statusBulkQuery, q) + } + + statusQuery := db.NotificationEventDeliveryStatus.CreateBulk(statusBulkQuery...) + + statusRows, err := statusQuery.Save(ctx) + if err != nil { + return nil, fmt.Errorf("failed to save notification event: %w", err) + } + + eventRow.Edges.DeliveryStatuses = statusRows + + event, err := EventFromDBEntity(*eventRow) + if err != nil { + return nil, fmt.Errorf("failed to cast notification event: %w", err) + } + + return event, nil +} + +func (r repository) ListEventsDeliveryStatus(ctx context.Context, params notification.ListEventsDeliveryStatusInput) (pagination.PagedResponse[notification.EventDeliveryStatus], error) { + db := r.client() + + query := db.NotificationEventDeliveryStatus.Query() + + if len(params.Namespaces) > 0 { + query = query.Where(statusdb.NamespaceIn(params.Namespaces...)) + } + + if len(params.Events) > 0 { + query = query.Where(statusdb.EventIDIn(params.Events...)) + } + + if len(params.Channels) > 0 { + query = query.Where(statusdb.ChannelIDIn(params.Channels...)) + } + + if len(params.States) > 0 { + query = query.Where(statusdb.StateIn(params.States...)) + } + + if !params.From.IsZero() { + query = query.Where(statusdb.UpdatedAtGTE(params.From.UTC())) + } + + if !params.To.IsZero() { + query = query.Where(statusdb.UpdatedAtLTE(params.To.UTC())) + } + + response := pagination.PagedResponse[notification.EventDeliveryStatus]{ + Page: params.Page, + } + + paged, err := query.Paginate(ctx, params.Page) + if err != nil { + return response, err + } + + result := make([]notification.EventDeliveryStatus, 0, len(paged.Items)) + for _, statusRow := range paged.Items { + if statusRow == nil { + r.logger.Warn("invalid query response: nil notification event delivery status received") + continue + } + + result = append(result, *EventDeliveryStatusFromDBEntity(*statusRow)) + } + + response.TotalCount = paged.TotalCount + response.Items = result + + return response, nil +} + +func (r repository) GetEventDeliveryStatus(ctx context.Context, params notification.GetEventDeliveryStatusInput) (*notification.EventDeliveryStatus, error) { + db := r.client() + + query := db.NotificationEventDeliveryStatus.Query(). + Where(statusdb.Namespace(params.Namespace)). + Where(statusdb.ID(params.ID)) + + queryRow, err := query.First(ctx) + if err != nil { + if entdb.IsNotFound(err) { + return nil, notification.NotFoundError{ + NamespacedID: models.NamespacedID{ + Namespace: params.Namespace, + ID: params.ID, + }, + } + } + + return nil, fmt.Errorf("failed to get notification event delivery status: %w", err) + } + if queryRow == nil { + return nil, errors.New("invalid query response: no delivery status received") + } + + return EventDeliveryStatusFromDBEntity(*queryRow), nil +} + +func (r repository) UpdateEventDeliveryStatus(ctx context.Context, params notification.UpdateEventDeliveryStatusInput) (*notification.EventDeliveryStatus, error) { + var updateQuery *entdb.NotificationEventDeliveryStatusUpdateOne + + db := r.client() + + if params.ID != "" { + updateQuery = db.NotificationEventDeliveryStatus.UpdateOneID(params.ID).SetState(params.State) + } else { + getQuery := db.NotificationEventDeliveryStatus.Query(). + Where(statusdb.Namespace(params.Namespace)). + Where(statusdb.EventID(params.EventID)). + Where(statusdb.ChannelID(params.ChannelID)) + + statusRow, err := getQuery.First(ctx) + if err != nil { + if entdb.IsNotFound(err) { + return nil, notification.NotFoundError{ + NamespacedID: models.NamespacedID{ + Namespace: params.Namespace, + ID: params.EventID, + }, + } + } + + return nil, fmt.Errorf("failed to udpate notification event delivery status: %w", err) + } + + updateQuery = db.NotificationEventDeliveryStatus.UpdateOne(statusRow). + SetState(params.State). + SetReason(params.Reason) + } + + updateRow, err := updateQuery.Save(ctx) + if err != nil { + if entdb.IsNotFound(err) { + return nil, notification.NotFoundError{ + NamespacedID: models.NamespacedID{ + Namespace: params.Namespace, + ID: params.EventID, + }, + } + } + + return nil, fmt.Errorf("failed to create notification event delivery status: %w", err) + } + + if updateRow == nil { + return nil, fmt.Errorf("invalid query response: no delivery status received") + } + + return EventDeliveryStatusFromDBEntity(*updateRow), nil +} diff --git a/openmeter/notification/repository/repository.go b/openmeter/notification/repository/repository.go index 9df0f6b99..f27a32d19 100644 --- a/openmeter/notification/repository/repository.go +++ b/openmeter/notification/repository/repository.go @@ -2,22 +2,12 @@ package repository import ( "context" - "encoding/json" "errors" "fmt" "log/slog" entdb "github.com/openmeterio/openmeter/openmeter/ent/db" - channeldb "github.com/openmeterio/openmeter/openmeter/ent/db/notificationchannel" - eventdb "github.com/openmeterio/openmeter/openmeter/ent/db/notificationevent" - statusdb "github.com/openmeterio/openmeter/openmeter/ent/db/notificationeventdeliverystatus" - ruledb "github.com/openmeterio/openmeter/openmeter/ent/db/notificationrule" "github.com/openmeterio/openmeter/openmeter/notification" - "github.com/openmeterio/openmeter/pkg/clock" - "github.com/openmeterio/openmeter/pkg/framework/entutils" - "github.com/openmeterio/openmeter/pkg/models" - "github.com/openmeterio/openmeter/pkg/pagination" - "github.com/openmeterio/openmeter/pkg/sortx" ) type Config struct { @@ -97,737 +87,3 @@ func (r repository) WithTx(ctx context.Context) (notification.TxRepository, erro logger: r.logger, }, nil } - -func (r repository) ListChannels(ctx context.Context, params notification.ListChannelsInput) (pagination.PagedResponse[notification.Channel], error) { - db := r.client() - - query := db.NotificationChannel.Query(). - Where(channeldb.DeletedAtIsNil()) // Do not return deleted channels - - if len(params.Namespaces) > 0 { - query = query.Where(channeldb.NamespaceIn(params.Namespaces...)) - } - - if len(params.Channels) > 0 { - query = query.Where(channeldb.IDIn(params.Channels...)) - } - - if !params.IncludeDisabled { - query = query.Where(channeldb.Disabled(false)) - } - - order := entutils.GetOrdering(sortx.OrderDefault) - if !params.Order.IsDefaultValue() { - order = entutils.GetOrdering(params.Order) - } - - switch params.OrderBy { - case notification.ChannelOrderByCreatedAt: - query = query.Order(channeldb.ByCreatedAt(order...)) - case notification.ChannelOrderByUpdatedAt: - query = query.Order(channeldb.ByUpdatedAt(order...)) - case notification.ChannelOrderByType: - query = query.Order(channeldb.ByType(order...)) - case notification.ChannelOrderByID: - fallthrough - default: - query = query.Order(channeldb.ByID(order...)) - } - - response := pagination.PagedResponse[notification.Channel]{ - Page: params.Page, - } - - paged, err := query.Paginate(ctx, params.Page) - if err != nil { - return response, err - } - - result := make([]notification.Channel, 0, len(paged.Items)) - for _, item := range paged.Items { - if item == nil { - r.logger.Warn("invalid query result: nil notification channel received") - continue - } - - result = append(result, *ChannelFromDBEntity(*item)) - } - - response.TotalCount = paged.TotalCount - response.Items = result - - return response, nil -} - -func (r repository) CreateChannel(ctx context.Context, params notification.CreateChannelInput) (*notification.Channel, error) { - db := r.client() - - query := db.NotificationChannel.Create(). - SetType(params.Type). - SetName(params.Name). - SetNamespace(params.Namespace). - SetDisabled(params.Disabled). - SetConfig(params.Config) - - channel, err := query.Save(ctx) - if err != nil { - return nil, fmt.Errorf("failed to create notification channel: %w", err) - } - - if channel == nil { - return nil, fmt.Errorf("invalid query result: nil notification channel received") - } - - return ChannelFromDBEntity(*channel), nil -} - -func (r repository) DeleteChannel(ctx context.Context, params notification.DeleteChannelInput) error { - db := r.client() - - query := db.NotificationChannel.UpdateOneID(params.ID). - SetDeletedAt(clock.Now().UTC()). - SetDisabled(true) - - _, err := query.Save(ctx) - if err != nil { - if entdb.IsNotFound(err) { - return notification.NotFoundError{ - NamespacedID: models.NamespacedID{ - Namespace: params.Namespace, - ID: params.ID, - }, - } - } - - return fmt.Errorf("failed to delete notification channel: %w", err) - } - - return nil -} - -func (r repository) GetChannel(ctx context.Context, params notification.GetChannelInput) (*notification.Channel, error) { - db := r.client() - - query := db.NotificationChannel.Query(). - Where(channeldb.ID(params.ID)). - Where(channeldb.Namespace(params.Namespace)) - - queryRow, err := query.First(ctx) - if err != nil { - if entdb.IsNotFound(err) { - return nil, notification.NotFoundError{ - NamespacedID: models.NamespacedID{ - Namespace: params.Namespace, - ID: params.ID, - }, - } - } - - return nil, fmt.Errorf("failed to fetch notification channel: %w", err) - } - - if queryRow == nil { - return nil, fmt.Errorf("invalid query result: nil notification channel received") - } - - return ChannelFromDBEntity(*queryRow), nil -} - -func (r repository) UpdateChannel(ctx context.Context, params notification.UpdateChannelInput) (*notification.Channel, error) { - db := r.client() - - query := db.NotificationChannel.UpdateOneID(params.ID). - SetUpdatedAt(clock.Now().UTC()). - SetDisabled(params.Disabled). - SetConfig(params.Config). - SetName(params.Name) - - queryRow, err := query.Save(ctx) - if err != nil { - if entdb.IsNotFound(err) { - return nil, notification.NotFoundError{ - NamespacedID: models.NamespacedID{ - Namespace: params.Namespace, - ID: params.ID, - }, - } - } - - return nil, fmt.Errorf("failed to update notification channel: %w", err) - } - - if queryRow == nil { - return nil, fmt.Errorf("invalid query result: nil notification channel received") - } - - return ChannelFromDBEntity(*queryRow), nil -} - -func (r repository) ListRules(ctx context.Context, params notification.ListRulesInput) (pagination.PagedResponse[notification.Rule], error) { - db := r.client() - - query := db.NotificationRule.Query(). - Where(ruledb.DeletedAtIsNil()) // Do not return deleted Rules - - if len(params.Namespaces) > 0 { - query = query.Where(ruledb.NamespaceIn(params.Namespaces...)) - } - - if len(params.Rules) > 0 { - query = query.Where(ruledb.IDIn(params.Rules...)) - } - - if !params.IncludeDisabled { - query = query.Where(ruledb.Disabled(false)) - } - - if len(params.Types) > 0 { - query = query.Where(ruledb.TypeIn(params.Types...)) - } - - if len(params.Channels) > 0 { - query = query.Where(ruledb.HasChannelsWith(channeldb.IDIn(params.Channels...))) - } - - // Eager load Channels - query = query.WithChannels() - - order := entutils.GetOrdering(sortx.OrderDefault) - if !params.Order.IsDefaultValue() { - order = entutils.GetOrdering(params.Order) - } - - switch params.OrderBy { - case notification.RuleOrderByCreatedAt: - query = query.Order(ruledb.ByCreatedAt(order...)) - case notification.RuleOrderByUpdatedAt: - query = query.Order(ruledb.ByUpdatedAt(order...)) - case notification.RuleOrderByType: - query = query.Order(ruledb.ByType(order...)) - case notification.RuleOrderByID: - fallthrough - default: - query = query.Order(ruledb.ByID(order...)) - } - - response := pagination.PagedResponse[notification.Rule]{ - Page: params.Page, - } - - paged, err := query.Paginate(ctx, params.Page) - if err != nil { - return response, err - } - - result := make([]notification.Rule, 0, len(paged.Items)) - for _, ruleRow := range paged.Items { - if ruleRow == nil { - r.logger.Warn("invalid query result: nil notification rule received") - continue - } - - result = append(result, *RuleFromDBEntity(*ruleRow)) - } - - response.TotalCount = paged.TotalCount - response.Items = result - - return response, nil -} - -func (r repository) CreateRule(ctx context.Context, params notification.CreateRuleInput) (*notification.Rule, error) { - db := r.client() - - query := db.NotificationRule.Create(). - SetType(params.Type). - SetName(params.Name). - SetNamespace(params.Namespace). - SetDisabled(params.Disabled). - SetConfig(params.Config). - AddChannelIDs(params.Channels...) - - queryRow, err := query.Save(ctx) - if err != nil { - return nil, fmt.Errorf("failed to create notification rule: %w", err) - } - - if queryRow == nil { - return nil, fmt.Errorf("invalid query result: nil notification rule received") - } - - channelsQuery := db.NotificationChannel.Query(). - Where(channeldb.Namespace(params.Namespace)). - Where(channeldb.IDIn(params.Channels...)) - - channelRows, err := channelsQuery.All(ctx) - if err != nil { - return nil, fmt.Errorf("failed to query notification channels: %w", err) - } - - queryRow.Edges.Channels = channelRows - - return RuleFromDBEntity(*queryRow), nil -} - -func (r repository) DeleteRule(ctx context.Context, params notification.DeleteRuleInput) error { - db := r.client() - - query := db.NotificationRule.UpdateOneID(params.ID). - Where(ruledb.Namespace(params.Namespace)). - SetDeletedAt(clock.Now().UTC()). - SetDisabled(true) - - _, err := query.Save(ctx) - if err != nil { - if entdb.IsNotFound(err) { - return notification.NotFoundError{ - NamespacedID: models.NamespacedID{ - Namespace: params.Namespace, - ID: params.ID, - }, - } - } - - return fmt.Errorf("failed top delete notification rule: %w", err) - } - - return nil -} - -func (r repository) GetRule(ctx context.Context, params notification.GetRuleInput) (*notification.Rule, error) { - db := r.client() - - query := db.NotificationRule.Query(). - Where(ruledb.ID(params.ID)). - Where(ruledb.Namespace(params.Namespace)). - WithChannels() - - ruleRow, err := query.First(ctx) - if err != nil { - if entdb.IsNotFound(err) { - return nil, notification.NotFoundError{ - NamespacedID: models.NamespacedID{ - Namespace: params.Namespace, - ID: params.ID, - }, - } - } - - return nil, fmt.Errorf("failed to fetch notification rule: %w", err) - } - - if ruleRow == nil { - return nil, fmt.Errorf("invalid query result: nil notification rule received") - } - - return RuleFromDBEntity(*ruleRow), nil -} - -func (r repository) UpdateRule(ctx context.Context, params notification.UpdateRuleInput) (*notification.Rule, error) { - db := r.client() - - query := db.NotificationRule.UpdateOneID(params.ID). - SetUpdatedAt(clock.Now().UTC()). - SetDisabled(params.Disabled). - SetConfig(params.Config). - SetName(params.Name). - ClearChannels(). - AddChannelIDs(params.Channels...) - - queryRow, err := query.Save(ctx) - if err != nil { - if entdb.IsNotFound(err) { - return nil, notification.NotFoundError{ - NamespacedID: models.NamespacedID{ - Namespace: params.Namespace, - ID: params.ID, - }, - } - } - - return nil, fmt.Errorf("failed to update notification rule: %w", err) - } - - if queryRow == nil { - return nil, fmt.Errorf("invalid query result: nil notification rule received") - } - - channelsQuery := db.NotificationChannel.Query(). - Where(channeldb.Namespace(params.Namespace)). - Where(channeldb.IDIn(params.Channels...)) - - channelRows, err := channelsQuery.All(ctx) - if err != nil { - return nil, fmt.Errorf("failed to query notification channels: %w", err) - } - - queryRow.Edges.Channels = channelRows - - return RuleFromDBEntity(*queryRow), nil -} - -func (r repository) ListEvents(ctx context.Context, params notification.ListEventsInput) (pagination.PagedResponse[notification.Event], error) { - db := r.client() - - query := db.NotificationEvent.Query() - - if len(params.Namespaces) > 0 { - query = query.Where(eventdb.NamespaceIn(params.Namespaces...)) - } - - if len(params.Events) > 0 { - query = query.Where(eventdb.IDIn(params.Events...)) - } - - if !params.From.IsZero() { - query = query.Where(eventdb.CreatedAtGTE(params.From.UTC())) - } - - if !params.To.IsZero() { - query = query.Where(eventdb.CreatedAtLTE(params.To.UTC())) - } - - if len(params.DeduplicationHashes) > 0 { - query = query.Where( - entutils.JSONBIn(eventdb.FieldAnnotations, notification.AnnotationEventDedupeHash, params.DeduplicationHashes), - ) - } - - // Eager load DeliveryStatus, Rules (including Channels) - if len(params.DeliveryStatusStates) > 0 { - query = query.WithDeliveryStatuses(func(query *entdb.NotificationEventDeliveryStatusQuery) { - query.Where(statusdb.StateIn(params.DeliveryStatusStates...)) - }) - } else { - query = query.WithDeliveryStatuses() - } - - if len(params.Features) > 0 { - query = query.Where( - eventdb.Or( - entutils.JSONBIn(eventdb.FieldAnnotations, notification.AnnotationEventFeatureKey, params.Features), - entutils.JSONBIn(eventdb.FieldAnnotations, notification.AnnotationEventFeatureID, params.Features), - ), - ) - } - - if len(params.Subjects) > 0 { - query = query.Where( - eventdb.Or( - entutils.JSONBIn(eventdb.FieldAnnotations, notification.AnnotationEventSubjectKey, params.Subjects), - entutils.JSONBIn(eventdb.FieldAnnotations, notification.AnnotationEventSubjectID, params.Subjects), - ), - ) - } - - if len(params.Rules) > 0 { - query = query.Where(eventdb.RuleIDIn(params.Rules...)) - } - - if len(params.Channels) > 0 { - query = query.Where(eventdb.HasRulesWith(ruledb.HasChannelsWith(channeldb.IDIn(params.Channels...)))) - } - - query = query.WithRules(func(query *entdb.NotificationRuleQuery) { - query.WithChannels() - }) - - order := entutils.GetOrdering(sortx.OrderDesc) - if !params.Order.IsDefaultValue() { - order = entutils.GetOrdering(params.Order) - } - - switch params.OrderBy { - case notification.EventOrderByID: - query = query.Order(eventdb.ByID(order...)) - case notification.EventOrderByCreatedAt: - fallthrough - default: - query = query.Order(eventdb.ByCreatedAt(order...)) - } - - response := pagination.PagedResponse[notification.Event]{ - Page: params.Page, - } - - paged, err := query.Paginate(ctx, params.Page) - if err != nil { - return response, err - } - - result := make([]notification.Event, 0, len(paged.Items)) - for _, eventRow := range paged.Items { - if eventRow == nil { - r.logger.Warn("invalid query result: nil notification event received") - continue - } - - event, err := EventFromDBEntity(*eventRow) - if err != nil { - return response, fmt.Errorf("failed to get notification events: %w", err) - } - - result = append(result, *event) - } - - response.TotalCount = paged.TotalCount - response.Items = result - - return response, nil -} - -func (r repository) GetEvent(ctx context.Context, params notification.GetEventInput) (*notification.Event, error) { - db := r.client() - - query := db.NotificationEvent.Query(). - Where(eventdb.Namespace(params.Namespace)). - Where(eventdb.ID(params.ID)). - WithDeliveryStatuses(). - WithRules() - - eventRow, err := query.First(ctx) - if err != nil { - if entdb.IsNotFound(err) { - return nil, notification.NotFoundError{ - NamespacedID: models.NamespacedID{ - Namespace: params.Namespace, - ID: params.ID, - }, - } - } - - return nil, fmt.Errorf("failed to get notification event: %w", err) - } - - if eventRow == nil { - return nil, errors.New("invalid query response: nil notification event received") - } - - event, err := EventFromDBEntity(*eventRow) - if err != nil { - return nil, fmt.Errorf("failed to get notification event: %w", err) - } - - return event, nil -} - -func (r repository) CreateEvent(ctx context.Context, params notification.CreateEventInput) (*notification.Event, error) { - payloadJSON, err := json.Marshal(params.Payload) - if err != nil { - return nil, fmt.Errorf("failed to serialize notification event payload: %w", err) - } - - db := r.client() - - query := db.NotificationEvent.Create(). - SetType(params.Type). - SetNamespace(params.Namespace). - SetRuleID(params.RuleID). - SetPayload(string(payloadJSON)). - SetAnnotations(params.Annotations) - - eventRow, err := query.Save(ctx) - if err != nil { - return nil, fmt.Errorf("failed to create notification event: %w", err) - } - - if eventRow == nil { - return nil, errors.New("invalid query response: nil notification event received") - } - - ruleQuery := db.NotificationRule.Query(). - Where(ruledb.Namespace(params.Namespace)). - Where(ruledb.ID(params.RuleID)). - Where(ruledb.DeletedAtIsNil()). - WithChannels() - - ruleRow, err := ruleQuery.First(ctx) - if err != nil { - if entdb.IsNotFound(err) { - return nil, notification.NotFoundError{ - NamespacedID: models.NamespacedID{ - Namespace: params.Namespace, - ID: params.RuleID, - }, - } - } - - return nil, fmt.Errorf("failed to fetch notification rule: %w", err) - } - if ruleRow == nil { - return nil, errors.New("invalid query result: nil notification rule received") - } - - if _, err = ruleRow.Edges.ChannelsOrErr(); err != nil { - return nil, fmt.Errorf("invalid query result: failed to load notification chnnaels for rule: %w", err) - } - - eventRow.Edges.Rules = ruleRow - - statusBulkQuery := make([]*entdb.NotificationEventDeliveryStatusCreate, 0, len(ruleRow.Edges.Channels)) - for _, channel := range ruleRow.Edges.Channels { - if channel == nil { - r.logger.Warn("invalid query result: nil channel received") - continue - } - - q := db.NotificationEventDeliveryStatus.Create(). - SetNamespace(params.Namespace). - SetEventID(eventRow.ID). - SetChannelID(channel.ID). - SetState(notification.EventDeliveryStatusStatePending). - AddEvents(eventRow) - - statusBulkQuery = append(statusBulkQuery, q) - } - - statusQuery := db.NotificationEventDeliveryStatus.CreateBulk(statusBulkQuery...) - - statusRows, err := statusQuery.Save(ctx) - if err != nil { - return nil, fmt.Errorf("failed to save notification event: %w", err) - } - - eventRow.Edges.DeliveryStatuses = statusRows - - event, err := EventFromDBEntity(*eventRow) - if err != nil { - return nil, fmt.Errorf("failed to cast notification event: %w", err) - } - - return event, nil -} - -func (r repository) ListEventsDeliveryStatus(ctx context.Context, params notification.ListEventsDeliveryStatusInput) (pagination.PagedResponse[notification.EventDeliveryStatus], error) { - db := r.client() - - query := db.NotificationEventDeliveryStatus.Query() - - if len(params.Namespaces) > 0 { - query = query.Where(statusdb.NamespaceIn(params.Namespaces...)) - } - - if len(params.Events) > 0 { - query = query.Where(statusdb.EventIDIn(params.Events...)) - } - - if len(params.Channels) > 0 { - query = query.Where(statusdb.ChannelIDIn(params.Channels...)) - } - - if len(params.States) > 0 { - query = query.Where(statusdb.StateIn(params.States...)) - } - - if !params.From.IsZero() { - query = query.Where(statusdb.UpdatedAtGTE(params.From.UTC())) - } - - if !params.To.IsZero() { - query = query.Where(statusdb.UpdatedAtLTE(params.To.UTC())) - } - - response := pagination.PagedResponse[notification.EventDeliveryStatus]{ - Page: params.Page, - } - - paged, err := query.Paginate(ctx, params.Page) - if err != nil { - return response, err - } - - result := make([]notification.EventDeliveryStatus, 0, len(paged.Items)) - for _, statusRow := range paged.Items { - if statusRow == nil { - r.logger.Warn("invalid query response: nil notification event delivery status received") - continue - } - - result = append(result, *EventDeliveryStatusFromDBEntity(*statusRow)) - } - - response.TotalCount = paged.TotalCount - response.Items = result - - return response, nil -} - -func (r repository) GetEventDeliveryStatus(ctx context.Context, params notification.GetEventDeliveryStatusInput) (*notification.EventDeliveryStatus, error) { - db := r.client() - - query := db.NotificationEventDeliveryStatus.Query(). - Where(statusdb.Namespace(params.Namespace)). - Where(statusdb.ID(params.ID)) - - queryRow, err := query.First(ctx) - if err != nil { - if entdb.IsNotFound(err) { - return nil, notification.NotFoundError{ - NamespacedID: models.NamespacedID{ - Namespace: params.Namespace, - ID: params.ID, - }, - } - } - - return nil, fmt.Errorf("failed to get notification event delivery status: %w", err) - } - if queryRow == nil { - return nil, errors.New("invalid query response: no delivery status received") - } - - return EventDeliveryStatusFromDBEntity(*queryRow), nil -} - -func (r repository) UpdateEventDeliveryStatus(ctx context.Context, params notification.UpdateEventDeliveryStatusInput) (*notification.EventDeliveryStatus, error) { - var updateQuery *entdb.NotificationEventDeliveryStatusUpdateOne - - db := r.client() - - if params.ID != "" { - updateQuery = db.NotificationEventDeliveryStatus.UpdateOneID(params.ID).SetState(params.State) - } else { - getQuery := db.NotificationEventDeliveryStatus.Query(). - Where(statusdb.Namespace(params.Namespace)). - Where(statusdb.EventID(params.EventID)). - Where(statusdb.ChannelID(params.ChannelID)) - - statusRow, err := getQuery.First(ctx) - if err != nil { - if entdb.IsNotFound(err) { - return nil, notification.NotFoundError{ - NamespacedID: models.NamespacedID{ - Namespace: params.Namespace, - ID: params.EventID, - }, - } - } - - return nil, fmt.Errorf("failed to udpate notification event delivery status: %w", err) - } - - updateQuery = db.NotificationEventDeliveryStatus.UpdateOne(statusRow). - SetState(params.State). - SetReason(params.Reason) - } - - updateRow, err := updateQuery.Save(ctx) - if err != nil { - if entdb.IsNotFound(err) { - return nil, notification.NotFoundError{ - NamespacedID: models.NamespacedID{ - Namespace: params.Namespace, - ID: params.EventID, - }, - } - } - - return nil, fmt.Errorf("failed to create notification event delivery status: %w", err) - } - - if updateRow == nil { - return nil, fmt.Errorf("invalid query response: no delivery status received") - } - - return EventDeliveryStatusFromDBEntity(*updateRow), nil -} diff --git a/openmeter/notification/repository/rule.go b/openmeter/notification/repository/rule.go new file mode 100644 index 000000000..16438c96a --- /dev/null +++ b/openmeter/notification/repository/rule.go @@ -0,0 +1,219 @@ +package repository + +import ( + "context" + "fmt" + + entdb "github.com/openmeterio/openmeter/openmeter/ent/db" + channeldb "github.com/openmeterio/openmeter/openmeter/ent/db/notificationchannel" + ruledb "github.com/openmeterio/openmeter/openmeter/ent/db/notificationrule" + "github.com/openmeterio/openmeter/openmeter/notification" + "github.com/openmeterio/openmeter/pkg/clock" + "github.com/openmeterio/openmeter/pkg/framework/entutils" + "github.com/openmeterio/openmeter/pkg/models" + "github.com/openmeterio/openmeter/pkg/pagination" + "github.com/openmeterio/openmeter/pkg/sortx" +) + +func (r repository) ListRules(ctx context.Context, params notification.ListRulesInput) (pagination.PagedResponse[notification.Rule], error) { + db := r.client() + + query := db.NotificationRule.Query(). + Where(ruledb.DeletedAtIsNil()) // Do not return deleted Rules + + if len(params.Namespaces) > 0 { + query = query.Where(ruledb.NamespaceIn(params.Namespaces...)) + } + + if len(params.Rules) > 0 { + query = query.Where(ruledb.IDIn(params.Rules...)) + } + + if !params.IncludeDisabled { + query = query.Where(ruledb.Disabled(false)) + } + + if len(params.Types) > 0 { + query = query.Where(ruledb.TypeIn(params.Types...)) + } + + if len(params.Channels) > 0 { + query = query.Where(ruledb.HasChannelsWith(channeldb.IDIn(params.Channels...))) + } + + // Eager load Channels + query = query.WithChannels() + + order := entutils.GetOrdering(sortx.OrderDefault) + if !params.Order.IsDefaultValue() { + order = entutils.GetOrdering(params.Order) + } + + switch params.OrderBy { + case notification.RuleOrderByCreatedAt: + query = query.Order(ruledb.ByCreatedAt(order...)) + case notification.RuleOrderByUpdatedAt: + query = query.Order(ruledb.ByUpdatedAt(order...)) + case notification.RuleOrderByType: + query = query.Order(ruledb.ByType(order...)) + case notification.RuleOrderByID: + fallthrough + default: + query = query.Order(ruledb.ByID(order...)) + } + + response := pagination.PagedResponse[notification.Rule]{ + Page: params.Page, + } + + paged, err := query.Paginate(ctx, params.Page) + if err != nil { + return response, err + } + + result := make([]notification.Rule, 0, len(paged.Items)) + for _, ruleRow := range paged.Items { + if ruleRow == nil { + r.logger.Warn("invalid query result: nil notification rule received") + continue + } + + result = append(result, *RuleFromDBEntity(*ruleRow)) + } + + response.TotalCount = paged.TotalCount + response.Items = result + + return response, nil +} + +func (r repository) CreateRule(ctx context.Context, params notification.CreateRuleInput) (*notification.Rule, error) { + db := r.client() + + query := db.NotificationRule.Create(). + SetType(params.Type). + SetName(params.Name). + SetNamespace(params.Namespace). + SetDisabled(params.Disabled). + SetConfig(params.Config). + AddChannelIDs(params.Channels...) + + queryRow, err := query.Save(ctx) + if err != nil { + return nil, fmt.Errorf("failed to create notification rule: %w", err) + } + + if queryRow == nil { + return nil, fmt.Errorf("invalid query result: nil notification rule received") + } + + channelsQuery := db.NotificationChannel.Query(). + Where(channeldb.Namespace(params.Namespace)). + Where(channeldb.IDIn(params.Channels...)) + + channelRows, err := channelsQuery.All(ctx) + if err != nil { + return nil, fmt.Errorf("failed to query notification channels: %w", err) + } + + queryRow.Edges.Channels = channelRows + + return RuleFromDBEntity(*queryRow), nil +} + +func (r repository) DeleteRule(ctx context.Context, params notification.DeleteRuleInput) error { + db := r.client() + + query := db.NotificationRule.UpdateOneID(params.ID). + Where(ruledb.Namespace(params.Namespace)). + SetDeletedAt(clock.Now().UTC()). + SetDisabled(true) + + _, err := query.Save(ctx) + if err != nil { + if entdb.IsNotFound(err) { + return notification.NotFoundError{ + NamespacedID: models.NamespacedID{ + Namespace: params.Namespace, + ID: params.ID, + }, + } + } + + return fmt.Errorf("failed top delete notification rule: %w", err) + } + + return nil +} + +func (r repository) GetRule(ctx context.Context, params notification.GetRuleInput) (*notification.Rule, error) { + db := r.client() + + query := db.NotificationRule.Query(). + Where(ruledb.ID(params.ID)). + Where(ruledb.Namespace(params.Namespace)). + WithChannels() + + ruleRow, err := query.First(ctx) + if err != nil { + if entdb.IsNotFound(err) { + return nil, notification.NotFoundError{ + NamespacedID: models.NamespacedID{ + Namespace: params.Namespace, + ID: params.ID, + }, + } + } + + return nil, fmt.Errorf("failed to fetch notification rule: %w", err) + } + + if ruleRow == nil { + return nil, fmt.Errorf("invalid query result: nil notification rule received") + } + + return RuleFromDBEntity(*ruleRow), nil +} + +func (r repository) UpdateRule(ctx context.Context, params notification.UpdateRuleInput) (*notification.Rule, error) { + db := r.client() + + query := db.NotificationRule.UpdateOneID(params.ID). + SetUpdatedAt(clock.Now().UTC()). + SetDisabled(params.Disabled). + SetConfig(params.Config). + SetName(params.Name). + ClearChannels(). + AddChannelIDs(params.Channels...) + + queryRow, err := query.Save(ctx) + if err != nil { + if entdb.IsNotFound(err) { + return nil, notification.NotFoundError{ + NamespacedID: models.NamespacedID{ + Namespace: params.Namespace, + ID: params.ID, + }, + } + } + + return nil, fmt.Errorf("failed to update notification rule: %w", err) + } + + if queryRow == nil { + return nil, fmt.Errorf("invalid query result: nil notification rule received") + } + + channelsQuery := db.NotificationChannel.Query(). + Where(channeldb.Namespace(params.Namespace)). + Where(channeldb.IDIn(params.Channels...)) + + channelRows, err := channelsQuery.All(ctx) + if err != nil { + return nil, fmt.Errorf("failed to query notification channels: %w", err) + } + + queryRow.Edges.Channels = channelRows + + return RuleFromDBEntity(*queryRow), nil +} From 6c3e0fb97d3a7235155692fdede338c215266c25 Mon Sep 17 00:00:00 2001 From: Krisztian Gacsal Date: Tue, 27 Aug 2024 17:17:03 +0200 Subject: [PATCH 3/3] refactor: notification eventhandler --- openmeter/notification/eventhandler.go | 229 ----------------- .../notification/eventhandler/handler.go | 236 ++++++++++++++++++ openmeter/notification/service/service.go | 3 +- 3 files changed, 238 insertions(+), 230 deletions(-) create mode 100644 openmeter/notification/eventhandler/handler.go diff --git a/openmeter/notification/eventhandler.go b/openmeter/notification/eventhandler.go index 7b13c614c..7dd34e6f9 100644 --- a/openmeter/notification/eventhandler.go +++ b/openmeter/notification/eventhandler.go @@ -2,14 +2,7 @@ package notification import ( "context" - "errors" - "fmt" - "log/slog" "time" - - "github.com/openmeterio/openmeter/openmeter/notification/webhook" - "github.com/openmeterio/openmeter/pkg/models" - "github.com/openmeterio/openmeter/pkg/pagination" ) const ( @@ -32,225 +25,3 @@ type EventReconciler interface { type EventDispatcher interface { Dispatch(*Event) error } - -type EventHandlerConfig struct { - Repository Repository - Webhook webhook.Handler - Logger *slog.Logger - ReconcileInterval time.Duration -} - -func (c *EventHandlerConfig) Validate() error { - if c.Repository == nil { - return fmt.Errorf("repository is required") - } - - if c.Webhook == nil { - return fmt.Errorf("webhook is required") - } - - return nil -} - -var _ EventHandler = (*handler)(nil) - -type handler struct { - repo Repository - webhook webhook.Handler - logger *slog.Logger - - reconcileInterval time.Duration - - stopCh chan struct{} -} - -func (h *handler) Start() error { - go func() { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - ticker := time.NewTicker(h.reconcileInterval) - defer ticker.Stop() - - logger := h.logger.WithGroup("reconciler") - - for { - select { - case <-h.stopCh: - logger.Debug("close event received: stopping reconciler") - return - case <-ticker.C: - if err := h.Reconcile(ctx); err != nil { - logger.Error("failed to reconcile event(s)", "error", err) - } - } - } - }() - - return nil -} - -func (h *handler) Close() error { - close(h.stopCh) - - return nil -} - -func (h *handler) reconcilePending(ctx context.Context, event *Event) error { - return h.dispatch(ctx, event) -} - -func (h *handler) reconcileSending(_ context.Context, _ *Event) error { - // NOTE(chrisgacsal): implement when EventDeliveryStatusStateSending state is need to be handled - return nil -} - -func (h *handler) reconcileFailed(_ context.Context, _ *Event) error { - // NOTE(chrisgacsal): reconcile failed events when adding support for retry on event delivery failure - return nil -} - -func (h *handler) Reconcile(ctx context.Context) error { - events, err := h.repo.ListEvents(ctx, ListEventsInput{ - Page: pagination.Page{}, - DeliveryStatusStates: []EventDeliveryStatusState{ - EventDeliveryStatusStatePending, - EventDeliveryStatusStateSending, - }, - }) - if err != nil { - return fmt.Errorf("failed to fetch notification delivery statuses for reconciliation: %w", err) - } - - for _, event := range events.Items { - var errs error - for _, state := range DeliveryStatusStates(event.DeliveryStatus) { - switch state { - case EventDeliveryStatusStatePending: - if err = h.reconcilePending(ctx, &event); err != nil { - errs = errors.Join(errs, err) - } - case EventDeliveryStatusStateSending: - if err = h.reconcileSending(ctx, &event); err != nil { - errs = errors.Join(errs, err) - } - case EventDeliveryStatusStateFailed: - if err = h.reconcileFailed(ctx, &event); err != nil { - errs = errors.Join(errs, err) - } - } - } - - if errs != nil { - return fmt.Errorf("failed to reconcile notification event: %w", errs) - } - } - - return nil -} - -func (h *handler) dispatchWebhook(ctx context.Context, event *Event) error { - sendIn := webhook.SendMessageInput{ - Namespace: event.Namespace, - EventID: event.ID, - EventType: string(event.Type), - Channels: []string{event.Rule.ID}, - } - - switch event.Type { - case EventTypeBalanceThreshold: - payload := event.Payload.AsNotificationEventBalanceThresholdPayload(event.ID, event.CreatedAt) - payloadMap, err := PayloadToMapInterface(payload) - if err != nil { - return fmt.Errorf("failed to cast event payload: %w", err) - } - - sendIn.Payload = payloadMap - default: - return fmt.Errorf("unknown event type: %s", event.Type) - } - - logger := h.logger.With("eventID", event.ID, "eventType", event.Type) - - var stateReason string - state := EventDeliveryStatusStateSuccess - _, err := h.webhook.SendMessage(ctx, sendIn) - if err != nil { - logger.Error("failed to send webhook message: error returned by webhook service", "error", err) - stateReason = "failed to send webhook message: error returned by webhook service" - state = EventDeliveryStatusStateFailed - } - - for _, channelID := range ChannelIDsByType(event.Rule.Channels, ChannelTypeWebhook) { - _, err = h.repo.UpdateEventDeliveryStatus(ctx, UpdateEventDeliveryStatusInput{ - NamespacedModel: models.NamespacedModel{ - Namespace: event.Namespace, - }, - State: state, - Reason: stateReason, - EventID: event.ID, - ChannelID: channelID, - }) - if err != nil { - return fmt.Errorf("failed to update event delivery: %w", err) - } - } - - return nil -} - -func (h *handler) dispatch(ctx context.Context, event *Event) error { - var errs error - - for _, channelType := range ChannelTypes(event.Rule.Channels) { - var err error - - switch channelType { - case ChannelTypeWebhook: - err = h.dispatchWebhook(ctx, event) - default: - err = fmt.Errorf("unknown channel type: %s", channelType) - } - - if err != nil { - errs = errors.Join(errs, err) - } - } - - return errs -} - -func (h *handler) Dispatch(event *Event) error { - go func() { - ctx, cancel := context.WithTimeout(context.Background(), DefaultDispatchTimeout) - defer cancel() - - if err := h.dispatch(ctx, event); err != nil { - h.logger.Warn("failed to dispatch event", "eventID", event.ID, "error", err) - } - }() - - return nil -} - -func NewEventHandler(config EventHandlerConfig) (EventHandler, error) { - if err := config.Validate(); err != nil { - return nil, err - } - - if config.ReconcileInterval == 0 { - config.ReconcileInterval = DefaultReconcileInterval - } - - if config.Logger == nil { - config.Logger = slog.Default() - } - - return &handler{ - repo: config.Repository, - webhook: config.Webhook, - reconcileInterval: config.ReconcileInterval, - logger: config.Logger, - stopCh: make(chan struct{}), - }, nil -} diff --git a/openmeter/notification/eventhandler/handler.go b/openmeter/notification/eventhandler/handler.go new file mode 100644 index 000000000..3203659cf --- /dev/null +++ b/openmeter/notification/eventhandler/handler.go @@ -0,0 +1,236 @@ +package eventhandler + +import ( + "context" + "errors" + "fmt" + "log/slog" + "time" + + "github.com/openmeterio/openmeter/openmeter/notification" + "github.com/openmeterio/openmeter/openmeter/notification/webhook" + "github.com/openmeterio/openmeter/pkg/models" + "github.com/openmeterio/openmeter/pkg/pagination" +) + +type Config struct { + Repository notification.Repository + Webhook webhook.Handler + Logger *slog.Logger + ReconcileInterval time.Duration +} + +func (c *Config) Validate() error { + if c.Repository == nil { + return fmt.Errorf("repository is required") + } + + if c.Webhook == nil { + return fmt.Errorf("webhook is required") + } + + return nil +} + +var _ notification.EventHandler = (*Handler)(nil) + +type Handler struct { + repo notification.Repository + webhook webhook.Handler + logger *slog.Logger + + reconcileInterval time.Duration + + stopCh chan struct{} +} + +func (h *Handler) Start() error { + go func() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ticker := time.NewTicker(h.reconcileInterval) + defer ticker.Stop() + + logger := h.logger.WithGroup("reconciler") + + for { + select { + case <-h.stopCh: + logger.Debug("close event received: stopping reconciler") + return + case <-ticker.C: + if err := h.Reconcile(ctx); err != nil { + logger.Error("failed to reconcile event(s)", "error", err) + } + } + } + }() + + return nil +} + +func (h *Handler) Close() error { + close(h.stopCh) + + return nil +} + +func (h *Handler) reconcilePending(ctx context.Context, event *notification.Event) error { + return h.dispatch(ctx, event) +} + +func (h *Handler) reconcileSending(_ context.Context, _ *notification.Event) error { + // NOTE(chrisgacsal): implement when EventDeliveryStatusStateSending state is need to be handled + return nil +} + +func (h *Handler) reconcileFailed(_ context.Context, _ *notification.Event) error { + // NOTE(chrisgacsal): reconcile failed events when adding support for retry on event delivery failure + return nil +} + +func (h *Handler) Reconcile(ctx context.Context) error { + events, err := h.repo.ListEvents(ctx, notification.ListEventsInput{ + Page: pagination.Page{}, + DeliveryStatusStates: []notification.EventDeliveryStatusState{ + notification.EventDeliveryStatusStatePending, + notification.EventDeliveryStatusStateSending, + }, + }) + if err != nil { + return fmt.Errorf("failed to fetch notification delivery statuses for reconciliation: %w", err) + } + + for _, event := range events.Items { + var errs error + for _, state := range notification.DeliveryStatusStates(event.DeliveryStatus) { + switch state { + case notification.EventDeliveryStatusStatePending: + if err = h.reconcilePending(ctx, &event); err != nil { + errs = errors.Join(errs, err) + } + case notification.EventDeliveryStatusStateSending: + if err = h.reconcileSending(ctx, &event); err != nil { + errs = errors.Join(errs, err) + } + case notification.EventDeliveryStatusStateFailed: + if err = h.reconcileFailed(ctx, &event); err != nil { + errs = errors.Join(errs, err) + } + } + } + + if errs != nil { + return fmt.Errorf("failed to reconcile notification event: %w", errs) + } + } + + return nil +} + +func (h *Handler) dispatchWebhook(ctx context.Context, event *notification.Event) error { + sendIn := webhook.SendMessageInput{ + Namespace: event.Namespace, + EventID: event.ID, + EventType: string(event.Type), + Channels: []string{event.Rule.ID}, + } + + switch event.Type { + case notification.EventTypeBalanceThreshold: + payload := event.Payload.AsNotificationEventBalanceThresholdPayload(event.ID, event.CreatedAt) + payloadMap, err := notification.PayloadToMapInterface(payload) + if err != nil { + return fmt.Errorf("failed to cast event payload: %w", err) + } + + sendIn.Payload = payloadMap + default: + return fmt.Errorf("unknown event type: %s", event.Type) + } + + logger := h.logger.With("eventID", event.ID, "eventType", event.Type) + + var stateReason string + state := notification.EventDeliveryStatusStateSuccess + _, err := h.webhook.SendMessage(ctx, sendIn) + if err != nil { + logger.Error("failed to send webhook message: error returned by webhook service", "error", err) + stateReason = "failed to send webhook message: error returned by webhook service" + state = notification.EventDeliveryStatusStateFailed + } + + for _, channelID := range notification.ChannelIDsByType(event.Rule.Channels, notification.ChannelTypeWebhook) { + _, err = h.repo.UpdateEventDeliveryStatus(ctx, notification.UpdateEventDeliveryStatusInput{ + NamespacedModel: models.NamespacedModel{ + Namespace: event.Namespace, + }, + State: state, + Reason: stateReason, + EventID: event.ID, + ChannelID: channelID, + }) + if err != nil { + return fmt.Errorf("failed to update event delivery: %w", err) + } + } + + return nil +} + +func (h *Handler) dispatch(ctx context.Context, event *notification.Event) error { + var errs error + + for _, channelType := range notification.ChannelTypes(event.Rule.Channels) { + var err error + + switch channelType { + case notification.ChannelTypeWebhook: + err = h.dispatchWebhook(ctx, event) + default: + err = fmt.Errorf("unknown channel type: %s", channelType) + } + + if err != nil { + errs = errors.Join(errs, err) + } + } + + return errs +} + +func (h *Handler) Dispatch(event *notification.Event) error { + go func() { + ctx, cancel := context.WithTimeout(context.Background(), notification.DefaultDispatchTimeout) + defer cancel() + + if err := h.dispatch(ctx, event); err != nil { + h.logger.Warn("failed to dispatch event", "eventID", event.ID, "error", err) + } + }() + + return nil +} + +func New(config Config) (*Handler, error) { + if err := config.Validate(); err != nil { + return nil, err + } + + if config.ReconcileInterval == 0 { + config.ReconcileInterval = notification.DefaultReconcileInterval + } + + if config.Logger == nil { + config.Logger = slog.Default() + } + + return &Handler{ + repo: config.Repository, + webhook: config.Webhook, + reconcileInterval: config.ReconcileInterval, + logger: config.Logger, + stopCh: make(chan struct{}), + }, nil +} diff --git a/openmeter/notification/service/service.go b/openmeter/notification/service/service.go index 0df200c28..f1a96faf5 100644 --- a/openmeter/notification/service/service.go +++ b/openmeter/notification/service/service.go @@ -7,6 +7,7 @@ import ( "log/slog" "github.com/openmeterio/openmeter/openmeter/notification" + "github.com/openmeterio/openmeter/openmeter/notification/eventhandler" "github.com/openmeterio/openmeter/openmeter/notification/webhook" "github.com/openmeterio/openmeter/openmeter/productcatalog" ) @@ -59,7 +60,7 @@ func New(config Config) (*Service, error) { } config.Logger = config.Logger.WithGroup("notification") - eventHandler, err := notification.NewEventHandler(notification.EventHandlerConfig{ + eventHandler, err := eventhandler.New(eventhandler.Config{ Repository: config.Repository, Webhook: config.Webhook, Logger: config.Logger,