Skip to content

Commit

Permalink
Ability to replay messages between time frame and opt-in support for …
Browse files Browse the repository at this point in the history
…extend storage in NATS
  • Loading branch information
shayonj committed Oct 28, 2024
1 parent 562f6f4 commit 14b446e
Show file tree
Hide file tree
Showing 3 changed files with 285 additions and 10 deletions.
100 changes: 97 additions & 3 deletions cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"os/signal"
"strings"
"syscall"
"time"

"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
Expand Down Expand Up @@ -65,6 +66,12 @@ var (
Short: "Start the worker with webhook sink",
Run: runWorker,
}

replayCmd = &cobra.Command{
Use: "replay-worker",
Short: "Replay historical changes",
Run: runReplay,
}
)

func Execute() error {
Expand Down Expand Up @@ -94,6 +101,9 @@ func init() {
replicatorCmd.Flags().Bool("copy-and-stream", false, "Enable copy and stream mode (env: PG_FLO_COPY_AND_STREAM)")
replicatorCmd.Flags().Int("max-copy-workers-per-table", 4, "Maximum number of copy workers per table (env: PG_FLO_MAX_COPY_WORKERS_PER_TABLE)")
replicatorCmd.Flags().Bool("track-ddl", false, "Enable tracking of DDL changes (env: PG_FLO_TRACK_DDL)")
replicatorCmd.Flags().Duration("max-age", 24*time.Hour, "Maximum age of messages to retain (env: PG_FLO_MAX_AGE)")
replicatorCmd.Flags().Bool("enable-replays", false, "Enable replay capability - requires more storage (env: PG_FLO_ENABLE_REPLAYS)")
replicatorCmd.Flags().Int("nats-replicas", 1, "Number of stream replicas (env: PG_FLO_NATS_REPLICAS)")

markFlagRequired(replicatorCmd, "host", "port", "dbname", "user", "password", "group", "nats-url")

Expand Down Expand Up @@ -133,7 +143,17 @@ func init() {
// Add subcommands to worker command
workerCmd.AddCommand(stdoutWorkerCmd, fileWorkerCmd, postgresWorkerCmd, webhookWorkerCmd)

rootCmd.AddCommand(replicatorCmd, workerCmd)
// Add replay command
replayCmd.Flags().String("start-time", "", "Start time for replay (RFC3339 format)")
replayCmd.Flags().String("end-time", "", "End time for replay (RFC3339 format)")
replayCmd.Flags().String("group", "", "Group name for worker (env: PG_FLO_GROUP)")
replayCmd.Flags().String("nats-url", "", "NATS server URL (env: PG_FLO_NATS_URL)")
replayCmd.Flags().String("rules-config", "", "Path to rules configuration file (env: PG_FLO_RULES_CONFIG)")
replayCmd.Flags().String("routing-config", "", "Path to routing configuration file (env: PG_FLO_ROUTING_CONFIG)")

markFlagRequired(replayCmd, "start-time", "end-time", "group", "nats-url")

rootCmd.AddCommand(replicatorCmd, workerCmd, replayCmd)
}

func initConfig() {
Expand All @@ -157,6 +177,7 @@ func initConfig() {
bindFlags(fileWorkerCmd)
bindFlags(postgresWorkerCmd)
bindFlags(webhookWorkerCmd)
bindFlags(replayCmd)

if err := viper.ReadInConfig(); err == nil {
fmt.Println("Using config file:", viper.ConfigFileUsed())
Expand Down Expand Up @@ -192,7 +213,13 @@ func runReplicator(_ *cobra.Command, _ []string) {
log.Fatal().Msg("NATS URL is required")
}

natsClient, err := pgflonats.NewNATSClient(natsURL, fmt.Sprintf("pgflo_%s_stream", config.Group), config.Group)
natsConfig := pgflonats.StreamConfig{
MaxAge: viper.GetDuration("max-age"),
Replays: viper.GetBool("enable-replays"),
Replicas: viper.GetInt("nats-replicas"),
}

natsClient, err := pgflonats.NewNATSClient(natsURL, fmt.Sprintf("pgflo_%s_stream", config.Group), config.Group, natsConfig)
if err != nil {
log.Fatal().Err(err).Msg("Failed to create NATS client")
}
Expand Down Expand Up @@ -222,7 +249,8 @@ func runWorker(cmd *cobra.Command, _ []string) {
sinkType := cmd.Use

// Create NATS client
natsClient, err := pgflonats.NewNATSClient(natsURL, fmt.Sprintf("pgflo_%s_stream", group), group)
natsConfig := pgflonats.DefaultStreamConfig()
natsClient, err := pgflonats.NewNATSClient(natsURL, fmt.Sprintf("pgflo_%s_stream", group), group, natsConfig)
if err != nil {
log.Fatal().Err(err).Msg("Failed to create NATS client")
}
Expand Down Expand Up @@ -352,3 +380,69 @@ func markPersistentFlagRequired(cmd *cobra.Command, flags ...string) {
}
}
}

func runReplay(cmd *cobra.Command, _ []string) {
startTimeStr := viper.GetString("start-time")
endTimeStr := viper.GetString("end-time")

startTime, err := time.Parse(time.RFC3339, startTimeStr)
if err != nil {
log.Fatal().Err(err).Msg("Invalid start time format")
}

endTime, err := time.Parse(time.RFC3339, endTimeStr)
if err != nil {
log.Fatal().Err(err).Msg("Invalid end time format")
}

// Create NATS client with replay config
natsConfig := pgflonats.StreamConfig{
MaxAge: viper.GetDuration("max-age"),
Replays: viper.GetBool("enable-replays"),
}

natsClient, err := pgflonats.NewNATSClient(
viper.GetString("nats-url"),
fmt.Sprintf("pgflo_%s_stream", viper.GetString("group")),
viper.GetString("group"),
natsConfig,
)
if err != nil {
log.Fatal().Err(err).Msg("Failed to create NATS client")
}

ruleEngine := rules.NewRuleEngine()
if viper.GetString("rules-config") != "" {
rulesConfig, err := loadRulesConfig(viper.GetString("rules-config"))
if err != nil {
log.Fatal().Err(err).Msg("Failed to load rules configuration")
}
if err := ruleEngine.LoadRules(rulesConfig); err != nil {
log.Fatal().Err(err).Msg("Failed to load rules")
}
}

router := routing.NewRouter()
if viper.GetString("routing-config") != "" {
routingConfig, err := loadRoutingConfig(viper.GetString("routing-config"))
if err != nil {
log.Fatal().Err(err).Msg("Failed to load routing configuration")
}
if err := router.LoadRoutes(routingConfig); err != nil {
log.Fatal().Err(err).Msg("Failed to load routes")
}
}

sink, err := createSink(cmd.Use)
if err != nil {
log.Fatal().Err(err).Msg("Failed to create sink")
}

w := worker.NewWorker(natsClient, ruleEngine, router, sink, viper.GetString("group"))

replayWorker := worker.NewReplayWorker(w, startTime, endTime)

if err := replayWorker.Start(cmd.Context()); err != nil {
log.Fatal().Err(err).Msg("Failed to start replay worker")
}
}
52 changes: 45 additions & 7 deletions pkg/pgflonats/pgflonats.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,35 @@ import (
const (
defaultNATSURL = "nats://localhost:4222"
envNATSURL = "PG_FLO_NATS_URL"

defaultMaxAge = 24 * time.Hour
defaultReplicas = 1
defaultReplays = false
)

// StreamConfig represents the configuration for a NATS stream
type StreamConfig struct {
MaxAge time.Duration
Replays bool
Replicas int
}

// DefaultStreamConfig returns a StreamConfig with default values
func DefaultStreamConfig() StreamConfig {
return StreamConfig{
MaxAge: defaultMaxAge,
Replays: defaultReplays,
Replicas: defaultReplicas,
}
}

// NATSClient represents a client for interacting with NATS
type NATSClient struct {
conn *nats.Conn
js nats.JetStreamContext
stream string
stateBucket string
config StreamConfig
}

// State represents the current state of the replication process
Expand All @@ -31,7 +52,15 @@ type State struct {
}

// NewNATSClient creates a new NATS client with the specified configuration, setting up the connection, main stream, and state bucket.
func NewNATSClient(url, stream, group string) (*NATSClient, error) {
func NewNATSClient(url, stream, group string, config StreamConfig) (*NATSClient, error) {
if config.MaxAge == 0 {
config.MaxAge = defaultMaxAge
}

if config.Replicas == 0 {
config.Replicas = defaultReplicas
}

if url == "" {
url = os.Getenv(envNATSURL)
if url == "" {
Expand Down Expand Up @@ -66,13 +95,21 @@ func NewNATSClient(url, stream, group string) (*NATSClient, error) {
return nil, fmt.Errorf("failed to create JetStream context: %w", err)
}

// Create the main stream
// Create the main stream with configurable retention
streamConfig := &nats.StreamConfig{
Name: stream,
Subjects: []string{fmt.Sprintf("pgflo.%s", group)},
Storage: nats.FileStorage,
Retention: nats.LimitsPolicy,
MaxAge: 24 * time.Hour,
Name: stream,
Storage: nats.FileStorage,
Retention: nats.LimitsPolicy,
MaxAge: config.MaxAge,
Replicas: config.Replicas,
Discard: nats.DiscardOld,
Subjects: []string{fmt.Sprintf("pgflo.%s", group)},
Description: fmt.Sprintf("pg_flo stream for group %s", group),
}
if config.Replays {
streamConfig.Retention = nats.WorkQueuePolicy
streamConfig.MaxMsgs = -1

}
_, err = js.AddStream(streamConfig)
if err != nil && !errors.Is(err, nats.ErrStreamNameAlreadyInUse) {
Expand Down Expand Up @@ -100,6 +137,7 @@ func NewNATSClient(url, stream, group string) (*NATSClient, error) {
js: js,
stream: stream,
stateBucket: stateBucket,
config: config,
}, nil
}

Expand Down
143 changes: 143 additions & 0 deletions pkg/worker/replay_worker.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
package worker

import (
"context"
"fmt"
"time"

"github.com/nats-io/nats.go"
)

type ReplayWorker struct {
*Worker
startTime time.Time
endTime time.Time
}

func NewReplayWorker(w *Worker, startTime, endTime time.Time) *ReplayWorker {
return &ReplayWorker{
Worker: w,
startTime: startTime,
endTime: endTime,
}
}

func (w *ReplayWorker) Start(ctx context.Context) error {
stream := fmt.Sprintf("pgflo_%s_stream", w.group)
subject := fmt.Sprintf("pgflo.%s", w.group)

w.logger.Info().
Str("stream", stream).
Str("subject", subject).
Str("group", w.group).
Time("start_time", w.startTime).
Time("end_time", w.endTime).
Msg("Starting replay worker")

js := w.natsClient.JetStream()

// Create unique consumer name for replay
// TODO - uniq?
consumerName := fmt.Sprintf("pgflo_%s_replay_%d", w.group, time.Now().UnixNano())

consumerConfig := &nats.ConsumerConfig{
Durable: consumerName,
FilterSubject: subject,
AckPolicy: nats.AckExplicitPolicy,
DeliverPolicy: nats.DeliverByStartTimePolicy,
OptStartTime: &w.startTime,
ReplayPolicy: nats.ReplayOriginalPolicy,
MaxDeliver: 1, // Only deliver once
AckWait: 30 * time.Second,
MaxAckPending: w.batchSize,
}

_, err := js.AddConsumer(stream, consumerConfig)
if err != nil {
w.logger.Error().Err(err).Msg("Failed to add replay consumer")
return fmt.Errorf("failed to add replay consumer: %w", err)
}

// Cleanup consumer when done
defer func() {
if err := js.DeleteConsumer(stream, consumerName); err != nil {
w.logger.Error().Err(err).Msg("Failed to delete replay consumer")
}
}()

sub, err := js.PullSubscribe(subject, consumerName)
if err != nil {
w.logger.Error().Err(err).Msg("Failed to subscribe to subject")
return fmt.Errorf("failed to subscribe to subject: %w", err)
}

w.wg.Add(1)
go func() {
defer w.wg.Done()
if err := w.processReplayMessages(ctx, sub); err != nil && err != context.Canceled {
w.logger.Error().Err(err).Msg("Error processing replay messages")
}
}()

<-ctx.Done()
w.logger.Info().Msg("Received shutdown signal. Initiating graceful shutdown...")

w.wg.Wait()
w.logger.Debug().Msg("All goroutines finished")

return w.flushBuffer()
}

func (w *ReplayWorker) processReplayMessages(ctx context.Context, sub *nats.Subscription) error {
flushTicker := time.NewTicker(w.flushInterval)
defer flushTicker.Stop()

for {
select {
case <-ctx.Done():
w.logger.Info().Msg("Flushing remaining messages")
return w.flushBuffer()
case <-flushTicker.C:
if err := w.flushBuffer(); err != nil {
w.logger.Error().Err(err).Msg("Failed to flush buffer on interval")
}
default:
msgs, err := sub.Fetch(w.batchSize, nats.MaxWait(500*time.Millisecond))
if err != nil {
if err == nats.ErrTimeout {
continue
}
w.logger.Error().Err(err).Msg("Error fetching messages")
continue
}

for _, msg := range msgs {
metadata, err := msg.Metadata()
if err != nil {
w.logger.Error().Err(err).Msg("Failed to get message metadata")
continue
}

// Check if message is within time range
// TODO - alternative?
if metadata.Timestamp.After(w.endTime) {
w.logger.Info().Msg("Reached end time, finishing replay")
return w.flushBuffer()
}

if err := w.processMessage(msg); err != nil {
w.logger.Error().Err(err).Msg("Failed to process message")
}
if err := msg.Ack(); err != nil {
w.logger.Error().Err(err).Msg("Failed to acknowledge message")
}
}

if len(w.buffer) >= w.batchSize {
if err := w.flushBuffer(); err != nil {
w.logger.Error().Err(err).Msg("Failed to flush buffer")
}
}
}
}
}

0 comments on commit 14b446e

Please sign in to comment.