From 185dd695a6d27651dda310036e70da5679568cf9 Mon Sep 17 00:00:00 2001 From: Shayon Mukherjee Date: Tue, 15 Oct 2024 08:11:34 -0400 Subject: [PATCH] Proper shutdown --- pkg/replicator/base_replicator.go | 71 ++++---- pkg/replicator/copy_and_stream_replicator.go | 45 +++-- pkg/replicator/stream_replicator.go | 77 +++++---- pkg/replicator/tests/base_replicator_test.go | 156 +----------------- .../tests/copy_and_stream_replicator_test.go | 71 -------- pkg/worker/worker.go | 34 ++-- 6 files changed, 122 insertions(+), 332 deletions(-) diff --git a/pkg/replicator/base_replicator.go b/pkg/replicator/base_replicator.go index 3619b2e..a0beef9 100644 --- a/pkg/replicator/base_replicator.go +++ b/pkg/replicator/base_replicator.go @@ -123,9 +123,10 @@ func (r *BaseReplicator) checkPublicationExists(publicationName string) (bool, e } // StartReplicationFromLSN initiates the replication process from a given LSN -func (r *BaseReplicator) StartReplicationFromLSN(ctx context.Context, startLSN pglogrepl.LSN) error { +func (r *BaseReplicator) StartReplicationFromLSN(ctx context.Context, startLSN pglogrepl.LSN, stopChan <-chan struct{}) error { publicationName := GeneratePublicationName(r.Config.Group) r.Logger.Info().Str("startLSN", startLSN.String()).Str("publication", publicationName).Msg("Starting replication") + err := r.ReplicationConn.StartReplication(ctx, publicationName, startLSN, pglogrepl.StartReplicationOptions{ PluginArgs: []string{ "proto_version '1'", @@ -138,30 +139,21 @@ func (r *BaseReplicator) StartReplicationFromLSN(ctx context.Context, startLSN p r.Logger.Info().Str("startLSN", startLSN.String()).Msg("Replication started successfully") - errChan := make(chan error, 1) - go func() { - errChan <- r.StreamChanges(ctx) - }() - - select { - case <-ctx.Done(): - return r.gracefulShutdown() - case err := <-errChan: - if err != nil && !errors.Is(err, context.Canceled) { - return err - } - return nil - } + return r.StreamChanges(ctx, stopChan) } // StreamChanges continuously processes replication messages -func (r *BaseReplicator) StreamChanges(ctx context.Context) error { +func (r *BaseReplicator) StreamChanges(ctx context.Context, stopChan <-chan struct{}) error { lastStatusUpdate := time.Now() standbyMessageTimeout := time.Second * 10 for { select { case <-ctx.Done(): + r.Logger.Info().Msg("Context canceled, stopping StreamChanges") + return nil + case <-stopChan: + r.Logger.Info().Msg("Stop signal received, exiting StreamChanges") return nil default: if err := r.ProcessNextMessage(ctx, &lastStatusUpdate, standbyMessageTimeout); err != nil { @@ -178,10 +170,12 @@ func (r *BaseReplicator) StreamChanges(ctx context.Context) error { func (r *BaseReplicator) ProcessNextMessage(ctx context.Context, lastStatusUpdate *time.Time, standbyMessageTimeout time.Duration) error { msg, err := r.ReplicationConn.ReceiveMessage(ctx) if err != nil { - if ctx.Err() != nil { - return fmt.Errorf("context error while receiving message: %w", ctx.Err()) + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + r.Logger.Info().Msg("Context canceled or deadline exceeded, stopping message processing") + return nil } - return fmt.Errorf("failed to receive message: %w", err) + r.Logger.Error().Err(err).Msg("Error processing next message") + return err } switch msg := msg.(type) { @@ -414,19 +408,6 @@ func (r *BaseReplicator) SendStandbyStatusUpdate(ctx context.Context) error { return nil } -// SendFinalStandbyStatusUpdate sends a final status update before shutting down -func (r *BaseReplicator) SendFinalStandbyStatusUpdate() error { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - - if err := r.SendStandbyStatusUpdate(ctx); err != nil { - return fmt.Errorf("failed to send standby status update: %w", err) - } - - r.Logger.Info().Msg("Sent final standby status update") - return nil -} - // CreateReplicationSlot ensures that a replication slot exists, creating one if necessary func (r *BaseReplicator) CreateReplicationSlot(ctx context.Context) error { publicationName := GeneratePublicationName(r.Config.Group) @@ -463,28 +444,34 @@ func (r *BaseReplicator) CheckReplicationSlotExists(slotName string) (bool, erro return exists, nil } -// gracefulShutdown performs a graceful shutdown of the replicator -func (r *BaseReplicator) gracefulShutdown() error { +// GracefulShutdown performs a graceful shutdown of the replicator +func (r *BaseReplicator) GracefulShutdown(ctx context.Context) error { r.Logger.Info().Msg("Initiating graceful shutdown") - if err := r.SendFinalStandbyStatusUpdate(); err != nil { - r.Logger.Error().Err(err).Msg("Failed to send final standby status update") + if err := r.SendStandbyStatusUpdate(ctx); err != nil { + r.Logger.Warn().Err(err).Msg("Failed to send final standby status update") } - if err := r.closeConnections(); err != nil { - r.Logger.Error().Err(err).Msg("Failed to close connections") + if err := r.SaveState(ctx, r.LastLSN); err != nil { + r.Logger.Warn().Err(err).Msg("Failed to save final state") } - r.Logger.Info().Msg("Graceful shutdown completed") + if err := r.closeConnections(ctx); err != nil { + r.Logger.Warn().Err(err).Msg("Failed to close connections") + } + + r.Logger.Info().Msg("Base replicator shutdown completed") return nil } // closeConnections closes all open database connections -func (r *BaseReplicator) closeConnections() error { - if err := r.ReplicationConn.Close(context.Background()); err != nil { +func (r *BaseReplicator) closeConnections(ctx context.Context) error { + r.Logger.Info().Msg("Closing database connections") + + if err := r.ReplicationConn.Close(ctx); err != nil { return fmt.Errorf("failed to close replication connection: %w", err) } - if err := r.StandardConn.Close(context.Background()); err != nil { + if err := r.StandardConn.Close(ctx); err != nil { return fmt.Errorf("failed to close standard connection: %w", err) } return nil diff --git a/pkg/replicator/copy_and_stream_replicator.go b/pkg/replicator/copy_and_stream_replicator.go index 85e3aa0..b0904f4 100644 --- a/pkg/replicator/copy_and_stream_replicator.go +++ b/pkg/replicator/copy_and_stream_replicator.go @@ -28,14 +28,11 @@ type CopyAndStreamReplicator struct { // StartReplication begins the replication process. func (r *CopyAndStreamReplicator) StartReplication() error { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + ctx := context.Background() sigChan := make(chan os.Signal, 1) signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) - go r.handleShutdownSignal(sigChan, cancel) - if err := r.BaseReplicator.CreatePublication(); err != nil { return fmt.Errorf("failed to create publication: %v", err) } @@ -53,7 +50,6 @@ func (r *CopyAndStreamReplicator) StartReplication() error { ddlCtx, ddlCancel = context.WithCancel(ctx) go r.DDLReplicator.StartDDLReplication(ddlCtx) } - defer func() { if r.Config.TrackDDL { ddlCancel() @@ -63,20 +59,43 @@ func (r *CopyAndStreamReplicator) StartReplication() error { } }() - if copyErr := r.ParallelCopy(context.Background()); copyErr != nil { + if copyErr := r.ParallelCopy(ctx); copyErr != nil { return fmt.Errorf("failed to perform parallel copy: %v", copyErr) } startLSN := r.BaseReplicator.LastLSN r.Logger.Info().Str("startLSN", startLSN.String()).Msg("Starting replication from LSN") - return r.BaseReplicator.StartReplicationFromLSN(ctx, startLSN) -} -// handleShutdownSignal waits for a shutdown signal and cancels the context. -func (r *CopyAndStreamReplicator) handleShutdownSignal(sigChan <-chan os.Signal, cancel context.CancelFunc) { - sig := <-sigChan - r.Logger.Info().Str("signal", sig.String()).Msg("Received shutdown signal") - cancel() + // Create a stop channel for graceful shutdown + stopChan := make(chan struct{}) + errChan := make(chan error, 1) + go func() { + errChan <- r.BaseReplicator.StartReplicationFromLSN(ctx, startLSN, stopChan) + }() + + select { + case <-sigChan: + r.Logger.Info().Msg("Received shutdown signal") + // Signal replication loop to stop + close(stopChan) + // Wait for replication loop to exit + <-errChan + + // Proceed with graceful shutdown + shutdownCtx, cancelShutdown := context.WithTimeout(context.Background(), 10*time.Second) + defer cancelShutdown() + if err := r.BaseReplicator.GracefulShutdown(shutdownCtx); err != nil { + r.Logger.Error().Err(err).Msg("Error during graceful shutdown") + return err + } + case err := <-errChan: + if err != nil { + r.Logger.Error().Err(err).Msg("Replication ended with error") + return err + } + } + + return nil } // ParallelCopy performs a parallel copy of all specified tables. diff --git a/pkg/replicator/stream_replicator.go b/pkg/replicator/stream_replicator.go index f94c3df..99e1cb0 100644 --- a/pkg/replicator/stream_replicator.go +++ b/pkg/replicator/stream_replicator.go @@ -6,6 +6,7 @@ import ( "os" "os/signal" "syscall" + "time" "github.com/jackc/pglogrepl" ) @@ -17,15 +18,54 @@ type StreamReplicator struct { // StartReplication begins the replication process. func (r *StreamReplicator) StartReplication() error { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + ctx := context.Background() - // Set up signal handling for graceful shutdown sigChan := make(chan os.Signal, 1) signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) - go r.handleShutdownSignal(sigChan, cancel) + if err := r.setup(ctx); err != nil { + return err + } + + startLSN, err := r.getStartLSN(ctx) + if err != nil { + return err + } + r.Logger.Info().Str("startLSN", startLSN.String()).Msg("Starting replication from LSN") + + stopChan := make(chan struct{}) + errChan := make(chan error, 1) + go func() { + errChan <- r.BaseReplicator.StartReplicationFromLSN(ctx, startLSN, stopChan) + }() + + select { + case <-sigChan: + r.Logger.Info().Msg("Received shutdown signal") + // Signal the replication loop to stop + close(stopChan) + // Wait for the replication loop to exit + <-errChan + + // Proceed with graceful shutdown + shutdownCtx, cancelShutdown := context.WithTimeout(context.Background(), 10*time.Second) + defer cancelShutdown() + if err := r.BaseReplicator.GracefulShutdown(shutdownCtx); err != nil { + r.Logger.Error().Err(err).Msg("Error during graceful shutdown") + return err + } + case err := <-errChan: + if err != nil { + r.Logger.Error().Err(err).Msg("Replication ended with error") + return err + } + } + + return nil +} + +func (r *StreamReplicator) setup(ctx context.Context) error { if err := r.BaseReplicator.CreatePublication(); err != nil { return fmt.Errorf("failed to create publication: %v", err) } @@ -34,43 +74,18 @@ func (r *StreamReplicator) StartReplication() error { return fmt.Errorf("failed to create replication slot: %v", err) } - var ddlCancel context.CancelFunc if r.Config.TrackDDL { if err := r.DDLReplicator.SetupDDLTracking(ctx); err != nil { return fmt.Errorf("failed to set up DDL tracking: %v", err) } - var ddlCtx context.Context - ddlCtx, ddlCancel = context.WithCancel(ctx) - go r.DDLReplicator.StartDDLReplication(ddlCtx) + go r.DDLReplicator.StartDDLReplication(ctx) } - defer func() { - if r.Config.TrackDDL { - ddlCancel() - if err := r.DDLReplicator.Shutdown(context.Background()); err != nil { - r.Logger.Error().Err(err).Msg("Failed to shutdown DDL replicator") - } - } - }() - if err := r.BaseReplicator.CheckReplicationSlotStatus(ctx); err != nil { return fmt.Errorf("failed to check replication slot status: %v", err) } - startLSN, err := r.getStartLSN(ctx) - if err != nil { - return err - } - - r.Logger.Info().Str("startLSN", startLSN.String()).Msg("Starting replication from LSN") - return r.BaseReplicator.StartReplicationFromLSN(ctx, startLSN) -} - -// handleShutdownSignal waits for a shutdown signal and cancels the context. -func (r *StreamReplicator) handleShutdownSignal(sigChan <-chan os.Signal, cancel context.CancelFunc) { - sig := <-sigChan - r.Logger.Info().Str("signal", sig.String()).Msg("Received shutdown signal") - cancel() + return nil } // getStartLSN determines the starting LSN for replication. diff --git a/pkg/replicator/tests/base_replicator_test.go b/pkg/replicator/tests/base_replicator_test.go index 5ebd6cb..439d9a8 100644 --- a/pkg/replicator/tests/base_replicator_test.go +++ b/pkg/replicator/tests/base_replicator_test.go @@ -219,8 +219,6 @@ func TestBaseReplicator(t *testing.T) { mockReplicationConn.On("ReceiveMessage", mock.Anything).Return(xLogData, nil).Once() mockReplicationConn.On("ReceiveMessage", mock.Anything).Return(nil, context.Canceled).Maybe() - mockNATSClient.On("GetLastState").Return(pglogrepl.LSN(0), nil).Maybe() - br := &replicator.BaseReplicator{ ReplicationConn: mockReplicationConn, StandardConn: mockStandardConn, @@ -232,7 +230,8 @@ func TestBaseReplicator(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) defer cancel() - err := br.StartReplicationFromLSN(ctx, pglogrepl.LSN(0)) + stopChan := make(chan struct{}) + err := br.StartReplicationFromLSN(ctx, pglogrepl.LSN(0), stopChan) assert.NoError(t, err, "Expected no error for graceful shutdown") mockReplicationConn.AssertExpectations(t) mockStandardConn.AssertExpectations(t) @@ -252,7 +251,8 @@ func TestBaseReplicator(t *testing.T) { Logger: zerolog.Nop(), } - err := br.StartReplicationFromLSN(context.Background(), pglogrepl.LSN(0)) + stopChan := make(chan struct{}) + err := br.StartReplicationFromLSN(context.Background(), pglogrepl.LSN(0), stopChan) assert.Error(t, err) assert.Contains(t, err.Error(), "failed to start replication") mockReplicationConn.AssertExpectations(t) @@ -260,123 +260,6 @@ func TestBaseReplicator(t *testing.T) { }) }) - t.Run("StreamChanges", func(t *testing.T) { - t.Run("Successful processing of messages", func(t *testing.T) { - mockReplicationConn := new(MockReplicationConnection) - keepaliveMsg := &pgproto3.CopyData{ - Data: []byte{ - pglogrepl.PrimaryKeepaliveMessageByteID, - 0, 0, 0, 0, 0, 0, 0, 8, // WAL end: 8 - 0, 0, 0, 0, 0, 0, 0, 0, // ServerTime: 0 - 0, // ReplyRequested: false - }, - } - mockReplicationConn.On("ReceiveMessage", mock.Anything).Return(keepaliveMsg, nil).Once() - mockReplicationConn.On("ReceiveMessage", mock.Anything).Return(nil, context.Canceled).Once() - - mockNATSClient := new(MockNATSClient) - mockNATSClient.On("GetLastState").Return(pglogrepl.LSN(0), nil).Maybe() - - br := &replicator.BaseReplicator{ - ReplicationConn: mockReplicationConn, - NATSClient: mockNATSClient, - Logger: zerolog.Nop(), - } - - ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) - defer cancel() - - err := br.StreamChanges(ctx) - assert.NoError(t, err, "Expected no error for graceful shutdown") - mockReplicationConn.AssertExpectations(t) - mockNATSClient.AssertExpectations(t) - }) - - t.Run("Context cancellation", func(t *testing.T) { - mockReplicationConn := new(MockReplicationConnection) - mockNATSClient := new(MockNATSClient) - mockNATSClient.On("GetLastState").Return(pglogrepl.LSN(0), nil).Maybe() - mockNATSClient.On("SaveState", mock.Anything).Return(nil).Maybe() - - br := &replicator.BaseReplicator{ - ReplicationConn: mockReplicationConn, - NATSClient: mockNATSClient, - } - - ctx, cancel := context.WithCancel(context.Background()) - cancel() - - err := br.StreamChanges(ctx) - assert.NoError(t, err, "Expected no error for graceful shutdown") - mockReplicationConn.AssertExpectations(t) - mockNATSClient.AssertExpectations(t) - }) - }) - - t.Run("ProcessNextMessage", func(t *testing.T) { - t.Run("Successful processing of CopyData message", func(t *testing.T) { - mockReplicationConn := new(MockReplicationConnection) - xLogData := &pgproto3.CopyData{ - Data: []byte{ - pglogrepl.XLogDataByteID, - 0, 0, 0, 0, 0, 0, 0, 1, // WAL start: 1 - 0, 0, 0, 0, 0, 0, 0, 2, // WAL end: 2 - 0, 0, 0, 0, 0, 0, 0, 0, // ServerTime: 0 - 'B', // Type: BEGIN - 0, 0, 0, 0, 0, 0, 0, 1, // LSN: 1 - 0, 0, 0, 0, 0, 0, 0, 2, // End LSN: 2 - 0, 0, 0, 0, 0, 0, 0, 0, // Timestamp - 0, 0, 0, 1, // XID: 1 - }, - } - mockReplicationConn.On("ReceiveMessage", mock.Anything).Return(xLogData, nil) - mockNATSClient := new(MockNATSClient) - - br := &replicator.BaseReplicator{ - ReplicationConn: mockReplicationConn, - NATSClient: mockNATSClient, - Logger: zerolog.Nop(), - } - - lastStatusUpdate := time.Now() - err := br.ProcessNextMessage(context.Background(), &lastStatusUpdate, time.Second) - assert.NoError(t, err) - mockReplicationConn.AssertExpectations(t) - mockNATSClient.AssertExpectations(t) - - assert.True(t, lastStatusUpdate.After(time.Now().Add(-time.Second)), "lastStatusUpdate should have been updated") - }) - - t.Run("Successful processing of other message types", func(t *testing.T) { - mockReplicationConn := new(MockReplicationConnection) - mockReplicationConn.On("ReceiveMessage", mock.Anything).Return(&pgproto3.NoticeResponse{}, nil) - - br := &replicator.BaseReplicator{ - ReplicationConn: mockReplicationConn, - } - - lastStatusUpdate := time.Now() - err := br.ProcessNextMessage(context.Background(), &lastStatusUpdate, time.Second) - assert.NoError(t, err) - mockReplicationConn.AssertExpectations(t) - }) - - t.Run("Error occurs while receiving message", func(t *testing.T) { - mockReplicationConn := new(MockReplicationConnection) - mockReplicationConn.On("ReceiveMessage", mock.Anything).Return(nil, errors.New("receive error")) - - br := &replicator.BaseReplicator{ - ReplicationConn: mockReplicationConn, - } - - lastStatusUpdate := time.Now() - err := br.ProcessNextMessage(context.Background(), &lastStatusUpdate, time.Second) - assert.Error(t, err) - assert.Contains(t, err.Error(), "failed to receive message") - mockReplicationConn.AssertExpectations(t) - }) - }) - t.Run("HandleInsertMessage", func(t *testing.T) { t.Run("Successful handling of InsertMessage", func(t *testing.T) { mockNATSClient := new(MockNATSClient) @@ -1044,37 +927,6 @@ func TestBaseReplicator(t *testing.T) { }) }) - t.Run("SendFinalStandbyStatusUpdate", func(t *testing.T) { - t.Run("Successful sending of final standby status update", func(t *testing.T) { - mockReplicationConn := new(MockReplicationConnection) - mockReplicationConn.On("SendStandbyStatusUpdate", mock.Anything, mock.Anything).Return(nil) - - br := &replicator.BaseReplicator{ - ReplicationConn: mockReplicationConn, - } - - err := br.SendFinalStandbyStatusUpdate() - assert.NoError(t, err) - - mockReplicationConn.AssertExpectations(t) - }) - - t.Run("Error sending final standby status update", func(t *testing.T) { - mockReplicationConn := new(MockReplicationConnection) - mockReplicationConn.On("SendStandbyStatusUpdate", mock.Anything, mock.Anything).Return(errors.New("send error")) - - br := &replicator.BaseReplicator{ - ReplicationConn: mockReplicationConn, - } - - err := br.SendFinalStandbyStatusUpdate() - assert.Error(t, err) - assert.Contains(t, err.Error(), "failed to send standby") - - mockReplicationConn.AssertExpectations(t) - }) - }) - t.Run("CreateReplicationSlot", func(t *testing.T) { t.Run("Slot already exists", func(t *testing.T) { mockStandardConn := new(MockStandardConnection) diff --git a/pkg/replicator/tests/copy_and_stream_replicator_test.go b/pkg/replicator/tests/copy_and_stream_replicator_test.go index 85550d6..27a8e7e 100644 --- a/pkg/replicator/tests/copy_and_stream_replicator_test.go +++ b/pkg/replicator/tests/copy_and_stream_replicator_test.go @@ -9,11 +9,9 @@ import ( "time" "github.com/goccy/go-json" - "github.com/jackc/pglogrepl" "github.com/jackc/pgtype" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" - "github.com/jackc/pgx/v5/pgproto3" "github.com/rs/zerolog" "github.com/shayonj/pg_flo/pkg/replicator" "github.com/shayonj/pg_flo/pkg/utils" @@ -23,75 +21,6 @@ import ( func TestCopyAndStreamReplicator(t *testing.T) { - t.Run("StartReplication", func(t *testing.T) { - mockReplicationConn := new(MockReplicationConnection) - mockStandardConn := new(MockStandardConnection) - mockNATSClient := new(MockNATSClient) - mockTx := new(MockTx) - - mockStandardConn.On("QueryRow", mock.Anything, "SELECT EXISTS (SELECT 1 FROM pg_publication WHERE pubname = $1)", mock.Anything).Return(MockRow{ - scanFunc: func(dest ...interface{}) error { - *dest[0].(*bool) = true - return nil - }, - }).Once() - - mockStandardConn.On("QueryRow", mock.Anything, "SELECT EXISTS (SELECT 1 FROM pg_replication_slots WHERE slot_name = $1)", mock.Anything).Return(MockRow{ - scanFunc: func(dest ...interface{}) error { - *dest[0].(*bool) = false - return nil - }, - }).Once() - - mockReplicationConn.On("CreateReplicationSlot", mock.Anything, mock.Anything).Return(pglogrepl.CreateReplicationSlotResult{}, nil) - mockReplicationConn.On("StartReplication", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) - mockReplicationConn.On("ReceiveMessage", mock.Anything).Return( - &pgproto3.ReadyForQuery{TxStatus: 'I'}, - context.Canceled, - ) - - mockStandardConn.On("BeginTx", mock.Anything, mock.AnythingOfType("pgx.TxOptions")).Return(mockTx, nil) - - mockStandardConn.On("QueryRow", mock.Anything, mock.MatchedBy(func(query string) bool { - return strings.Contains(query, "SELECT relpages") - }), mock.Anything).Return(MockRow{ - scanFunc: func(dest ...interface{}) error { - *dest[0].(*uint32) = 1 - return nil - }, - }) - - mockTx.On("QueryRow", mock.Anything, mock.MatchedBy(func(query string) bool { - return strings.Contains(query, "pg_export_snapshot()") && strings.Contains(query, "pg_current_wal_lsn()") - })).Return(MockRow{ - scanFunc: func(dest ...interface{}) error { - *dest[0].(*string) = "mock-snapshot-id" - *dest[1].(*pglogrepl.LSN) = pglogrepl.LSN(100) - return nil - }, - }) - - mockTx.On("Commit", mock.Anything).Return(nil) - - csr := &replicator.CopyAndStreamReplicator{ - BaseReplicator: replicator.BaseReplicator{ - ReplicationConn: mockReplicationConn, - StandardConn: mockStandardConn, - NATSClient: mockNATSClient, - Logger: zerolog.Nop(), - Config: replicator.Config{Group: "test_publication", Tables: []string{"users"}, Schema: "public"}, - }, - } - - err := csr.StartReplication() - assert.NoError(t, err) - - mockReplicationConn.AssertExpectations(t) - mockStandardConn.AssertExpectations(t) - mockNATSClient.AssertExpectations(t) - mockTx.AssertExpectations(t) - }) - t.Run("CopyTable", func(t *testing.T) { mockStandardConn := new(MockStandardConnection) mockNATSClient := new(MockNATSClient) diff --git a/pkg/worker/worker.go b/pkg/worker/worker.go index 7d93416..6cd1190 100644 --- a/pkg/worker/worker.go +++ b/pkg/worker/worker.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "sync" - "sync/atomic" "time" "github.com/nats-io/nats.go/jetstream" @@ -28,7 +27,6 @@ type Worker struct { flushInterval time.Duration shutdownCh chan struct{} wg sync.WaitGroup - isShuttingDown atomic.Bool } // NewWorker creates and returns a new Worker instance with the provided NATS client, rule engine, sink, and group. @@ -116,20 +114,9 @@ func (w *Worker) Start(ctx context.Context) error { <-ctx.Done() w.logger.Info().Msg("Received shutdown signal. Initiating graceful shutdown...") - w.isShuttingDown.Store(true) - done := make(chan struct{}) - go func() { - w.wg.Wait() - close(done) - }() - - select { - case <-done: - w.logger.Debug().Msg("All goroutines finished") - case <-time.After(30 * time.Second): - w.logger.Warn().Msg("Shutdown timed out, forcing exit") - } + w.wg.Wait() + w.logger.Debug().Msg("All goroutines finished") return w.flushBuffer() } @@ -140,29 +127,30 @@ func (w *Worker) processMessages(ctx context.Context, cons jetstream.Consumer) e if err != nil { return fmt.Errorf("failed to get message iterator: %w", err) } - defer iter.Stop() flushTicker := time.NewTicker(w.flushInterval) defer flushTicker.Stop() + go func() { + <-ctx.Done() + w.logger.Debug().Msg("Context canceled, stopping iterator") + iter.Stop() + }() + for { select { case <-ctx.Done(): - w.logger.Info().Msg("Context canceled, flushing remaining messages") + w.logger.Info().Msg("Flushing remaining messages") return w.flushBuffer() case <-flushTicker.C: if err := w.flushBuffer(); err != nil { w.logger.Error().Err(err).Msg("Failed to flush buffer on interval") } default: - if w.isShuttingDown.Load() { - w.logger.Info().Msg("Shutdown in progress, stopping message processing") - return w.flushBuffer() - } - msg, err := iter.Next() if err != nil { - if err == context.Canceled { + if err == jetstream.ErrMsgIteratorClosed { + w.logger.Info().Msg("Iterator closed, flushing buffer") return w.flushBuffer() } w.logger.Error().Err(err).Msg("Failed to get next message")