diff --git a/cmd/root.go b/cmd/root.go index affafe2..d8f546f 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -1,9 +1,12 @@ package cmd import ( + "context" "fmt" "os" + "os/signal" "strings" + "syscall" "github.com/rs/zerolog/log" "github.com/shayonj/pg_flo/pkg/pgflonats" @@ -106,6 +109,7 @@ func init() { postgresWorkerCmd.Flags().String("postgres-user", "", "Target PostgreSQL user (env: PG_FLO_POSTGRES_USER)") postgresWorkerCmd.Flags().String("postgres-password", "", "Target PostgreSQL password (env: PG_FLO_POSTGRES_PASSWORD)") postgresWorkerCmd.Flags().Bool("postgres-sync-schema", false, "Sync schema from source to target (env: PG_FLO_POSTGRES_SYNC_SCHEMA)") + postgresWorkerCmd.Flags().Bool("postgres-disable-foreign-keys", false, "Disable foreign key checks during write (env: PG_FLO_POSTGRES_DISABLE_FOREIGN_KEYS)") markFlagRequired(postgresWorkerCmd, "postgres-host", "postgres-dbname", "postgres-user", "postgres-password") @@ -228,9 +232,30 @@ func runWorker(cmd *cobra.Command, _ []string) { } w := worker.NewWorker(natsClient, ruleEngine, sink, group) - if err := w.Start(cmd.Context()); err != nil { - log.Fatal().Err(err).Msg("Worker failed") + + ctx, cancel := context.WithCancel(cmd.Context()) + defer cancel() + + // Set up signal handling + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) + + go func() { + sig := <-sigCh + log.Info().Msgf("Received shutdown signal: %v. Canceling context...", sig) + cancel() + }() + + log.Info().Msg("Starting worker...") + if err := w.Start(ctx); err != nil { + if err == context.Canceled { + log.Info().Msg("Worker shut down gracefully") + } else { + log.Error().Err(err).Msg("Worker encountered an error during shutdown") + } } + + log.Info().Msg("Worker process exiting") } func loadRulesConfig(filePath string) (rules.Config, error) { @@ -267,6 +292,7 @@ func createSink(sinkType string) (sinks.Sink, error) { viper.GetString("dbname"), viper.GetString("user"), viper.GetString("password"), + viper.GetBool("postgres-disable-foreign-keys"), ) case "webhook": return sinks.NewWebhookSink( diff --git a/internal/e2e_test_local.sh b/internal/e2e_test_local.sh index faa2589..48ec1b0 100755 --- a/internal/e2e_test_local.sh +++ b/internal/e2e_test_local.sh @@ -30,7 +30,7 @@ trap cleanup EXIT make build -# setup_docker +setup_docker log "Running e2e copy & stream tests..." if CI=false ./internal/e2e_copy_and_stream.sh; then diff --git a/internal/e2e_test_stream.sh b/internal/e2e_test_stream.sh index fccefe9..7ab5890 100755 --- a/internal/e2e_test_stream.sh +++ b/internal/e2e_test_stream.sh @@ -46,14 +46,17 @@ simulate_changes() { local update_count=500 local delete_count=250 + log "Simulating inserts..." for i in $(seq 1 $insert_count); do run_sql "INSERT INTO public.users (data) VALUES ('Data $i');" done + log "Simulating updates..." for i in $(seq 1 $update_count); do run_sql "UPDATE public.users SET data = 'Updated data $i' WHERE id = $i;" done + log "Simulating deletes..." for i in $(seq 1 $delete_count); do run_sql "DELETE FROM public.users WHERE id = $i;" done diff --git a/pkg/replicator/copy_and_stream_replicator.go b/pkg/replicator/copy_and_stream_replicator.go index b2075df..85e3aa0 100644 --- a/pkg/replicator/copy_and_stream_replicator.go +++ b/pkg/replicator/copy_and_stream_replicator.go @@ -190,7 +190,7 @@ func (r *CopyAndStreamReplicator) getRelPages(ctx context.Context, tableName str // generateRanges creates a set of page ranges for copying. func (r *CopyAndStreamReplicator) generateRanges(relPages uint32) [][2]uint32 { var ranges [][2]uint32 - batchSize := uint32(100) + batchSize := uint32(1000) for start := uint32(0); start < relPages; start += batchSize { end := start + batchSize if end >= relPages { diff --git a/pkg/sinks/postgres.go b/pkg/sinks/postgres.go index 4e5a0d8..1d2240d 100644 --- a/pkg/sinks/postgres.go +++ b/pkg/sinks/postgres.go @@ -14,11 +14,12 @@ import ( // PostgresSink represents a sink for PostgreSQL database type PostgresSink struct { - conn *pgx.Conn + conn *pgx.Conn + disableForeignKeyChecks bool } // NewPostgresSink creates a new PostgresSink instance -func NewPostgresSink(targetHost string, targetPort int, targetDBName, targetUser, targetPassword string, syncSchema bool, sourceHost string, sourcePort int, sourceDBName, sourceUser, sourcePassword string) (*PostgresSink, error) { +func NewPostgresSink(targetHost string, targetPort int, targetDBName, targetUser, targetPassword string, syncSchema bool, sourceHost string, sourcePort int, sourceDBName, sourceUser, sourcePassword string, disableForeignKeyChecks bool) (*PostgresSink, error) { connConfig, err := pgx.ParseConfig(fmt.Sprintf("host=%s port=%d dbname=%s user=%s password=%s", targetHost, targetPort, targetDBName, targetUser, targetPassword)) if err != nil { return nil, fmt.Errorf("failed to parse connection config: %v", err) @@ -30,7 +31,8 @@ func NewPostgresSink(targetHost string, targetPort int, targetDBName, targetUser } sink := &PostgresSink{ - conn: conn, + conn: conn, + disableForeignKeyChecks: disableForeignKeyChecks, } if syncSchema { @@ -162,18 +164,42 @@ func (s *PostgresSink) handleDDL(tx pgx.Tx, message *utils.CDCMessage) error { return err } +// disableForeignKeys disables foreign key checks +func (s *PostgresSink) disableForeignKeys(ctx context.Context) error { + _, err := s.conn.Exec(ctx, "SET session_replication_role = 'replica';") + return err +} + +// enableForeignKeys enables foreign key checks +func (s *PostgresSink) enableForeignKeys(ctx context.Context) error { + _, err := s.conn.Exec(ctx, "SET session_replication_role = 'origin';") + return err +} + // WriteBatch writes a batch of CDC messages to the target database func (s *PostgresSink) WriteBatch(messages []*utils.CDCMessage) error { - tx, err := s.conn.Begin(context.Background()) + ctx := context.Background() + tx, err := s.conn.Begin(ctx) if err != nil { return fmt.Errorf("failed to begin transaction: %v", err) } defer func() { - if err := tx.Rollback(context.Background()); err != nil && err != pgx.ErrTxClosed { + if err := tx.Rollback(ctx); err != nil && err != pgx.ErrTxClosed { log.Error().Err(err).Msg("failed to rollback transaction") } }() + if s.disableForeignKeyChecks { + if err := s.disableForeignKeys(ctx); err != nil { + return fmt.Errorf("failed to disable foreign key checks: %v", err) + } + defer func() { + if err := s.enableForeignKeys(ctx); err != nil { + log.Error().Err(err).Msg("failed to re-enable foreign key checks") + } + }() + } + for _, message := range messages { var err error switch message.Type { @@ -194,7 +220,7 @@ func (s *PostgresSink) WriteBatch(messages []*utils.CDCMessage) error { } } - if err := tx.Commit(context.Background()); err != nil { + if err := tx.Commit(ctx); err != nil { return fmt.Errorf("failed to commit transaction: %v", err) } diff --git a/pkg/worker/worker.go b/pkg/worker/worker.go index 0dcc290..8633ba6 100644 --- a/pkg/worker/worker.go +++ b/pkg/worker/worker.go @@ -3,6 +3,9 @@ package worker import ( "context" "fmt" + "sync" + "sync/atomic" + "time" "github.com/nats-io/nats.go/jetstream" "github.com/rs/zerolog" @@ -14,13 +17,20 @@ import ( // Worker represents a worker that processes messages from NATS. type Worker struct { - natsClient *pgflonats.NATSClient - ruleEngine *rules.RuleEngine - sink sinks.Sink - group string - logger zerolog.Logger - batchSize int - maxRetries int + natsClient *pgflonats.NATSClient + ruleEngine *rules.RuleEngine + sink sinks.Sink + group string + logger zerolog.Logger + batchSize int + maxRetries int + buffer []*utils.CDCMessage + lastSavedState uint64 + flushInterval time.Duration + done chan struct{} + shutdownCh chan struct{} + wg sync.WaitGroup + isShuttingDown atomic.Bool } // NewWorker creates and returns a new Worker instance with the provided NATS client, rule engine, sink, and group. @@ -28,13 +38,18 @@ func NewWorker(natsClient *pgflonats.NATSClient, ruleEngine *rules.RuleEngine, s logger := zerolog.New(zerolog.NewConsoleWriter()).With().Timestamp().Str("component", "worker").Logger() return &Worker{ - natsClient: natsClient, - ruleEngine: ruleEngine, - sink: sink, - group: group, - logger: logger, - batchSize: 100, - maxRetries: 3, + natsClient: natsClient, + ruleEngine: ruleEngine, + sink: sink, + group: group, + logger: logger, + batchSize: 1000, + maxRetries: 3, + buffer: make([]*utils.CDCMessage, 0, 1000), + lastSavedState: 0, + flushInterval: 5 * time.Second, + done: make(chan struct{}), + shutdownCh: make(chan struct{}), } } @@ -95,7 +110,34 @@ func (w *Worker) Start(ctx context.Context) error { return fmt.Errorf("failed to create ordered consumer: %w", err) } - return w.processMessages(ctx, cons) + w.wg.Add(1) + go func() { + defer w.wg.Done() + if err := w.processMessages(ctx, cons); err != nil && err != context.Canceled { + w.logger.Error().Err(err).Msg("Error processing messages") + } + }() + + <-ctx.Done() + w.logger.Info().Msg("Received shutdown signal. Initiating graceful shutdown...") + w.isShuttingDown.Store(true) + close(w.shutdownCh) + + // Wait for processMessages to finish with a short timeout + done := make(chan struct{}) + go func() { + w.wg.Wait() + close(done) + }() + + select { + case <-done: + w.logger.Info().Msg("All goroutines finished") + case <-time.After(5 * time.Second): + w.logger.Warn().Msg("Shutdown timed out, forcing exit") + } + + return w.flushBuffer() } // processMessages continuously processes messages from the NATS consumer. @@ -106,22 +148,61 @@ func (w *Worker) processMessages(ctx context.Context, cons jetstream.Consumer) e } defer iter.Stop() + flushTicker := time.NewTicker(w.flushInterval) + defer flushTicker.Stop() + for { select { case <-ctx.Done(): - return ctx.Err() + w.logger.Info().Msg("Context canceled, flushing remaining messages") + return w.flushBuffer() + case <-w.shutdownCh: + w.logger.Info().Msg("Shutdown signal received, flushing remaining messages") + return w.flushBuffer() + case <-flushTicker.C: + if err := w.flushBuffer(); err != nil { + w.logger.Error().Err(err).Msg("Failed to flush buffer on interval") + } default: - msg, err := iter.Next() - if err != nil { + if w.isShuttingDown.Load() { + w.logger.Info().Msg("Shutdown in progress, stopping message processing") + return w.flushBuffer() + } + + // Use a timeout for fetching the next message + msgCh := make(chan jetstream.Msg, 1) + errCh := make(chan error, 1) + go func() { + msg, err := iter.Next() + if err != nil { + errCh <- err + } else { + msgCh <- msg + } + }() + + select { + case msg := <-msgCh: + if err := w.processMessage(msg); err != nil { + w.logger.Error().Err(err).Msg("Failed to process message") + } + case err := <-errCh: + if err == context.Canceled { + return w.flushBuffer() + } w.logger.Error().Err(err).Msg("Failed to get next message") - continue } - if err := w.processMessage(msg); err != nil { - w.logger.Error().Err(err).Msg("Failed to process message") - // Note: OrderedConsumer doesn't support Nak() + if w.isShuttingDown.Load() { + w.logger.Info().Msg("Shutdown detected, stopping message processing") + return w.flushBuffer() + } + + if len(w.buffer) >= w.batchSize { + if err := w.flushBuffer(); err != nil { + w.logger.Error().Err(err).Msg("Failed to flush buffer") + } } - // Note: OrderedConsumer doesn't require explicit Ack() } } } @@ -134,11 +215,6 @@ func (w *Worker) processMessage(msg jetstream.Msg) error { return err } - w.logger.Debug(). - Uint64("stream_seq", metadata.Sequence.Stream). - Uint64("consumer_seq", metadata.Sequence.Consumer). - Msg("Processing message") - var cdcMessage utils.CDCMessage err = cdcMessage.UnmarshalBinary(msg.Data()) if err != nil { @@ -159,26 +235,41 @@ func (w *Worker) processMessage(msg jetstream.Msg) error { cdcMessage = *processedMessage } - err = w.sink.WriteBatch([]*utils.CDCMessage{&cdcMessage}) + w.buffer = append(w.buffer, &cdcMessage) + w.lastSavedState = metadata.Sequence.Stream + + return nil +} + +func (w *Worker) flushBuffer() error { + if len(w.buffer) == 0 { + return nil + } + + w.logger.Debug().Int("messages", len(w.buffer)).Msg("Flushing buffer") + + // Use a context with timeout for the flush operation + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + err := w.sink.WriteBatch(w.buffer) if err != nil { - w.logger.Error().Err(err).Msg("Failed to write to sink") + w.logger.Error().Err(err).Msg("Failed to write batch to sink") return err } - state, err := w.natsClient.GetState(context.Background()) + state, err := w.natsClient.GetState(ctx) if err != nil { w.logger.Error().Err(err).Msg("Failed to get current state") return err } - if metadata.Sequence.Stream > state.LastProcessedSeq { - state.LastProcessedSeq = metadata.Sequence.Stream - if err := w.natsClient.SaveState(context.Background(), state); err != nil { - w.logger.Error().Err(err).Msg("Failed to save state") - } else { - w.logger.Debug().Uint64("last_processed_seq", state.LastProcessedSeq).Msg("Updated last processed sequence") - } + state.LastProcessedSeq = w.lastSavedState + if err := w.natsClient.SaveState(ctx, state); err != nil { + w.logger.Error().Err(err).Msg("Failed to save state") + return err } + w.buffer = w.buffer[:0] return nil }