From 33ee371936eb752096f3494014d2028fb90476dd Mon Sep 17 00:00:00 2001 From: Shayon Mukherjee Date: Sat, 16 Nov 2024 15:44:30 -0500 Subject: [PATCH] Move towards simpler interfaces and good lifecycle management hygiene (#52) * Move towards simpler interfaces and good lifecycle management hygiene * wip --- cmd/root.go | 52 ++++++++- internal/scripts/e2e_resume_test.rb | 10 +- internal/scripts/e2e_test_local.sh | 6 +- pkg/replicator/base_replicator.go | 94 +++++++++++++++- pkg/replicator/copy_and_stream_replicator.go | 103 ++++-------------- pkg/replicator/ddl_replicator.go | 62 ++++++----- pkg/replicator/errors.go | 11 +- pkg/replicator/factory.go | 51 +-------- pkg/replicator/interfaces.go | 3 +- pkg/replicator/stream_replicator.go | 103 +++--------------- .../tests/copy_and_stream_replicator_test.go | 97 +++++++++++++---- pkg/replicator/tests/ddl_replicator_test.go | 18 ++- 12 files changed, 323 insertions(+), 287 deletions(-) diff --git a/cmd/root.go b/cmd/root.go index 285a877..cbabab0 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -7,6 +7,7 @@ import ( "os/signal" "strings" "syscall" + "time" "github.com/pgflo/pg_flo/pkg/pgflonats" "github.com/pgflo/pg_flo/pkg/replicator" @@ -295,13 +296,60 @@ func runReplicator(_ *cobra.Command, _ []string) { factory = &replicator.StreamReplicatorFactory{} } + // Create base context for the entire application + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + rep, err := factory.CreateReplicator(config, natsClient) if err != nil { log.Fatal().Err(err).Msg("Failed to create replicator") } - if err := rep.StartReplication(); err != nil { - log.Fatal().Err(err).Msg("Failed to start replication") + // Setup signal handling + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) + + // Error channel to capture any replication errors + errCh := make(chan error, 1) + + // Start replication in a goroutine + go func() { + if err := rep.Start(ctx); err != nil { + errCh <- err + } + }() + + // Wait for either a signal or an error + select { + case sig := <-sigCh: + log.Info().Str("signal", sig.String()).Msg("Received shutdown signal") + + // Create a new context with timeout for graceful shutdown + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 30*time.Second) + defer shutdownCancel() + + // Cancel the main context first + cancel() + + // Then call Stop with the timeout context + if err := rep.Stop(shutdownCtx); err != nil { + log.Error().Err(err).Msg("Error during replication shutdown") + os.Exit(1) + } + log.Info().Msg("Replication stopped successfully") + + case err := <-errCh: + log.Error().Err(err).Msg("Replication error occurred") + + // Create shutdown context for cleanup + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 30*time.Second) + defer shutdownCancel() + + // Attempt cleanup even if there was an error + if stopErr := rep.Stop(shutdownCtx); stopErr != nil { + log.Error().Err(stopErr).Msg("Additional error during shutdown") + } + os.Exit(1) } } diff --git a/internal/scripts/e2e_resume_test.rb b/internal/scripts/e2e_resume_test.rb index 0857adb..dafaf7f 100644 --- a/internal/scripts/e2e_resume_test.rb +++ b/internal/scripts/e2e_resume_test.rb @@ -16,10 +16,10 @@ require 'securerandom' class ResumeTest - TOTAL_INSERTS = 3000 - INSERTS_BEFORE_INTERRUPT = 1000 - RESUME_WAIT_TIME = 2 # seconds - NUM_GROUPS = 3 + TOTAL_INSERTS = 5000 + INSERTS_BEFORE_INTERRUPT = 1500 + RESUME_WAIT_TIME = 1 # seconds + NUM_GROUPS = 4 PG_HOST = 'localhost' PG_PORT = 5433 @@ -284,7 +284,7 @@ def test_resume @logger.info "Waiting for all inserts to complete..." threads.each(&:join) - sleep 5 + sleep 20 @logger.info "Sending final SIGTERM to cleanup..." @replicator_pids.each do |pid| diff --git a/internal/scripts/e2e_test_local.sh b/internal/scripts/e2e_test_local.sh index 596e858..465a05b 100755 --- a/internal/scripts/e2e_test_local.sh +++ b/internal/scripts/e2e_test_local.sh @@ -33,9 +33,9 @@ make build setup_docker -log "Running e2e postgres tests..." -if CI=false ./internal/scripts/e2e_postgres.sh; then - success "e2e postgres tests completed successfully" +log "Running e2e ddl tests..." +if CI=false ruby ./internal/scripts/e2e_resume_test.rb; then + success "e2e ddl tests completed successfully" else error "Original e2e tests failed" exit 1 diff --git a/pkg/replicator/base_replicator.go b/pkg/replicator/base_replicator.go index 6dade52..dce0664 100644 --- a/pkg/replicator/base_replicator.go +++ b/pkg/replicator/base_replicator.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "strings" + "sync" "time" "errors" @@ -26,12 +27,17 @@ type BaseReplicator struct { Config Config ReplicationConn ReplicationConnection StandardConn StandardConnection + DDLReplicator *DDLReplicator Relations map[uint32]*pglogrepl.RelationMessage Logger utils.Logger TableDetails map[string][]string LastLSN pglogrepl.LSN NATSClient NATSClient TableReplicationKeys map[string]utils.ReplicationKey + stopChan chan struct{} + started bool + stopped bool + mu sync.RWMutex } // NewBaseReplicator creates a new BaseReplicator instance @@ -52,6 +58,15 @@ func NewBaseReplicator(config Config, replicationConn ReplicationConnection, sta NATSClient: natsClient, } + if config.TrackDDL { + ddlRepl, err := NewDDLReplicator(config, br, standardConn) + if err != nil { + br.Logger.Error().Err(err).Msg("Failed to initialize DDL replicator") + } else { + br.DDLReplicator = ddlRepl + } + } + // Initialize the OID map with custom types from the database if err := InitializeOIDMap(context.Background(), standardConn); err != nil { br.Logger.Error().Err(err).Msg("Failed to initialize OID map") @@ -465,10 +480,19 @@ func (r *BaseReplicator) CheckReplicationSlotExists(slotName string) (bool, erro func (r *BaseReplicator) GracefulShutdown(ctx context.Context) error { r.Logger.Info().Msg("Initiating graceful shutdown") + // Send final status update before DDL shutdown if err := r.SendStandbyStatusUpdate(ctx); err != nil { r.Logger.Warn().Err(err).Msg("Failed to send final standby status update") } + // Shutdown DDL replicator if it exists + if r.DDLReplicator != nil { + if err := r.DDLReplicator.Shutdown(ctx); err != nil { + r.Logger.Warn().Err(err).Msg("Failed to shutdown DDL replicator") + } + } + + // Save state and close connections if err := r.SaveState(r.LastLSN); err != nil { r.Logger.Warn().Err(err).Msg("Failed to save final state") } @@ -485,12 +509,30 @@ func (r *BaseReplicator) GracefulShutdown(ctx context.Context) error { func (r *BaseReplicator) closeConnections(ctx context.Context) error { r.Logger.Info().Msg("Closing database connections") - if err := r.ReplicationConn.Close(ctx); err != nil { - return fmt.Errorf("failed to close replication connection: %w", err) + // Close replication connection first + if r.ReplicationConn != nil { + if err := r.ReplicationConn.Close(ctx); err != nil { + r.Logger.Error().Err(err).Msg("Failed to close replication connection") + } + r.ReplicationConn = nil + } + + // Close standard connection + if r.StandardConn != nil { + if err := r.StandardConn.Close(ctx); err != nil { + r.Logger.Error().Err(err).Msg("Failed to close standard connection") + } + r.StandardConn = nil } - if err := r.StandardConn.Close(ctx); err != nil { - return fmt.Errorf("failed to close standard connection: %w", err) + + // Close DDL connection if exists + if r.DDLReplicator != nil && r.DDLReplicator.DDLConn != nil { + if err := r.DDLReplicator.DDLConn.Close(ctx); err != nil { + r.Logger.Error().Err(err).Msg("Failed to close DDL connection") + } + r.DDLReplicator.DDLConn = nil } + return nil } @@ -526,3 +568,47 @@ func (r *BaseReplicator) CheckReplicationSlotStatus(ctx context.Context) error { r.Logger.Info().Str("slotName", publicationName).Str("restartLSN", restartLSN).Msg("Replication slot status") return nil } + +func (r *BaseReplicator) Start(ctx context.Context) error { + r.mu.Lock() + if r.started { + r.mu.Unlock() + return ErrReplicatorAlreadyStarted + } + r.started = true + r.stopChan = make(chan struct{}) + r.mu.Unlock() + + // Create publication first (uses standard connection) + if err := r.CreatePublication(); err != nil { + return fmt.Errorf("failed to create publication: %w", err) + } + + // Create replication slot (uses standard connection) + if err := r.CreateReplicationSlot(ctx); err != nil { + return fmt.Errorf("failed to create replication slot: %w", err) + } + + // Setup and start DDL tracking if enabled + if r.Config.TrackDDL && r.DDLReplicator != nil { + if err := r.DDLReplicator.SetupDDLTracking(ctx); err != nil { + return fmt.Errorf("failed to setup DDL tracking: %w", err) + } + go r.DDLReplicator.StartDDLReplication(ctx) + } + + return nil +} + +func (r *BaseReplicator) Stop(ctx context.Context) error { + r.mu.Lock() + if !r.started || r.stopped { + r.mu.Unlock() + return ErrReplicatorNotStarted + } + r.stopped = true + close(r.stopChan) + r.mu.Unlock() + + return r.GracefulShutdown(ctx) +} diff --git a/pkg/replicator/copy_and_stream_replicator.go b/pkg/replicator/copy_and_stream_replicator.go index 078261f..b12ce16 100644 --- a/pkg/replicator/copy_and_stream_replicator.go +++ b/pkg/replicator/copy_and_stream_replicator.go @@ -3,10 +3,7 @@ package replicator import ( "context" "fmt" - "os" - "os/signal" "sync" - "syscall" "time" "github.com/jackc/pglogrepl" @@ -14,63 +11,30 @@ import ( "github.com/pgflo/pg_flo/pkg/utils" ) -func (r *CopyAndStreamReplicator) NewBaseReplicator() *BaseReplicator { - return &r.BaseReplicator -} - // CopyAndStreamReplicator implements a replication strategy that first copies existing data // and then streams changes. type CopyAndStreamReplicator struct { - BaseReplicator + *BaseReplicator MaxCopyWorkersPerTable int - DDLReplicator DDLReplicator CopyOnly bool } -// StartReplication begins the replication process. -func (r *CopyAndStreamReplicator) StartReplication() error { - ctx := context.Background() - - sigChan := make(chan os.Signal, 1) - signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) - - if !r.CopyOnly { - if err := r.BaseReplicator.CreatePublication(); err != nil { - return fmt.Errorf("failed to create publication: %v", err) - } - - if err := r.BaseReplicator.CreateReplicationSlot(ctx); err != nil { - return fmt.Errorf("failed to create replication slot: %v", err) - } +func NewCopyAndStreamReplicator(base *BaseReplicator, maxWorkers int, copyOnly bool) *CopyAndStreamReplicator { + return &CopyAndStreamReplicator{ + BaseReplicator: base, + MaxCopyWorkersPerTable: maxWorkers, + CopyOnly: copyOnly, } +} - // Start DDL replication with its own cancellable context and wait group - var ddlWg sync.WaitGroup - var ddlCancel context.CancelFunc - if r.Config.TrackDDL { - if err := r.DDLReplicator.SetupDDLTracking(ctx); err != nil { - return fmt.Errorf("failed to set up DDL tracking: %v", err) - } - ddlCtx, cancel := context.WithCancel(ctx) - ddlCancel = cancel - ddlWg.Add(1) - go func() { - defer ddlWg.Done() - r.DDLReplicator.StartDDLReplication(ddlCtx) - }() +// StartReplication begins the replication process. +func (r *CopyAndStreamReplicator) Start(ctx context.Context) error { + if err := r.BaseReplicator.Start(ctx); err != nil { + return err } - defer func() { - if r.Config.TrackDDL { - ddlCancel() - ddlWg.Wait() - if err := r.DDLReplicator.Shutdown(ctx); err != nil { - r.Logger.Error().Err(err).Msg("Failed to shutdown DDL replicator") - } - } - }() - if copyErr := r.ParallelCopy(ctx); copyErr != nil { - return fmt.Errorf("failed to perform parallel copy: %v", copyErr) + if err := r.ParallelCopy(ctx); err != nil { + return fmt.Errorf("failed to perform parallel copy: %w", err) } if r.CopyOnly { @@ -78,47 +42,22 @@ func (r *CopyAndStreamReplicator) StartReplication() error { return nil } - startLSN := r.BaseReplicator.LastLSN - - r.Logger.Info().Str("startLSN", startLSN.String()).Msg("Starting replication from LSN") - - // Create a stop channel for graceful shutdown - stopChan := make(chan struct{}) + startLSN := r.LastLSN errChan := make(chan error, 1) go func() { - errChan <- r.BaseReplicator.StartReplicationFromLSN(ctx, startLSN, stopChan) + errChan <- r.StartReplicationFromLSN(ctx, startLSN, r.stopChan) }() select { - case <-sigChan: - r.Logger.Info().Msg("Received shutdown signal") - // Signal replication loop to stop - close(stopChan) - // Wait for replication loop to exit - <-errChan - - // Signal DDL replication to stop and wait for it to finish - if r.Config.TrackDDL { - ddlCancel() - ddlWg.Wait() - if err := r.DDLReplicator.Shutdown(ctx); err != nil { - r.Logger.Error().Err(err).Msg("Failed to shutdown DDL replicator") - } - } - - // Proceed with graceful shutdown - if err := r.BaseReplicator.GracefulShutdown(ctx); err != nil { - r.Logger.Error().Err(err).Msg("Error during graceful shutdown") - return err - } + case <-ctx.Done(): + return ctx.Err() case err := <-errChan: - if err != nil { - r.Logger.Error().Err(err).Msg("Replication ended with error") - return err - } + return err } +} - return nil +func (r *CopyAndStreamReplicator) Stop(ctx context.Context) error { + return r.BaseReplicator.Stop(ctx) } // ParallelCopy performs a parallel copy of all specified tables. diff --git a/pkg/replicator/ddl_replicator.go b/pkg/replicator/ddl_replicator.go index c1cd43e..15caa09 100644 --- a/pkg/replicator/ddl_replicator.go +++ b/pkg/replicator/ddl_replicator.go @@ -120,30 +120,14 @@ func (d *DDLReplicator) StartDDLReplication(ctx context.Context) { for { select { case <-ctx.Done(): - d.BaseRepl.Logger.Info().Msg("Finishing processing of DDL events... (this can take a while)") - - // Create a new context without cancellation to process remaining events - // Alternatively, when a shutdown is received we can trigger a ROLLBACK too ? - shutdownCtx := context.Background() - - for { - hasEvents, err := d.HasPendingDDLEvents(shutdownCtx) - if err != nil { - d.BaseRepl.Logger.Error().Err(err).Msg("Failed to check for pending DDL events during shutdown") - break - } - if !hasEvents { - break - } - if err := d.ProcessDDLEvents(shutdownCtx); err != nil { - d.BaseRepl.Logger.Error().Err(err).Msg("Failed to process DDL events during shutdown") - } - - time.Sleep(100 * time.Millisecond) - } + d.BaseRepl.Logger.Info().Msg("DDL replication stopping...") return case <-ticker.C: if err := d.ProcessDDLEvents(ctx); err != nil { + if ctx.Err() != nil { + // Context canceled, exit gracefully + return + } d.BaseRepl.Logger.Error().Err(err).Msg("Failed to process DDL events") } } @@ -153,9 +137,9 @@ func (d *DDLReplicator) StartDDLReplication(ctx context.Context) { // ProcessDDLEvents processes DDL events from the log table func (d *DDLReplicator) ProcessDDLEvents(ctx context.Context) error { rows, err := d.DDLConn.Query(ctx, ` - SELECT id, event_type, object_type, object_identity, table_name, ddl_command, created_at - FROM internal_pg_flo.ddl_log - ORDER BY created_at ASC + SELECT id, event_type, object_type, object_identity, table_name, ddl_command, created_at + FROM internal_pg_flo.ddl_log + ORDER BY created_at ASC `) if err != nil { d.BaseRepl.Logger.Error().Err(err).Msg("Failed to query DDL log") @@ -261,14 +245,34 @@ func (d *DDLReplicator) Close(ctx context.Context) error { // Shutdown performs a graceful shutdown of the DDL replicator func (d *DDLReplicator) Shutdown(ctx context.Context) error { d.BaseRepl.Logger.Info().Msg("Shutting down DDL replicator") - if ctx.Err() != nil { - ctx = context.Background() - } + + // Process remaining events with the provided context if err := d.ProcessDDLEvents(ctx); err != nil { d.BaseRepl.Logger.Error().Err(err).Msg("Failed to process final DDL events") - return err + // Continue with shutdown even if processing fails + } + + // Wait for any pending events with respect to context deadline + ticker := time.NewTicker(100 * time.Millisecond) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + d.BaseRepl.Logger.Warn().Msg("Context deadline exceeded while waiting for DDL events") + return ctx.Err() + case <-ticker.C: + hasEvents, err := d.HasPendingDDLEvents(ctx) + if err != nil { + d.BaseRepl.Logger.Error().Err(err).Msg("Failed to check pending DDL events") + return err + } + if !hasEvents { + d.BaseRepl.Logger.Info().Msg("All DDL events processed") + return d.Close(ctx) + } + } } - return d.Close(ctx) } // HasPendingDDLEvents checks if there are pending DDL events in the log diff --git a/pkg/replicator/errors.go b/pkg/replicator/errors.go index 8e3df8a..eb040be 100644 --- a/pkg/replicator/errors.go +++ b/pkg/replicator/errors.go @@ -1,6 +1,15 @@ package replicator -import "fmt" +import ( + "errors" + "fmt" +) + +var ( + ErrReplicatorAlreadyStarted = errors.New("replicator already started") + ErrReplicatorNotStarted = errors.New("replicator not started") + ErrReplicatorAlreadyStopped = errors.New("replicator already stopped") +) // ReplicationError represents an error that occurred during replication. type ReplicationError struct { diff --git a/pkg/replicator/factory.go b/pkg/replicator/factory.go index 974a7b0..4c67498 100644 --- a/pkg/replicator/factory.go +++ b/pkg/replicator/factory.go @@ -41,28 +41,7 @@ func (f *StreamReplicatorFactory) CreateReplicator(config Config, natsClient NAT } baseReplicator := NewBaseReplicator(config, replicationConn, standardConn, natsClient) - - var ddlReplicator *DDLReplicator - if config.TrackDDL { - ddlConn, err := NewStandardConnection(config) - if err != nil { - return nil, fmt.Errorf("failed to create DDL connection: %v", err) - } - ddlReplicator, err = NewDDLReplicator(config, baseReplicator, ddlConn) - if err != nil { - return nil, fmt.Errorf("failed to create DDL replicator: %v", err) - } - } - - streamReplicator := &StreamReplicator{ - BaseReplicator: *baseReplicator, - } - - if ddlReplicator != nil { - streamReplicator.DDLReplicator = *ddlReplicator - } - - return streamReplicator, nil + return &StreamReplicator{BaseReplicator: baseReplicator}, nil } // CopyAndStreamReplicatorFactory creates `CopyAndStreamReplicator` instances @@ -81,31 +60,13 @@ func (f *CopyAndStreamReplicatorFactory) CreateReplicator(config Config, natsCli baseReplicator := NewBaseReplicator(config, replicationConn, standardConn, natsClient) - var ddlReplicator *DDLReplicator - if config.TrackDDL { - ddlConn, err := NewStandardConnection(config) - if err != nil { - return nil, fmt.Errorf("failed to create DDL connection: %v", err) - } - ddlReplicator, err = NewDDLReplicator(config, baseReplicator, ddlConn) - if err != nil { - return nil, fmt.Errorf("failed to create DDL replicator: %v", err) - } - } - if f.MaxCopyWorkersPerTable <= 0 { f.MaxCopyWorkersPerTable = 4 } - copyAndStreamReplicator := &CopyAndStreamReplicator{ - BaseReplicator: *baseReplicator, - MaxCopyWorkersPerTable: f.MaxCopyWorkersPerTable, - CopyOnly: f.CopyOnly, - } - - if ddlReplicator != nil { - copyAndStreamReplicator.DDLReplicator = *ddlReplicator - } - - return copyAndStreamReplicator, nil + return NewCopyAndStreamReplicator( + baseReplicator, + f.MaxCopyWorkersPerTable, + f.CopyOnly, + ), nil } diff --git a/pkg/replicator/interfaces.go b/pkg/replicator/interfaces.go index 7689f32..cacd1c8 100644 --- a/pkg/replicator/interfaces.go +++ b/pkg/replicator/interfaces.go @@ -12,7 +12,8 @@ import ( ) type Replicator interface { - StartReplication() error + Start(ctx context.Context) error + Stop(ctx context.Context) error } type ReplicationConnection interface { diff --git a/pkg/replicator/stream_replicator.go b/pkg/replicator/stream_replicator.go index bdbac58..4e475f1 100644 --- a/pkg/replicator/stream_replicator.go +++ b/pkg/replicator/stream_replicator.go @@ -2,113 +2,46 @@ package replicator import ( "context" - "fmt" - "os" - "os/signal" - "sync" - "syscall" "github.com/jackc/pglogrepl" ) type StreamReplicator struct { - BaseReplicator - DDLReplicator DDLReplicator + *BaseReplicator } -// StartReplication begins the replication process. -func (r *StreamReplicator) StartReplication() error { - ctx := context.Background() - - sigChan := make(chan os.Signal, 1) - signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) +func NewStreamReplicator(base *BaseReplicator) *StreamReplicator { + return &StreamReplicator{ + BaseReplicator: base, + } +} - if err := r.setup(ctx); err != nil { +func (r *StreamReplicator) Start(ctx context.Context) error { + if err := r.BaseReplicator.Start(ctx); err != nil { return err } - startLSN, err := r.getStartLSN() + startLSN, err := r.GetLastState() if err != nil { - return err + r.Logger.Warn().Err(err).Msg("Failed to get last LSN, starting from 0") + startLSN = pglogrepl.LSN(0) } - r.Logger.Info().Str("startLSN", startLSN.String()).Msg("Starting replication from LSN") + r.Logger.Info().Str("startLSN", startLSN.String()).Msg("Starting replication") - // Start DDL replication with its own cancellable context and wait group - var ddlWg sync.WaitGroup - var ddlCancel context.CancelFunc - if r.Config.TrackDDL { - ddlCtx, cancel := context.WithCancel(ctx) - ddlCancel = cancel - if err := r.DDLReplicator.SetupDDLTracking(ctx); err != nil { - return fmt.Errorf("failed to set up DDL tracking: %v", err) - } - ddlWg.Add(1) - go func() { - defer ddlWg.Done() - r.DDLReplicator.StartDDLReplication(ddlCtx) - }() - } - - stopChan := make(chan struct{}) errChan := make(chan error, 1) go func() { - errChan <- r.BaseReplicator.StartReplicationFromLSN(ctx, startLSN, stopChan) + errChan <- r.StartReplicationFromLSN(ctx, startLSN, r.stopChan) }() select { - case <-sigChan: - r.Logger.Info().Msg("Received shutdown signal") - // Signal the replication loop to stop - close(stopChan) - // Wait for the replication loop to exit - <-errChan - - // Signal DDL replication to stop and wait for it to finish - if r.Config.TrackDDL { - ddlCancel() - ddlWg.Wait() - if err := r.DDLReplicator.Shutdown(context.Background()); err != nil { - r.Logger.Error().Err(err).Msg("Failed to shutdown DDL replicator") - } - } - - if err := r.BaseReplicator.GracefulShutdown(ctx); err != nil { - r.Logger.Error().Err(err).Msg("Error during graceful shutdown") - return err - } + case <-ctx.Done(): + return ctx.Err() case err := <-errChan: - if err != nil { - r.Logger.Error().Err(err).Msg("Replication ended with error") - return err - } - } - - return nil -} - -func (r *StreamReplicator) setup(ctx context.Context) error { - if err := r.BaseReplicator.CreatePublication(); err != nil { - return fmt.Errorf("failed to create publication: %v", err) - } - - if err := r.BaseReplicator.CreateReplicationSlot(ctx); err != nil { - return fmt.Errorf("failed to create replication slot: %v", err) - } - - if err := r.BaseReplicator.CheckReplicationSlotStatus(ctx); err != nil { - return fmt.Errorf("failed to check replication slot status: %v", err) + return err } - - return nil } -// getStartLSN determines the starting LSN for replication. -func (r *StreamReplicator) getStartLSN() (pglogrepl.LSN, error) { - startLSN, err := r.BaseReplicator.GetLastState() - if err != nil { - r.Logger.Warn().Err(err).Msg("Failed to get last LSN, starting from 0") - return pglogrepl.LSN(0), nil - } - return startLSN, nil +func (r *StreamReplicator) Stop(ctx context.Context) error { + return r.BaseReplicator.Stop(ctx) } diff --git a/pkg/replicator/tests/copy_and_stream_replicator_test.go b/pkg/replicator/tests/copy_and_stream_replicator_test.go index 9609ab0..e9f2aa9 100644 --- a/pkg/replicator/tests/copy_and_stream_replicator_test.go +++ b/pkg/replicator/tests/copy_and_stream_replicator_test.go @@ -14,7 +14,6 @@ import ( "github.com/jackc/pgx/v5/pgconn" "github.com/pgflo/pg_flo/pkg/replicator" "github.com/pgflo/pg_flo/pkg/utils" - "github.com/rs/zerolog" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" ) @@ -38,6 +37,24 @@ func TestCopyAndStreamReplicator(t *testing.T) { }, }) + mockOIDRows := new(MockRows) + mockOIDRows.On("Next").Return(false) + mockOIDRows.On("Err").Return(nil) + mockOIDRows.On("Close").Return() + + mockPKRows := new(MockRows) + mockPKRows.On("Next").Return(false) + mockPKRows.On("Err").Return(nil) + mockPKRows.On("Close").Return() + + mockStandardConn.On("Query", mock.Anything, mock.MatchedBy(func(q string) bool { + return strings.Contains(q, "pg_type") + }), mock.Anything).Return(mockOIDRows, nil) + + mockStandardConn.On("Query", mock.Anything, mock.MatchedBy(func(q string) bool { + return strings.Contains(q, "table_info") + }), mock.Anything).Return(mockPKRows, nil) + mockPoolConn.On("BeginTx", mock.Anything, mock.MatchedBy(func(txOptions pgx.TxOptions) bool { return txOptions.IsoLevel == pgx.Serializable && txOptions.AccessMode == pgx.ReadOnly })).Return(mockTx, nil) @@ -91,11 +108,8 @@ func TestCopyAndStreamReplicator(t *testing.T) { })).Return(nil) csr := &replicator.CopyAndStreamReplicator{ - BaseReplicator: replicator.BaseReplicator{ - StandardConn: mockStandardConn, - NATSClient: mockNATSClient, - Logger: utils.NewZerologLogger(zerolog.New(nil)), - Config: replicator.Config{ + BaseReplicator: replicator.NewBaseReplicator( + replicator.Config{ Tables: []string{"users"}, Schema: "public", Host: "localhost", @@ -105,7 +119,10 @@ func TestCopyAndStreamReplicator(t *testing.T) { Database: "testdb", Group: "test_group", }, - }, + nil, + mockStandardConn, + mockNATSClient, + ), MaxCopyWorkersPerTable: 2, } @@ -120,6 +137,8 @@ func TestCopyAndStreamReplicator(t *testing.T) { mockTx.AssertExpectations(t) mockRows.AssertExpectations(t) mockNATSClient.AssertExpectations(t) + mockOIDRows.AssertExpectations(t) + mockPKRows.AssertExpectations(t) }) t.Run("CopyTableRange", func(t *testing.T) { @@ -146,6 +165,24 @@ func TestCopyAndStreamReplicator(t *testing.T) { }, }) + mockOIDRows := new(MockRows) + mockOIDRows.On("Next").Return(false) + mockOIDRows.On("Err").Return(nil) + mockOIDRows.On("Close").Return() + + mockPKRows := new(MockRows) + mockPKRows.On("Next").Return(false) + mockPKRows.On("Err").Return(nil) + mockPKRows.On("Close").Return() + + mockStandardConn.On("Query", mock.Anything, mock.MatchedBy(func(q string) bool { + return strings.Contains(q, "pg_type") + }), mock.Anything).Return(mockOIDRows, nil) + + mockStandardConn.On("Query", mock.Anything, mock.MatchedBy(func(q string) bool { + return strings.Contains(q, "table_info") + }), mock.Anything).Return(mockPKRows, nil) + mockRows.On("Next").Return(true).Once().On("Next").Return(false) mockRows.On("Err").Return(nil) mockRows.On("Close").Return() @@ -162,11 +199,8 @@ func TestCopyAndStreamReplicator(t *testing.T) { mockNATSClient.On("PublishMessage", "pgflo.test_group", mock.Anything).Return(nil) csr := &replicator.CopyAndStreamReplicator{ - BaseReplicator: replicator.BaseReplicator{ - StandardConn: mockStandardConn, - NATSClient: mockNATSClient, - Logger: utils.NewZerologLogger(zerolog.New(nil)), - Config: replicator.Config{ + BaseReplicator: replicator.NewBaseReplicator( + replicator.Config{ Tables: []string{"users"}, Schema: "public", Host: "localhost", @@ -176,7 +210,10 @@ func TestCopyAndStreamReplicator(t *testing.T) { Database: "testdb", Group: "test_group", }, - }, + nil, + mockStandardConn, + mockNATSClient, + ), } rowsCopied, err := csr.CopyTableRange(context.Background(), "users", 0, 1000, "snapshot-1", 0) @@ -188,6 +225,8 @@ func TestCopyAndStreamReplicator(t *testing.T) { mockTx.AssertExpectations(t) mockRows.AssertExpectations(t) mockNATSClient.AssertExpectations(t) + mockOIDRows.AssertExpectations(t) + mockPKRows.AssertExpectations(t) }) t.Run("CopyTableRange with diverse data types", func(t *testing.T) { testCases := []struct { @@ -271,6 +310,24 @@ func TestCopyAndStreamReplicator(t *testing.T) { }, }) + mockOIDRows := new(MockRows) + mockOIDRows.On("Next").Return(false) + mockOIDRows.On("Err").Return(nil) + mockOIDRows.On("Close").Return() + + mockPKRows := new(MockRows) + mockPKRows.On("Next").Return(false) + mockPKRows.On("Err").Return(nil) + mockPKRows.On("Close").Return() + + mockStandardConn.On("Query", mock.Anything, mock.MatchedBy(func(q string) bool { + return strings.Contains(q, "pg_type") + }), mock.Anything).Return(mockOIDRows, nil) + + mockStandardConn.On("Query", mock.Anything, mock.MatchedBy(func(q string) bool { + return strings.Contains(q, "table_info") + }), mock.Anything).Return(mockPKRows, nil) + mockRows.On("Next").Return(true).Once().On("Next").Return(false) mockRows.On("Err").Return(nil) mockRows.On("Close").Return() @@ -336,11 +393,8 @@ func TestCopyAndStreamReplicator(t *testing.T) { })).Return(nil) csr := &replicator.CopyAndStreamReplicator{ - BaseReplicator: replicator.BaseReplicator{ - StandardConn: mockStandardConn, - NATSClient: mockNATSClient, - Logger: utils.NewZerologLogger(zerolog.New(nil)), - Config: replicator.Config{ + BaseReplicator: replicator.NewBaseReplicator( + replicator.Config{ Tables: []string{"test_table"}, Schema: "public", Host: "localhost", @@ -350,7 +404,10 @@ func TestCopyAndStreamReplicator(t *testing.T) { Database: "testdb", Group: "test_group", }, - }, + nil, + mockStandardConn, + mockNATSClient, + ), } rowsCopied, err := csr.CopyTableRange(context.Background(), "test_table", 0, 1000, "snapshot-1", 0) @@ -363,6 +420,8 @@ func TestCopyAndStreamReplicator(t *testing.T) { mockTx.AssertExpectations(t) mockRows.AssertExpectations(t) mockNATSClient.AssertExpectations(t) + mockOIDRows.AssertExpectations(t) + mockPKRows.AssertExpectations(t) }) } }) diff --git a/pkg/replicator/tests/ddl_replicator_test.go b/pkg/replicator/tests/ddl_replicator_test.go index c62a1d5..b0de2a1 100644 --- a/pkg/replicator/tests/ddl_replicator_test.go +++ b/pkg/replicator/tests/ddl_replicator_test.go @@ -77,10 +77,6 @@ func TestDDLReplicator(t *testing.T) { defer cancel() mockRows := &MockRows{} - mockRows.On("Next").Return(false) - mockRows.On("Err").Return(nil) - mockRows.On("Close").Return() - mockStandardConn.On("Query", mock.Anything, mock.MatchedBy(func(sql string) bool { expectedParts := []string{ "SELECT id, event_type, object_type, object_identity, table_name, ddl_command, created_at", @@ -93,25 +89,25 @@ func TestDDLReplicator(t *testing.T) { } } return true - }), mock.Anything).Return(mockRows, nil) + }), mock.Anything).Return(mockRows, nil).Maybe() + + mockRows.On("Next").Return(false).Maybe() + mockRows.On("Err").Return(nil).Maybe() + mockRows.On("Close").Return().Maybe() mockStandardConn.On("QueryRow", mock.Anything, mock.MatchedBy(func(sql string) bool { return strings.Contains(sql, "SELECT COUNT(*) FROM internal_pg_flo.ddl_log") }), mock.Anything).Return(&MockRow{ scanFunc: func(dest ...interface{}) error { - *dest[0].(*int) = 1 + *dest[0].(*int) = 0 return nil }, - }) + }).Maybe() go ddlReplicator.StartDDLReplication(ctx) time.Sleep(100 * time.Millisecond) - hasPending, err := ddlReplicator.HasPendingDDLEvents(ctx) - assert.NoError(t, err) - assert.True(t, hasPending) - cancel() time.Sleep(100 * time.Millisecond)