diff --git a/cmd/root.go b/cmd/root.go index 9169dfb..6e73409 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -7,6 +7,7 @@ import ( "os/signal" "strings" "syscall" + "time" "github.com/rs/zerolog" "github.com/rs/zerolog/log" @@ -65,6 +66,12 @@ var ( Short: "Start the worker with webhook sink", Run: runWorker, } + + replayCmd = &cobra.Command{ + Use: "replay-worker", + Short: "Replay historical changes", + Run: runReplay, + } ) func Execute() error { @@ -94,6 +101,9 @@ func init() { replicatorCmd.Flags().Bool("copy-and-stream", false, "Enable copy and stream mode (env: PG_FLO_COPY_AND_STREAM)") replicatorCmd.Flags().Int("max-copy-workers-per-table", 4, "Maximum number of copy workers per table (env: PG_FLO_MAX_COPY_WORKERS_PER_TABLE)") replicatorCmd.Flags().Bool("track-ddl", false, "Enable tracking of DDL changes (env: PG_FLO_TRACK_DDL)") + replicatorCmd.Flags().Duration("max-age", 24*time.Hour, "Maximum age of messages to retain (env: PG_FLO_MAX_AGE)") + replicatorCmd.Flags().Bool("enable-replays", false, "Enable replay capability - requires more storage (env: PG_FLO_ENABLE_REPLAYS)") + replicatorCmd.Flags().Int("nats-replicas", 1, "Number of stream replicas (env: PG_FLO_NATS_REPLICAS)") markFlagRequired(replicatorCmd, "host", "port", "dbname", "user", "password", "group", "nats-url") @@ -133,7 +143,17 @@ func init() { // Add subcommands to worker command workerCmd.AddCommand(stdoutWorkerCmd, fileWorkerCmd, postgresWorkerCmd, webhookWorkerCmd) - rootCmd.AddCommand(replicatorCmd, workerCmd) + // Add replay command + replayCmd.Flags().String("start-time", "", "Start time for replay (RFC3339 format)") + replayCmd.Flags().String("end-time", "", "End time for replay (RFC3339 format)") + replayCmd.Flags().String("group", "", "Group name for worker (env: PG_FLO_GROUP)") + replayCmd.Flags().String("nats-url", "", "NATS server URL (env: PG_FLO_NATS_URL)") + replayCmd.Flags().String("rules-config", "", "Path to rules configuration file (env: PG_FLO_RULES_CONFIG)") + replayCmd.Flags().String("routing-config", "", "Path to routing configuration file (env: PG_FLO_ROUTING_CONFIG)") + + markFlagRequired(replayCmd, "start-time", "end-time", "group", "nats-url") + + rootCmd.AddCommand(replicatorCmd, workerCmd, replayCmd) } func initConfig() { @@ -157,6 +177,7 @@ func initConfig() { bindFlags(fileWorkerCmd) bindFlags(postgresWorkerCmd) bindFlags(webhookWorkerCmd) + bindFlags(replayCmd) if err := viper.ReadInConfig(); err == nil { fmt.Println("Using config file:", viper.ConfigFileUsed()) @@ -192,7 +213,13 @@ func runReplicator(_ *cobra.Command, _ []string) { log.Fatal().Msg("NATS URL is required") } - natsClient, err := pgflonats.NewNATSClient(natsURL, fmt.Sprintf("pgflo_%s_stream", config.Group), config.Group) + natsConfig := pgflonats.StreamConfig{ + MaxAge: viper.GetDuration("max-age"), + Replays: viper.GetBool("enable-replays"), + Replicas: viper.GetInt("nats-replicas"), + } + + natsClient, err := pgflonats.NewNATSClient(natsURL, fmt.Sprintf("pgflo_%s_stream", config.Group), config.Group, natsConfig) if err != nil { log.Fatal().Err(err).Msg("Failed to create NATS client") } @@ -222,7 +249,8 @@ func runWorker(cmd *cobra.Command, _ []string) { sinkType := cmd.Use // Create NATS client - natsClient, err := pgflonats.NewNATSClient(natsURL, fmt.Sprintf("pgflo_%s_stream", group), group) + natsConfig := pgflonats.DefaultStreamConfig() + natsClient, err := pgflonats.NewNATSClient(natsURL, fmt.Sprintf("pgflo_%s_stream", group), group, natsConfig) if err != nil { log.Fatal().Err(err).Msg("Failed to create NATS client") } @@ -352,3 +380,69 @@ func markPersistentFlagRequired(cmd *cobra.Command, flags ...string) { } } } + +func runReplay(cmd *cobra.Command, _ []string) { + startTimeStr := viper.GetString("start-time") + endTimeStr := viper.GetString("end-time") + + startTime, err := time.Parse(time.RFC3339, startTimeStr) + if err != nil { + log.Fatal().Err(err).Msg("Invalid start time format") + } + + endTime, err := time.Parse(time.RFC3339, endTimeStr) + if err != nil { + log.Fatal().Err(err).Msg("Invalid end time format") + } + + // Create NATS client with replay config + natsConfig := pgflonats.StreamConfig{ + MaxAge: viper.GetDuration("max-age"), + Replays: viper.GetBool("enable-replays"), + } + + natsClient, err := pgflonats.NewNATSClient( + viper.GetString("nats-url"), + fmt.Sprintf("pgflo_%s_stream", viper.GetString("group")), + viper.GetString("group"), + natsConfig, + ) + if err != nil { + log.Fatal().Err(err).Msg("Failed to create NATS client") + } + + ruleEngine := rules.NewRuleEngine() + if viper.GetString("rules-config") != "" { + rulesConfig, err := loadRulesConfig(viper.GetString("rules-config")) + if err != nil { + log.Fatal().Err(err).Msg("Failed to load rules configuration") + } + if err := ruleEngine.LoadRules(rulesConfig); err != nil { + log.Fatal().Err(err).Msg("Failed to load rules") + } + } + + router := routing.NewRouter() + if viper.GetString("routing-config") != "" { + routingConfig, err := loadRoutingConfig(viper.GetString("routing-config")) + if err != nil { + log.Fatal().Err(err).Msg("Failed to load routing configuration") + } + if err := router.LoadRoutes(routingConfig); err != nil { + log.Fatal().Err(err).Msg("Failed to load routes") + } + } + + sink, err := createSink(cmd.Use) + if err != nil { + log.Fatal().Err(err).Msg("Failed to create sink") + } + + w := worker.NewWorker(natsClient, ruleEngine, router, sink, viper.GetString("group")) + + replayWorker := worker.NewReplayWorker(w, startTime, endTime) + + if err := replayWorker.Start(cmd.Context()); err != nil { + log.Fatal().Err(err).Msg("Failed to start replay worker") + } +} diff --git a/pkg/pgflonats/pgflonats.go b/pkg/pgflonats/pgflonats.go index b0f2d1c..98e5eb1 100644 --- a/pkg/pgflonats/pgflonats.go +++ b/pkg/pgflonats/pgflonats.go @@ -14,14 +14,35 @@ import ( const ( defaultNATSURL = "nats://localhost:4222" envNATSURL = "PG_FLO_NATS_URL" + + defaultMaxAge = 24 * time.Hour + defaultReplicas = 1 + defaultReplays = false ) +// StreamConfig represents the configuration for a NATS stream +type StreamConfig struct { + MaxAge time.Duration + Replays bool + Replicas int +} + +// DefaultStreamConfig returns a StreamConfig with default values +func DefaultStreamConfig() StreamConfig { + return StreamConfig{ + MaxAge: defaultMaxAge, + Replays: defaultReplays, + Replicas: defaultReplicas, + } +} + // NATSClient represents a client for interacting with NATS type NATSClient struct { conn *nats.Conn js nats.JetStreamContext stream string stateBucket string + config StreamConfig } // State represents the current state of the replication process @@ -31,7 +52,15 @@ type State struct { } // NewNATSClient creates a new NATS client with the specified configuration, setting up the connection, main stream, and state bucket. -func NewNATSClient(url, stream, group string) (*NATSClient, error) { +func NewNATSClient(url, stream, group string, config StreamConfig) (*NATSClient, error) { + if config.MaxAge == 0 { + config.MaxAge = defaultMaxAge + } + + if config.Replicas == 0 { + config.Replicas = defaultReplicas + } + if url == "" { url = os.Getenv(envNATSURL) if url == "" { @@ -66,13 +95,21 @@ func NewNATSClient(url, stream, group string) (*NATSClient, error) { return nil, fmt.Errorf("failed to create JetStream context: %w", err) } - // Create the main stream + // Create the main stream with configurable retention streamConfig := &nats.StreamConfig{ - Name: stream, - Subjects: []string{fmt.Sprintf("pgflo.%s", group)}, - Storage: nats.FileStorage, - Retention: nats.LimitsPolicy, - MaxAge: 24 * time.Hour, + Name: stream, + Storage: nats.FileStorage, + Retention: nats.LimitsPolicy, + MaxAge: config.MaxAge, + Replicas: config.Replicas, + Discard: nats.DiscardOld, + Subjects: []string{fmt.Sprintf("pgflo.%s", group)}, + Description: fmt.Sprintf("pg_flo stream for group %s", group), + } + if config.Replays { + streamConfig.Retention = nats.WorkQueuePolicy + streamConfig.MaxMsgs = -1 + } _, err = js.AddStream(streamConfig) if err != nil && !errors.Is(err, nats.ErrStreamNameAlreadyInUse) { @@ -100,6 +137,7 @@ func NewNATSClient(url, stream, group string) (*NATSClient, error) { js: js, stream: stream, stateBucket: stateBucket, + config: config, }, nil } diff --git a/pkg/worker/replay_worker.go b/pkg/worker/replay_worker.go new file mode 100644 index 0000000..95723c2 --- /dev/null +++ b/pkg/worker/replay_worker.go @@ -0,0 +1,143 @@ +package worker + +import ( + "context" + "fmt" + "time" + + "github.com/nats-io/nats.go" +) + +type ReplayWorker struct { + *Worker + startTime time.Time + endTime time.Time +} + +func NewReplayWorker(w *Worker, startTime, endTime time.Time) *ReplayWorker { + return &ReplayWorker{ + Worker: w, + startTime: startTime, + endTime: endTime, + } +} + +func (w *ReplayWorker) Start(ctx context.Context) error { + stream := fmt.Sprintf("pgflo_%s_stream", w.group) + subject := fmt.Sprintf("pgflo.%s", w.group) + + w.logger.Info(). + Str("stream", stream). + Str("subject", subject). + Str("group", w.group). + Time("start_time", w.startTime). + Time("end_time", w.endTime). + Msg("Starting replay worker") + + js := w.natsClient.JetStream() + + // Create unique consumer name for replay + // TODO - uniq? + consumerName := fmt.Sprintf("pgflo_%s_replay_%d", w.group, time.Now().UnixNano()) + + consumerConfig := &nats.ConsumerConfig{ + Durable: consumerName, + FilterSubject: subject, + AckPolicy: nats.AckExplicitPolicy, + DeliverPolicy: nats.DeliverByStartTimePolicy, + OptStartTime: &w.startTime, + ReplayPolicy: nats.ReplayOriginalPolicy, + MaxDeliver: 1, // Only deliver once + AckWait: 30 * time.Second, + MaxAckPending: w.batchSize, + } + + _, err := js.AddConsumer(stream, consumerConfig) + if err != nil { + w.logger.Error().Err(err).Msg("Failed to add replay consumer") + return fmt.Errorf("failed to add replay consumer: %w", err) + } + + // Cleanup consumer when done + defer func() { + if err := js.DeleteConsumer(stream, consumerName); err != nil { + w.logger.Error().Err(err).Msg("Failed to delete replay consumer") + } + }() + + sub, err := js.PullSubscribe(subject, consumerName) + if err != nil { + w.logger.Error().Err(err).Msg("Failed to subscribe to subject") + return fmt.Errorf("failed to subscribe to subject: %w", err) + } + + w.wg.Add(1) + go func() { + defer w.wg.Done() + if err := w.processReplayMessages(ctx, sub); err != nil && err != context.Canceled { + w.logger.Error().Err(err).Msg("Error processing replay messages") + } + }() + + <-ctx.Done() + w.logger.Info().Msg("Received shutdown signal. Initiating graceful shutdown...") + + w.wg.Wait() + w.logger.Debug().Msg("All goroutines finished") + + return w.flushBuffer() +} + +func (w *ReplayWorker) processReplayMessages(ctx context.Context, sub *nats.Subscription) error { + flushTicker := time.NewTicker(w.flushInterval) + defer flushTicker.Stop() + + for { + select { + case <-ctx.Done(): + w.logger.Info().Msg("Flushing remaining messages") + return w.flushBuffer() + case <-flushTicker.C: + if err := w.flushBuffer(); err != nil { + w.logger.Error().Err(err).Msg("Failed to flush buffer on interval") + } + default: + msgs, err := sub.Fetch(w.batchSize, nats.MaxWait(500*time.Millisecond)) + if err != nil { + if err == nats.ErrTimeout { + continue + } + w.logger.Error().Err(err).Msg("Error fetching messages") + continue + } + + for _, msg := range msgs { + metadata, err := msg.Metadata() + if err != nil { + w.logger.Error().Err(err).Msg("Failed to get message metadata") + continue + } + + // Check if message is within time range + // TODO - alternative? + if metadata.Timestamp.After(w.endTime) { + w.logger.Info().Msg("Reached end time, finishing replay") + return w.flushBuffer() + } + + if err := w.processMessage(msg); err != nil { + w.logger.Error().Err(err).Msg("Failed to process message") + } + if err := msg.Ack(); err != nil { + w.logger.Error().Err(err).Msg("Failed to acknowledge message") + } + } + + if len(w.buffer) >= w.batchSize { + if err := w.flushBuffer(); err != nil { + w.logger.Error().Err(err).Msg("Failed to flush buffer") + } + } + } + } +}