Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ability to replay messages between time frame with new or existing rules + routes #22

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 97 additions & 3 deletions cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"os/signal"
"strings"
"syscall"
"time"

"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
Expand Down Expand Up @@ -65,6 +66,12 @@ var (
Short: "Start the worker with webhook sink",
Run: runWorker,
}

replayCmd = &cobra.Command{
Use: "replay-worker",
Short: "Replay historical changes",
Run: runReplay,
}
)

func Execute() error {
Expand Down Expand Up @@ -94,6 +101,9 @@ func init() {
replicatorCmd.Flags().Bool("copy-and-stream", false, "Enable copy and stream mode (env: PG_FLO_COPY_AND_STREAM)")
replicatorCmd.Flags().Int("max-copy-workers-per-table", 4, "Maximum number of copy workers per table (env: PG_FLO_MAX_COPY_WORKERS_PER_TABLE)")
replicatorCmd.Flags().Bool("track-ddl", false, "Enable tracking of DDL changes (env: PG_FLO_TRACK_DDL)")
replicatorCmd.Flags().Duration("max-age", 24*time.Hour, "Maximum age of messages to retain (env: PG_FLO_MAX_AGE)")
replicatorCmd.Flags().Bool("enable-replays", false, "Enable replay capability - requires more storage (env: PG_FLO_ENABLE_REPLAYS)")
replicatorCmd.Flags().Int("nats-replicas", 1, "Number of stream replicas (env: PG_FLO_NATS_REPLICAS)")

markFlagRequired(replicatorCmd, "host", "port", "dbname", "user", "password", "group", "nats-url")

Expand Down Expand Up @@ -133,7 +143,17 @@ func init() {
// Add subcommands to worker command
workerCmd.AddCommand(stdoutWorkerCmd, fileWorkerCmd, postgresWorkerCmd, webhookWorkerCmd)

rootCmd.AddCommand(replicatorCmd, workerCmd)
// Add replay command
replayCmd.Flags().String("start-time", "", "Start time for replay (RFC3339 format)")
replayCmd.Flags().String("end-time", "", "End time for replay (RFC3339 format)")
replayCmd.Flags().String("group", "", "Group name for worker (env: PG_FLO_GROUP)")
replayCmd.Flags().String("nats-url", "", "NATS server URL (env: PG_FLO_NATS_URL)")
replayCmd.Flags().String("rules-config", "", "Path to rules configuration file (env: PG_FLO_RULES_CONFIG)")
replayCmd.Flags().String("routing-config", "", "Path to routing configuration file (env: PG_FLO_ROUTING_CONFIG)")

markFlagRequired(replayCmd, "start-time", "end-time", "group", "nats-url")

rootCmd.AddCommand(replicatorCmd, workerCmd, replayCmd)
}

func initConfig() {
Expand All @@ -157,6 +177,7 @@ func initConfig() {
bindFlags(fileWorkerCmd)
bindFlags(postgresWorkerCmd)
bindFlags(webhookWorkerCmd)
bindFlags(replayCmd)

if err := viper.ReadInConfig(); err == nil {
fmt.Println("Using config file:", viper.ConfigFileUsed())
Expand Down Expand Up @@ -192,7 +213,13 @@ func runReplicator(_ *cobra.Command, _ []string) {
log.Fatal().Msg("NATS URL is required")
}

natsClient, err := pgflonats.NewNATSClient(natsURL, fmt.Sprintf("pgflo_%s_stream", config.Group), config.Group)
natsConfig := pgflonats.StreamConfig{
MaxAge: viper.GetDuration("max-age"),
Replays: viper.GetBool("enable-replays"),
Replicas: viper.GetInt("nats-replicas"),
}

natsClient, err := pgflonats.NewNATSClient(natsURL, fmt.Sprintf("pgflo_%s_stream", config.Group), config.Group, natsConfig)
if err != nil {
log.Fatal().Err(err).Msg("Failed to create NATS client")
}
Expand Down Expand Up @@ -222,7 +249,8 @@ func runWorker(cmd *cobra.Command, _ []string) {
sinkType := cmd.Use

// Create NATS client
natsClient, err := pgflonats.NewNATSClient(natsURL, fmt.Sprintf("pgflo_%s_stream", group), group)
natsConfig := pgflonats.DefaultStreamConfig()
natsClient, err := pgflonats.NewNATSClient(natsURL, fmt.Sprintf("pgflo_%s_stream", group), group, natsConfig)
if err != nil {
log.Fatal().Err(err).Msg("Failed to create NATS client")
}
Expand Down Expand Up @@ -352,3 +380,69 @@ func markPersistentFlagRequired(cmd *cobra.Command, flags ...string) {
}
}
}

func runReplay(cmd *cobra.Command, _ []string) {
startTimeStr := viper.GetString("start-time")
endTimeStr := viper.GetString("end-time")

startTime, err := time.Parse(time.RFC3339, startTimeStr)
if err != nil {
log.Fatal().Err(err).Msg("Invalid start time format")
}

endTime, err := time.Parse(time.RFC3339, endTimeStr)
if err != nil {
log.Fatal().Err(err).Msg("Invalid end time format")
}

// Create NATS client with replay config
natsConfig := pgflonats.StreamConfig{
MaxAge: viper.GetDuration("max-age"),
Replays: viper.GetBool("enable-replays"),
}

natsClient, err := pgflonats.NewNATSClient(
viper.GetString("nats-url"),
fmt.Sprintf("pgflo_%s_stream", viper.GetString("group")),
viper.GetString("group"),
natsConfig,
)
if err != nil {
log.Fatal().Err(err).Msg("Failed to create NATS client")
}

ruleEngine := rules.NewRuleEngine()
if viper.GetString("rules-config") != "" {
rulesConfig, err := loadRulesConfig(viper.GetString("rules-config"))
if err != nil {
log.Fatal().Err(err).Msg("Failed to load rules configuration")
}
if err := ruleEngine.LoadRules(rulesConfig); err != nil {
log.Fatal().Err(err).Msg("Failed to load rules")
}
}

router := routing.NewRouter()
if viper.GetString("routing-config") != "" {
routingConfig, err := loadRoutingConfig(viper.GetString("routing-config"))
if err != nil {
log.Fatal().Err(err).Msg("Failed to load routing configuration")
}
if err := router.LoadRoutes(routingConfig); err != nil {
log.Fatal().Err(err).Msg("Failed to load routes")
}
}

sink, err := createSink(cmd.Use)
if err != nil {
log.Fatal().Err(err).Msg("Failed to create sink")
}

w := worker.NewWorker(natsClient, ruleEngine, router, sink, viper.GetString("group"))

replayWorker := worker.NewReplayWorker(w, startTime, endTime)

if err := replayWorker.Start(cmd.Context()); err != nil {
log.Fatal().Err(err).Msg("Failed to start replay worker")
}
}
52 changes: 45 additions & 7 deletions pkg/pgflonats/pgflonats.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,35 @@ import (
const (
defaultNATSURL = "nats://localhost:4222"
envNATSURL = "PG_FLO_NATS_URL"

defaultMaxAge = 24 * time.Hour
defaultReplicas = 1
defaultReplays = false
)

// StreamConfig represents the configuration for a NATS stream
type StreamConfig struct {
MaxAge time.Duration
Replays bool
Replicas int
}

// DefaultStreamConfig returns a StreamConfig with default values
func DefaultStreamConfig() StreamConfig {
return StreamConfig{
MaxAge: defaultMaxAge,
Replays: defaultReplays,
Replicas: defaultReplicas,
}
}

// NATSClient represents a client for interacting with NATS
type NATSClient struct {
conn *nats.Conn
js nats.JetStreamContext
stream string
stateBucket string
config StreamConfig
}

// State represents the current state of the replication process
Expand All @@ -31,7 +52,15 @@ type State struct {
}

// NewNATSClient creates a new NATS client with the specified configuration, setting up the connection, main stream, and state bucket.
func NewNATSClient(url, stream, group string) (*NATSClient, error) {
func NewNATSClient(url, stream, group string, config StreamConfig) (*NATSClient, error) {
if config.MaxAge == 0 {
config.MaxAge = defaultMaxAge
}

if config.Replicas == 0 {
config.Replicas = defaultReplicas
}

if url == "" {
url = os.Getenv(envNATSURL)
if url == "" {
Expand Down Expand Up @@ -66,13 +95,21 @@ func NewNATSClient(url, stream, group string) (*NATSClient, error) {
return nil, fmt.Errorf("failed to create JetStream context: %w", err)
}

// Create the main stream
// Create the main stream with configurable retention
streamConfig := &nats.StreamConfig{
Name: stream,
Subjects: []string{fmt.Sprintf("pgflo.%s", group)},
Storage: nats.FileStorage,
Retention: nats.LimitsPolicy,
MaxAge: 24 * time.Hour,
Name: stream,
Storage: nats.FileStorage,
Retention: nats.LimitsPolicy,
MaxAge: config.MaxAge,
Replicas: config.Replicas,
Discard: nats.DiscardOld,
Subjects: []string{fmt.Sprintf("pgflo.%s", group)},
Description: fmt.Sprintf("pg_flo stream for group %s", group),
}
if config.Replays {
streamConfig.Retention = nats.WorkQueuePolicy
streamConfig.MaxMsgs = -1

}
_, err = js.AddStream(streamConfig)
if err != nil && !errors.Is(err, nats.ErrStreamNameAlreadyInUse) {
Expand Down Expand Up @@ -100,6 +137,7 @@ func NewNATSClient(url, stream, group string) (*NATSClient, error) {
js: js,
stream: stream,
stateBucket: stateBucket,
config: config,
}, nil
}

Expand Down
143 changes: 143 additions & 0 deletions pkg/worker/replay_worker.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
package worker

import (
"context"
"fmt"
"time"

"github.com/nats-io/nats.go"
)

type ReplayWorker struct {
*Worker
startTime time.Time
endTime time.Time
}

func NewReplayWorker(w *Worker, startTime, endTime time.Time) *ReplayWorker {
return &ReplayWorker{
Worker: w,
startTime: startTime,
endTime: endTime,
}
}

func (w *ReplayWorker) Start(ctx context.Context) error {
stream := fmt.Sprintf("pgflo_%s_stream", w.group)
subject := fmt.Sprintf("pgflo.%s", w.group)

w.logger.Info().
Str("stream", stream).
Str("subject", subject).
Str("group", w.group).
Time("start_time", w.startTime).
Time("end_time", w.endTime).
Msg("Starting replay worker")

js := w.natsClient.JetStream()

// Create unique consumer name for replay
// TODO - uniq?
consumerName := fmt.Sprintf("pgflo_%s_replay_%d", w.group, time.Now().UnixNano())

consumerConfig := &nats.ConsumerConfig{
Durable: consumerName,
FilterSubject: subject,
AckPolicy: nats.AckExplicitPolicy,
DeliverPolicy: nats.DeliverByStartTimePolicy,
OptStartTime: &w.startTime,
ReplayPolicy: nats.ReplayOriginalPolicy,
MaxDeliver: 1, // Only deliver once
AckWait: 30 * time.Second,
Copy link

@servusdei2018 servusdei2018 Nov 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should probably be a configurable option (AckWait)

MaxAckPending: w.batchSize,
}

_, err := js.AddConsumer(stream, consumerConfig)
if err != nil {
w.logger.Error().Err(err).Msg("Failed to add replay consumer")
return fmt.Errorf("failed to add replay consumer: %w", err)
}

// Cleanup consumer when done
defer func() {
if err := js.DeleteConsumer(stream, consumerName); err != nil {
w.logger.Error().Err(err).Msg("Failed to delete replay consumer")
}
}()

sub, err := js.PullSubscribe(subject, consumerName)
if err != nil {
w.logger.Error().Err(err).Msg("Failed to subscribe to subject")
return fmt.Errorf("failed to subscribe to subject: %w", err)
}

w.wg.Add(1)
go func() {
defer w.wg.Done()
if err := w.processReplayMessages(ctx, sub); err != nil && err != context.Canceled {
w.logger.Error().Err(err).Msg("Error processing replay messages")
}
}()

<-ctx.Done()
w.logger.Info().Msg("Received shutdown signal. Initiating graceful shutdown...")

w.wg.Wait()
w.logger.Debug().Msg("All goroutines finished")

return w.flushBuffer()
}

func (w *ReplayWorker) processReplayMessages(ctx context.Context, sub *nats.Subscription) error {
flushTicker := time.NewTicker(w.flushInterval)
defer flushTicker.Stop()

for {
select {
case <-ctx.Done():
w.logger.Info().Msg("Flushing remaining messages")
return w.flushBuffer()
case <-flushTicker.C:
if err := w.flushBuffer(); err != nil {
w.logger.Error().Err(err).Msg("Failed to flush buffer on interval")
}
default:
msgs, err := sub.Fetch(w.batchSize, nats.MaxWait(500*time.Millisecond))
if err != nil {
if err == nats.ErrTimeout {
continue
}
w.logger.Error().Err(err).Msg("Error fetching messages")
continue
}

for _, msg := range msgs {
metadata, err := msg.Metadata()
if err != nil {
w.logger.Error().Err(err).Msg("Failed to get message metadata")
continue
}

// Check if message is within time range
// TODO - alternative?
if metadata.Timestamp.After(w.endTime) {
w.logger.Info().Msg("Reached end time, finishing replay")
return w.flushBuffer()
}

if err := w.processMessage(msg); err != nil {
w.logger.Error().Err(err).Msg("Failed to process message")
}
if err := msg.Ack(); err != nil {
w.logger.Error().Err(err).Msg("Failed to acknowledge message")
}
}

if len(w.buffer) >= w.batchSize {
if err := w.flushBuffer(); err != nil {
w.logger.Error().Err(err).Msg("Failed to flush buffer")
}
}
}
}
}
Loading