diff --git a/app/common/billing.go b/app/common/billing.go index 6d65beb97..d1e8f5bf2 100644 --- a/app/common/billing.go +++ b/app/common/billing.go @@ -13,6 +13,7 @@ import ( billingservice "github.com/openmeterio/openmeter/openmeter/billing/service" billingsubscription "github.com/openmeterio/openmeter/openmeter/billing/subscription" billingworkerautoadvance "github.com/openmeterio/openmeter/openmeter/billing/worker/advance" + billingworkercollect "github.com/openmeterio/openmeter/openmeter/billing/worker/collect" "github.com/openmeterio/openmeter/openmeter/customer" entdb "github.com/openmeterio/openmeter/openmeter/ent/db" "github.com/openmeterio/openmeter/openmeter/meter" @@ -82,3 +83,10 @@ func NewBillingAutoAdvancer(logger *slog.Logger, service billing.Service) (*bill Logger: logger, }) } + +func NewBillingCollector(logger *slog.Logger, service billing.Service) (*billingworkercollect.InvoiceCollector, error) { + return billingworkercollect.NewInvoiceCollector(billingworkercollect.Config{ + BillingService: service, + Logger: logger, + }) +} diff --git a/cmd/jobs/billing/billing.go b/cmd/jobs/billing/billing.go index 268bc39e1..d24505953 100644 --- a/cmd/jobs/billing/billing.go +++ b/cmd/jobs/billing/billing.go @@ -4,6 +4,7 @@ import ( "github.com/spf13/cobra" "github.com/openmeterio/openmeter/cmd/jobs/billing/advance" + "github.com/openmeterio/openmeter/cmd/jobs/billing/collect" ) var Cmd = &cobra.Command{ @@ -13,4 +14,5 @@ var Cmd = &cobra.Command{ func init() { Cmd.AddCommand(advance.Cmd) + Cmd.AddCommand(collect.Cmd) } diff --git a/cmd/jobs/billing/collect/collect.go b/cmd/jobs/billing/collect/collect.go new file mode 100644 index 000000000..ed9ccdd48 --- /dev/null +++ b/cmd/jobs/billing/collect/collect.go @@ -0,0 +1,102 @@ +package collect + +import ( + "fmt" + "time" + + "github.com/spf13/cobra" + + "github.com/openmeterio/openmeter/cmd/jobs/internal" + billingworkercollect "github.com/openmeterio/openmeter/openmeter/billing/worker/collect" +) + +var ( + namespaces []string + customerIDs []string + invoiceIDs []string +) + +var Cmd = &cobra.Command{ + Use: "collect", + Short: "Invoice collection operations", +} + +func init() { + Cmd.AddCommand(ListCmd()) + Cmd.AddCommand(InvoiceCmd()) + Cmd.AddCommand(AllCmd()) +} + +var ListCmd = func() *cobra.Command { + cmd := &cobra.Command{ + Use: "list", + Short: "List gathering invoices which can be collected", + RunE: func(cmd *cobra.Command, args []string) error { + invoices, err := internal.App.BillingCollector.ListCollectableInvoices(cmd.Context(), + billingworkercollect.ListCollectableInvoicesInput{ + Namespaces: namespaces, + InvoiceIDs: invoiceIDs, + Customers: customerIDs, + CollectionAt: time.Now(), + }) + if err != nil { + return err + } + + for _, invoice := range invoices { + fmt.Printf("Namespace: %s ID: %s CollectAt: %s\n", invoice.Namespace, invoice.ID, invoice.CollectionAt) + } + + return nil + }, + } + + cmd.PersistentFlags().StringSliceVar(&namespaces, "n", nil, "filter by namespaces") + cmd.PersistentFlags().StringSliceVar(&customerIDs, "c", nil, "filter by customer ids") + cmd.PersistentFlags().StringSliceVar(&invoiceIDs, "i", nil, "filter by invoice ids") + + return cmd +} + +var InvoiceCmd = func() *cobra.Command { + cmd := &cobra.Command{ + Use: "invoice [CUSTOMER_ID]", + Short: "Create new invoice(s) for customer(s)", + Args: cobra.MinimumNArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + for _, customerID := range args { + _, err := internal.App.BillingCollector.CollectCustomerInvoice(cmd.Context(), + billingworkercollect.CollectCustomerInvoiceInput{ + CustomerID: customerID, + AsOf: nil, + }, + ) + if err != nil { + return fmt.Errorf("failed to advance invoice %s: %w", customerID, err) + } + } + + return nil + }, + } + + return cmd +} + +var batchSize int + +var AllCmd = func() *cobra.Command { + cmd := &cobra.Command{ + Use: "all", + Short: "Advance all eligible invoices", + RunE: func(cmd *cobra.Command, args []string) error { + return internal.App.BillingCollector.All(cmd.Context(), namespaces, customerIDs, batchSize) + }, + } + + cmd.PersistentFlags().StringSliceVar(&namespaces, "n", nil, "filter by namespaces") + cmd.PersistentFlags().StringSliceVar(&customerIDs, "c", nil, "filter by customer ids") + cmd.PersistentFlags().IntVar(&batchSize, "batch", 0, "operation batch size") + + return cmd +} diff --git a/cmd/jobs/internal/wire.go b/cmd/jobs/internal/wire.go index 4f63fbcbb..97bad9535 100644 --- a/cmd/jobs/internal/wire.go +++ b/cmd/jobs/internal/wire.go @@ -17,6 +17,7 @@ import ( appstripe "github.com/openmeterio/openmeter/openmeter/app/stripe" "github.com/openmeterio/openmeter/openmeter/billing" billingworkerautoadvance "github.com/openmeterio/openmeter/openmeter/billing/worker/advance" + billingworkercollect "github.com/openmeterio/openmeter/openmeter/billing/worker/collect" "github.com/openmeterio/openmeter/openmeter/customer" "github.com/openmeterio/openmeter/openmeter/ent/db" "github.com/openmeterio/openmeter/openmeter/meter" @@ -40,6 +41,7 @@ type Application struct { Customer customer.Service Billing billing.Service BillingAutoAdvancer *billingworkerautoadvance.AutoAdvancer + BillingCollector *billingworkercollect.InvoiceCollector EntClient *db.Client EventPublisher eventbus.Publisher EntitlementRegistry *registry.Entitlement @@ -73,6 +75,7 @@ func initializeApplication(ctx context.Context, conf config.Configuration) (Appl common.MeterInMemory, common.Namespace, common.NewBillingAutoAdvancer, + common.NewBillingCollector, common.NewDefaultTextMapPropagator, common.NewServerPublisher, common.Streaming, diff --git a/cmd/jobs/internal/wire_gen.go b/cmd/jobs/internal/wire_gen.go index c023349c9..19113ab4a 100644 --- a/cmd/jobs/internal/wire_gen.go +++ b/cmd/jobs/internal/wire_gen.go @@ -15,6 +15,7 @@ import ( "github.com/openmeterio/openmeter/openmeter/app/stripe" "github.com/openmeterio/openmeter/openmeter/billing" "github.com/openmeterio/openmeter/openmeter/billing/worker/advance" + "github.com/openmeterio/openmeter/openmeter/billing/worker/collect" "github.com/openmeterio/openmeter/openmeter/customer" "github.com/openmeterio/openmeter/openmeter/ent/db" "github.com/openmeterio/openmeter/openmeter/meter" @@ -268,6 +269,16 @@ func initializeApplication(ctx context.Context, conf config.Configuration) (Appl cleanup() return Application{}, nil, err } + invoiceCollector, err := common.NewBillingCollector(logger, billingService) + if err != nil { + cleanup6() + cleanup5() + cleanup4() + cleanup3() + cleanup2() + cleanup() + return Application{}, nil, err + } producer, err := common.NewKafkaProducer(kafkaIngestConfiguration, logger) if err != nil { cleanup6() @@ -328,6 +339,7 @@ func initializeApplication(ctx context.Context, conf config.Configuration) (Appl Customer: customerService, Billing: billingService, BillingAutoAdvancer: autoAdvancer, + BillingCollector: invoiceCollector, EntClient: client, EventPublisher: eventbusPublisher, EntitlementRegistry: entitlement, @@ -366,6 +378,7 @@ type Application struct { Customer customer.Service Billing billing.Service BillingAutoAdvancer *billingworkeradvance.AutoAdvancer + BillingCollector *billingworkercollect.InvoiceCollector EntClient *db.Client EventPublisher eventbus.Publisher EntitlementRegistry *registry.Entitlement diff --git a/openmeter/billing/service/collectionat.go b/openmeter/billing/service/collectionat.go index 6f3ca45c4..34286febd 100644 --- a/openmeter/billing/service/collectionat.go +++ b/openmeter/billing/service/collectionat.go @@ -17,10 +17,33 @@ func UpdateInvoiceCollectionAt(invoice *billing.Invoice, collection billing.Coll return false } + invoiceAt := GetEarliestValidInvoiceAt(invoice.Lines) + + if invoiceAt.IsZero() { + return false + } + + interval, ok := collection.Interval.Duration() + if !ok { + return false + } + + collectionAt := invoiceAt.Add(interval) + + if lo.FromPtr(invoice.CollectionAt).Equal(collectionAt) { + return false + } + + invoice.CollectionAt = &collectionAt + + return true +} + +func GetEarliestValidInvoiceAt(lines billing.LineChildren) time.Time { var invoiceAt time.Time // Find the invoice lint with the earliest invoiceAt attribute - invoice.Lines.ForEach(func(v []*billing.Line) { + lines.ForEach(func(v []*billing.Line) { for _, line := range v { if line == nil || line.Status != billing.InvoiceLineStatusValid { continue @@ -41,22 +64,19 @@ func UpdateInvoiceCollectionAt(invoice *billing.Invoice, collection billing.Coll } }) - if invoiceAt.IsZero() { - return false - } + return invoiceAt +} - interval, ok := collection.Interval.Duration() - if !ok { - return false - } +func GetInvoiceWithEarliestCollectionAt(invoices []billing.Invoice) billing.Invoice { + var idx int - collectionAt := invoiceAt.Add(interval) + collectAt := time.Now() - if lo.FromPtr(invoice.CollectionAt).Equal(collectionAt) { - return false + for i, invoice := range invoices { + if invoice.CollectionAt.Before(collectAt) { + idx = i + } } - invoice.CollectionAt = &collectionAt - - return true + return invoices[idx] } diff --git a/openmeter/billing/service/invoice.go b/openmeter/billing/service/invoice.go index 9084d72fe..c7736be38 100644 --- a/openmeter/billing/service/invoice.go +++ b/openmeter/billing/service/invoice.go @@ -334,23 +334,78 @@ func (s *Service) InvoicePendingLines(ctx context.Context, input billing.Invoice return nil, fmt.Errorf("cleanup: line counts check: %w", err) } - invoicesWithoutLines := lo.Filter(sourceInvoiceIDs, func(id string, _ int) bool { - return invoiceLineCounts.Counts[billing.InvoiceID{ + // Collect gathering invoices which can be deleted and which needs to have their collectionAt updated + // due to still having live items. + liveGatheringInvoiceIDs := make([]string, 0, len(sourceInvoiceIDs)) + emptyGatheringInvoiceIDs := make([]string, 0, len(sourceInvoiceIDs)) + + for _, invoiceID := range sourceInvoiceIDs { + invoiceNamespacedID := billing.InvoiceID{ Namespace: input.Customer.Namespace, - ID: id, - }] == 0 - }) + ID: invoiceID, + } + + if invoiceLineCounts.Counts[invoiceNamespacedID] == 0 { + emptyGatheringInvoiceIDs = append(emptyGatheringInvoiceIDs, invoiceID) + } else { + liveGatheringInvoiceIDs = append(liveGatheringInvoiceIDs, invoiceID) + } + } - if len(invoicesWithoutLines) > 0 { + // Delete empty gathering invoices + if len(emptyGatheringInvoiceIDs) > 0 { err = s.adapter.DeleteInvoices(ctx, billing.DeleteInvoicesAdapterInput{ Namespace: input.Customer.Namespace, - InvoiceIDs: invoicesWithoutLines, + InvoiceIDs: emptyGatheringInvoiceIDs, }) if err != nil { return nil, fmt.Errorf("cleanup invoices: %w", err) } } + // Update collectionAt for live gathering invoices + if len(liveGatheringInvoiceIDs) > 0 { + resp, err := s.ListInvoices(ctx, billing.ListInvoicesInput{ + Customers: []string{input.Customer.ID}, + IDs: liveGatheringInvoiceIDs, + ExtendedStatuses: []billing.InvoiceStatus{billing.InvoiceStatusGathering}, + Expand: billing.InvoiceExpand{ + Lines: true, + }, + }) + if err != nil { + return nil, fmt.Errorf("failed to get gathering invoice(s) for customer [customer=%s]: %w", + input.Customer.ID, err, + ) + } + + for _, invoice := range resp.Items { + collectionAt := invoice.CollectionAt + if ok := UpdateInvoiceCollectionAt(&invoice, customerProfile.Profile.WorkflowConfig.Collection); ok { + s.logger.DebugContext(ctx, "collection time updated for invoice", + "invoiceID", invoice.ID, + "collectionAt", map[string]interface{}{ + "from": lo.FromPtr(collectionAt), + "to": lo.FromPtr(invoice.CollectionAt), + "collectionInterval": customerProfile.Profile.WorkflowConfig.Collection.Interval.String(), + }, + ) + } + + if err = invoice.Validate(); err != nil { + return nil, billing.ValidationError{ + Err: err, + } + } + + if _, err = s.updateInvoice(ctx, invoice); err != nil { + return nil, fmt.Errorf("failed to update gathering invoice [namespace=%s invoice=%s, customer=%s]: %w", + input.Customer.Namespace, invoice.ID, input.Customer.ID, err, + ) + } + } + } + // Assemble output: we need to refetch as the association call will have side-effects of updating // invoice objects (e.g. totals, period, etc.) out := make([]billing.Invoice, 0, len(createdInvoices)) diff --git a/openmeter/billing/worker/collect/collect.go b/openmeter/billing/worker/collect/collect.go new file mode 100644 index 000000000..9206ee2b2 --- /dev/null +++ b/openmeter/billing/worker/collect/collect.go @@ -0,0 +1,205 @@ +package billingworkercollect + +import ( + "context" + "errors" + "fmt" + "log/slog" + "sync" + "time" + + "github.com/samber/lo" + + "github.com/openmeterio/openmeter/openmeter/billing" + billingservice "github.com/openmeterio/openmeter/openmeter/billing/service" + customerentity "github.com/openmeterio/openmeter/openmeter/customer/entity" +) + +type InvoiceCollector struct { + billing billing.Service + + logger *slog.Logger +} + +type ListCollectableInvoicesInput struct { + Namespaces []string + InvoiceIDs []string + Customers []string + CollectionAt time.Time +} + +func (i ListCollectableInvoicesInput) Validate() error { + var errs []error + + if i.CollectionAt.IsZero() { + errs = append(errs, fmt.Errorf("collectionAt time must not be zero")) + } + + return errors.Join(errs...) +} + +func (a *InvoiceCollector) ListCollectableInvoices(ctx context.Context, params ListCollectableInvoicesInput) ([]billing.Invoice, error) { + if err := params.Validate(); err != nil { + return nil, fmt.Errorf("invalid input: %w", err) + } + + resp, err := a.billing.ListInvoices(ctx, billing.ListInvoicesInput{ + Namespaces: params.Namespaces, + IDs: params.InvoiceIDs, + Customers: params.Customers, + CollectionAt: lo.ToPtr(params.CollectionAt), + ExtendedStatuses: []billing.InvoiceStatus{billing.InvoiceStatusGathering}, + }) + if err != nil { + return nil, fmt.Errorf("failed to list collectable invoices: %w", err) + } + + return resp.Items, nil +} + +type CollectCustomerInvoiceInput struct { + CustomerID string + AsOf *time.Time +} + +func (i CollectCustomerInvoiceInput) Validate() error { + var errs []error + + if i.CustomerID == "" { + errs = append(errs, fmt.Errorf("customer id must not be empty")) + } + + if i.AsOf != nil && i.AsOf.IsZero() { + errs = append(errs, fmt.Errorf("asOf time must not be zero")) + } + + return errors.Join(errs...) +} + +func (a *InvoiceCollector) CollectCustomerInvoice(ctx context.Context, params CollectCustomerInvoiceInput) ([]billing.Invoice, error) { + if err := params.Validate(); err != nil { + return nil, fmt.Errorf("invalid input: %w", err) + } + + resp, err := a.billing.ListInvoices(ctx, billing.ListInvoicesInput{ + Customers: []string{params.CustomerID}, + ExtendedStatuses: []billing.InvoiceStatus{billing.InvoiceStatusGathering}, + Expand: billing.InvoiceExpand{ + Lines: true, + }, + }) + if err != nil { + return nil, fmt.Errorf("failed to get gathering invoice(s) for customer [customer=%s]: %w", params.CustomerID, err) + } + + if len(resp.Items) == 0 { + return nil, nil + } + + invoice := resp.Items[0] + if params.AsOf == nil || params.AsOf.IsZero() { + invoice = billingservice.GetInvoiceWithEarliestCollectionAt(resp.Items) + params.AsOf = lo.ToPtr(billingservice.GetEarliestValidInvoiceAt(invoice.Lines)) + } + + a.logger.DebugContext(ctx, "collecting customer invoices", "customer", params.CustomerID, "asOf", params.AsOf) + + invoices, err := a.billing.InvoicePendingLines(ctx, billing.InvoicePendingLinesInput{ + Customer: customerentity.CustomerID{ + Namespace: invoice.Namespace, + ID: invoice.Customer.CustomerID, + }, + AsOf: params.AsOf, + }) + if err != nil { + return nil, fmt.Errorf("failed to create invoice(s) for customer [customer=%s]: %w", params.CustomerID, err) + } + + return invoices, nil +} + +// All runs invoice collection for all customers +func (a *InvoiceCollector) All(ctx context.Context, namespaces []string, customerIDs []string, batchSize int) error { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + a.logger.InfoContext(ctx, "listing invoices waiting for collection") + + invoices, err := a.ListCollectableInvoices(ctx, ListCollectableInvoicesInput{ + Namespaces: namespaces, + Customers: customerIDs, + CollectionAt: time.Now(), + }) + if err != nil { + return fmt.Errorf("failed to list invoices to collect: %w", err) + } + + if len(invoices) == 0 { + return nil + } + + batches := [][]billing.Invoice{ + invoices, + } + if batchSize > 0 { + batches = lo.Chunk(invoices, batchSize) + } + + a.logger.DebugContext(ctx, "found invoices to collect", "count", len(invoices), "batchSize", batchSize) + + errChan := make(chan error, len(invoices)) + closeErrChan := sync.OnceFunc(func() { + close(errChan) + }) + defer closeErrChan() + + for _, batch := range batches { + var wg sync.WaitGroup + for _, invoice := range batch { + wg.Add(1) + + go func() { + defer wg.Done() + + _, err = a.CollectCustomerInvoice(ctx, CollectCustomerInvoiceInput{ + CustomerID: invoice.Customer.CustomerID, + }) + if err != nil { + err = fmt.Errorf("failed to collect invoice for customer [namespace=%s invoice=%s customer=%s]: %w", invoice.Namespace, invoice.ID, invoice.Customer.CustomerID, err) + } + + errChan <- err + }() + } + + wg.Wait() + } + closeErrChan() + + var errs []error + for err = range errChan { + errs = append(errs, err) + } + + return errors.Join(errs...) +} + +type Config struct { + BillingService billing.Service + Logger *slog.Logger +} + +func NewInvoiceCollector(config Config) (*InvoiceCollector, error) { + if config.BillingService == nil { + return nil, fmt.Errorf("billing service is required") + } + + if config.Logger == nil { + return nil, fmt.Errorf("logger is required") + } + + return &InvoiceCollector{ + billing: config.BillingService, + logger: config.Logger, + }, nil +}