Skip to content

Commit

Permalink
Move towards simpler interfaces and good lifecycle management hygiene (
Browse files Browse the repository at this point in the history
…#52)

* Move towards simpler interfaces and good lifecycle management hygiene

* wip
  • Loading branch information
shayonj authored Nov 16, 2024
1 parent 24be6e3 commit 33ee371
Show file tree
Hide file tree
Showing 12 changed files with 323 additions and 287 deletions.
52 changes: 50 additions & 2 deletions cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"os/signal"
"strings"
"syscall"
"time"

"github.com/pgflo/pg_flo/pkg/pgflonats"
"github.com/pgflo/pg_flo/pkg/replicator"
Expand Down Expand Up @@ -295,13 +296,60 @@ func runReplicator(_ *cobra.Command, _ []string) {
factory = &replicator.StreamReplicatorFactory{}
}

// Create base context for the entire application
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

rep, err := factory.CreateReplicator(config, natsClient)
if err != nil {
log.Fatal().Err(err).Msg("Failed to create replicator")
}

if err := rep.StartReplication(); err != nil {
log.Fatal().Err(err).Msg("Failed to start replication")
// Setup signal handling
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)

// Error channel to capture any replication errors
errCh := make(chan error, 1)

// Start replication in a goroutine
go func() {
if err := rep.Start(ctx); err != nil {
errCh <- err
}
}()

// Wait for either a signal or an error
select {
case sig := <-sigCh:
log.Info().Str("signal", sig.String()).Msg("Received shutdown signal")

// Create a new context with timeout for graceful shutdown
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 30*time.Second)
defer shutdownCancel()

// Cancel the main context first
cancel()

// Then call Stop with the timeout context
if err := rep.Stop(shutdownCtx); err != nil {
log.Error().Err(err).Msg("Error during replication shutdown")
os.Exit(1)
}
log.Info().Msg("Replication stopped successfully")

case err := <-errCh:
log.Error().Err(err).Msg("Replication error occurred")

// Create shutdown context for cleanup
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 30*time.Second)
defer shutdownCancel()

// Attempt cleanup even if there was an error
if stopErr := rep.Stop(shutdownCtx); stopErr != nil {
log.Error().Err(stopErr).Msg("Additional error during shutdown")
}
os.Exit(1)
}
}

Expand Down
10 changes: 5 additions & 5 deletions internal/scripts/e2e_resume_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
require 'securerandom'

class ResumeTest
TOTAL_INSERTS = 3000
INSERTS_BEFORE_INTERRUPT = 1000
RESUME_WAIT_TIME = 2 # seconds
NUM_GROUPS = 3
TOTAL_INSERTS = 5000
INSERTS_BEFORE_INTERRUPT = 1500
RESUME_WAIT_TIME = 1 # seconds
NUM_GROUPS = 4

PG_HOST = 'localhost'
PG_PORT = 5433
Expand Down Expand Up @@ -284,7 +284,7 @@ def test_resume
@logger.info "Waiting for all inserts to complete..."
threads.each(&:join)

sleep 5
sleep 20

@logger.info "Sending final SIGTERM to cleanup..."
@replicator_pids.each do |pid|
Expand Down
6 changes: 3 additions & 3 deletions internal/scripts/e2e_test_local.sh
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ make build

setup_docker

log "Running e2e postgres tests..."
if CI=false ./internal/scripts/e2e_postgres.sh; then
success "e2e postgres tests completed successfully"
log "Running e2e ddl tests..."
if CI=false ruby ./internal/scripts/e2e_resume_test.rb; then
success "e2e ddl tests completed successfully"
else
error "Original e2e tests failed"
exit 1
Expand Down
94 changes: 90 additions & 4 deletions pkg/replicator/base_replicator.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"fmt"
"strings"
"sync"
"time"

"errors"
Expand All @@ -26,12 +27,17 @@ type BaseReplicator struct {
Config Config
ReplicationConn ReplicationConnection
StandardConn StandardConnection
DDLReplicator *DDLReplicator
Relations map[uint32]*pglogrepl.RelationMessage
Logger utils.Logger
TableDetails map[string][]string
LastLSN pglogrepl.LSN
NATSClient NATSClient
TableReplicationKeys map[string]utils.ReplicationKey
stopChan chan struct{}
started bool
stopped bool
mu sync.RWMutex
}

// NewBaseReplicator creates a new BaseReplicator instance
Expand All @@ -52,6 +58,15 @@ func NewBaseReplicator(config Config, replicationConn ReplicationConnection, sta
NATSClient: natsClient,
}

if config.TrackDDL {
ddlRepl, err := NewDDLReplicator(config, br, standardConn)
if err != nil {
br.Logger.Error().Err(err).Msg("Failed to initialize DDL replicator")
} else {
br.DDLReplicator = ddlRepl
}
}

// Initialize the OID map with custom types from the database
if err := InitializeOIDMap(context.Background(), standardConn); err != nil {
br.Logger.Error().Err(err).Msg("Failed to initialize OID map")
Expand Down Expand Up @@ -465,10 +480,19 @@ func (r *BaseReplicator) CheckReplicationSlotExists(slotName string) (bool, erro
func (r *BaseReplicator) GracefulShutdown(ctx context.Context) error {
r.Logger.Info().Msg("Initiating graceful shutdown")

// Send final status update before DDL shutdown
if err := r.SendStandbyStatusUpdate(ctx); err != nil {
r.Logger.Warn().Err(err).Msg("Failed to send final standby status update")
}

// Shutdown DDL replicator if it exists
if r.DDLReplicator != nil {
if err := r.DDLReplicator.Shutdown(ctx); err != nil {
r.Logger.Warn().Err(err).Msg("Failed to shutdown DDL replicator")
}
}

// Save state and close connections
if err := r.SaveState(r.LastLSN); err != nil {
r.Logger.Warn().Err(err).Msg("Failed to save final state")
}
Expand All @@ -485,12 +509,30 @@ func (r *BaseReplicator) GracefulShutdown(ctx context.Context) error {
func (r *BaseReplicator) closeConnections(ctx context.Context) error {
r.Logger.Info().Msg("Closing database connections")

if err := r.ReplicationConn.Close(ctx); err != nil {
return fmt.Errorf("failed to close replication connection: %w", err)
// Close replication connection first
if r.ReplicationConn != nil {
if err := r.ReplicationConn.Close(ctx); err != nil {
r.Logger.Error().Err(err).Msg("Failed to close replication connection")
}
r.ReplicationConn = nil
}

// Close standard connection
if r.StandardConn != nil {
if err := r.StandardConn.Close(ctx); err != nil {
r.Logger.Error().Err(err).Msg("Failed to close standard connection")
}
r.StandardConn = nil
}
if err := r.StandardConn.Close(ctx); err != nil {
return fmt.Errorf("failed to close standard connection: %w", err)

// Close DDL connection if exists
if r.DDLReplicator != nil && r.DDLReplicator.DDLConn != nil {
if err := r.DDLReplicator.DDLConn.Close(ctx); err != nil {
r.Logger.Error().Err(err).Msg("Failed to close DDL connection")
}
r.DDLReplicator.DDLConn = nil
}

return nil
}

Expand Down Expand Up @@ -526,3 +568,47 @@ func (r *BaseReplicator) CheckReplicationSlotStatus(ctx context.Context) error {
r.Logger.Info().Str("slotName", publicationName).Str("restartLSN", restartLSN).Msg("Replication slot status")
return nil
}

func (r *BaseReplicator) Start(ctx context.Context) error {
r.mu.Lock()
if r.started {
r.mu.Unlock()
return ErrReplicatorAlreadyStarted
}
r.started = true
r.stopChan = make(chan struct{})
r.mu.Unlock()

// Create publication first (uses standard connection)
if err := r.CreatePublication(); err != nil {
return fmt.Errorf("failed to create publication: %w", err)
}

// Create replication slot (uses standard connection)
if err := r.CreateReplicationSlot(ctx); err != nil {
return fmt.Errorf("failed to create replication slot: %w", err)
}

// Setup and start DDL tracking if enabled
if r.Config.TrackDDL && r.DDLReplicator != nil {
if err := r.DDLReplicator.SetupDDLTracking(ctx); err != nil {
return fmt.Errorf("failed to setup DDL tracking: %w", err)
}
go r.DDLReplicator.StartDDLReplication(ctx)
}

return nil
}

func (r *BaseReplicator) Stop(ctx context.Context) error {
r.mu.Lock()
if !r.started || r.stopped {
r.mu.Unlock()
return ErrReplicatorNotStarted
}
r.stopped = true
close(r.stopChan)
r.mu.Unlock()

return r.GracefulShutdown(ctx)
}
103 changes: 21 additions & 82 deletions pkg/replicator/copy_and_stream_replicator.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,122 +3,61 @@ package replicator
import (
"context"
"fmt"
"os"
"os/signal"
"sync"
"syscall"
"time"

"github.com/jackc/pglogrepl"
"github.com/jackc/pgx/v5"
"github.com/pgflo/pg_flo/pkg/utils"
)

func (r *CopyAndStreamReplicator) NewBaseReplicator() *BaseReplicator {
return &r.BaseReplicator
}

// CopyAndStreamReplicator implements a replication strategy that first copies existing data
// and then streams changes.
type CopyAndStreamReplicator struct {
BaseReplicator
*BaseReplicator
MaxCopyWorkersPerTable int
DDLReplicator DDLReplicator
CopyOnly bool
}

// StartReplication begins the replication process.
func (r *CopyAndStreamReplicator) StartReplication() error {
ctx := context.Background()

sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)

if !r.CopyOnly {
if err := r.BaseReplicator.CreatePublication(); err != nil {
return fmt.Errorf("failed to create publication: %v", err)
}

if err := r.BaseReplicator.CreateReplicationSlot(ctx); err != nil {
return fmt.Errorf("failed to create replication slot: %v", err)
}
func NewCopyAndStreamReplicator(base *BaseReplicator, maxWorkers int, copyOnly bool) *CopyAndStreamReplicator {
return &CopyAndStreamReplicator{
BaseReplicator: base,
MaxCopyWorkersPerTable: maxWorkers,
CopyOnly: copyOnly,
}
}

// Start DDL replication with its own cancellable context and wait group
var ddlWg sync.WaitGroup
var ddlCancel context.CancelFunc
if r.Config.TrackDDL {
if err := r.DDLReplicator.SetupDDLTracking(ctx); err != nil {
return fmt.Errorf("failed to set up DDL tracking: %v", err)
}
ddlCtx, cancel := context.WithCancel(ctx)
ddlCancel = cancel
ddlWg.Add(1)
go func() {
defer ddlWg.Done()
r.DDLReplicator.StartDDLReplication(ddlCtx)
}()
// StartReplication begins the replication process.
func (r *CopyAndStreamReplicator) Start(ctx context.Context) error {
if err := r.BaseReplicator.Start(ctx); err != nil {
return err
}
defer func() {
if r.Config.TrackDDL {
ddlCancel()
ddlWg.Wait()
if err := r.DDLReplicator.Shutdown(ctx); err != nil {
r.Logger.Error().Err(err).Msg("Failed to shutdown DDL replicator")
}
}
}()

if copyErr := r.ParallelCopy(ctx); copyErr != nil {
return fmt.Errorf("failed to perform parallel copy: %v", copyErr)
if err := r.ParallelCopy(ctx); err != nil {
return fmt.Errorf("failed to perform parallel copy: %w", err)
}

if r.CopyOnly {
r.Logger.Info().Msg("Copy-only mode: finished copying data")
return nil
}

startLSN := r.BaseReplicator.LastLSN

r.Logger.Info().Str("startLSN", startLSN.String()).Msg("Starting replication from LSN")

// Create a stop channel for graceful shutdown
stopChan := make(chan struct{})
startLSN := r.LastLSN
errChan := make(chan error, 1)
go func() {
errChan <- r.BaseReplicator.StartReplicationFromLSN(ctx, startLSN, stopChan)
errChan <- r.StartReplicationFromLSN(ctx, startLSN, r.stopChan)
}()

select {
case <-sigChan:
r.Logger.Info().Msg("Received shutdown signal")
// Signal replication loop to stop
close(stopChan)
// Wait for replication loop to exit
<-errChan

// Signal DDL replication to stop and wait for it to finish
if r.Config.TrackDDL {
ddlCancel()
ddlWg.Wait()
if err := r.DDLReplicator.Shutdown(ctx); err != nil {
r.Logger.Error().Err(err).Msg("Failed to shutdown DDL replicator")
}
}

// Proceed with graceful shutdown
if err := r.BaseReplicator.GracefulShutdown(ctx); err != nil {
r.Logger.Error().Err(err).Msg("Error during graceful shutdown")
return err
}
case <-ctx.Done():
return ctx.Err()
case err := <-errChan:
if err != nil {
r.Logger.Error().Err(err).Msg("Replication ended with error")
return err
}
return err
}
}

return nil
func (r *CopyAndStreamReplicator) Stop(ctx context.Context) error {
return r.BaseReplicator.Stop(ctx)
}

// ParallelCopy performs a parallel copy of all specified tables.
Expand Down
Loading

0 comments on commit 33ee371

Please sign in to comment.