diff --git a/internal/e2e_copy_and_stream.sh b/internal/e2e_copy_and_stream.sh index 83edf2e..aaf92ef 100755 --- a/internal/e2e_copy_and_stream.sh +++ b/internal/e2e_copy_and_stream.sh @@ -37,6 +37,7 @@ populate_initial_data() { decode(lpad(to_hex(generate_series(1, 4)), 8, '0'), 'hex') ;" run_sql "UPDATE public.users SET text_col = text_col || ' - Updated';" + run_sql "ANALYZE public.users;" success "Initial data populated" } @@ -121,7 +122,7 @@ test_pg_flo_cdc() { simulate_concurrent_changes log "Waiting for pg_flo to process changes..." - sleep 30 + sleep 2 stop_pg_flo_gracefully diff --git a/internal/e2e_resume.sh b/internal/e2e_resume.sh index 033371d..e4cda6f 100755 --- a/internal/e2e_resume.sh +++ b/internal/e2e_resume.sh @@ -17,7 +17,7 @@ create_users() { start_pg_flo_replication() { log "Starting pg_flo replication..." - $pg_flo_BIN stream file \ + $pg_flo_BIN replicator \ --host "$PG_HOST" \ --port "$PG_PORT" \ --dbname "$PG_DB" \ @@ -26,13 +26,25 @@ start_pg_flo_replication() { --group "group_resume" \ --tables "users" \ --schema "public" \ - --status-dir "/tmp" \ - --output-dir "$OUTPUT_DIR" >"$pg_flo_LOG" 2>&1 & + --nats-url "$NATS_URL" \ + >"$pg_flo_LOG" 2>&1 & pg_flo_PID=$! log "pg_flo started with PID: $pg_flo_PID" success "pg_flo replication started" } +start_pg_flo_worker() { + log "Starting pg_flo worker with file sink..." + $pg_flo_BIN worker file \ + --group "group_resume" \ + --nats-url "$NATS_URL" \ + --file-output-dir "$OUTPUT_DIR" \ + >"$pg_flo_WORKER_LOG" 2>&1 & + pg_flo_WORKER_PID=$! + log "pg_flo worker started with PID: $pg_flo_WORKER_PID" + success "pg_flo worker started" +} + simulate_concurrent_inserts() { log "Starting concurrent inserts..." for i in $(seq 1 $TOTAL_INSERTS); do @@ -43,20 +55,14 @@ simulate_concurrent_inserts() { } interrupt_pg_flo() { - log "Interrupting pg_flo process..." - if kill -0 $pg_flo_PID 2>/dev/null; then - kill -15 $pg_flo_PID - wait $pg_flo_PID 2>/dev/null || true - success "pg_flo process interrupted" - else - log "pg_flo process not found, it may have already stopped" - fi + log "Interrupting pg_flo processes..." + stop_pg_flo_gracefully } verify_results() { log "Verifying results..." local db_count=$(run_sql "SELECT COUNT(*) FROM public.users") - local json_count=$(jq -s '[.[] | select(.type == "INSERT")] | length' "$OUTPUT_DIR"/*.jsonl) + local json_count=$(jq -s '[.[] | select(.Type == "INSERT")] | length' "$OUTPUT_DIR"/*.jsonl) log "Database row count: $db_count" log "JSON INSERT count: $json_count" @@ -74,6 +80,7 @@ test_pg_flo_resume() { setup_postgres create_users start_pg_flo_replication + start_pg_flo_worker rm -f $INSERT_COMPLETE_FLAG @@ -90,6 +97,7 @@ test_pg_flo_resume() { sleep $RESUME_WAIT_TIME start_pg_flo_replication + start_pg_flo_worker log "Waiting for inserts to complete..." while [ ! -f $INSERT_COMPLETE_FLAG ]; do diff --git a/internal/e2e_test_local.sh b/internal/e2e_test_local.sh index 9bcf992..faa2589 100755 --- a/internal/e2e_test_local.sh +++ b/internal/e2e_test_local.sh @@ -30,7 +30,7 @@ trap cleanup EXIT make build -setup_docker +# setup_docker log "Running e2e copy & stream tests..." if CI=false ./internal/e2e_copy_and_stream.sh; then @@ -40,37 +40,37 @@ else exit 1 fi -# setup_docker +setup_docker -# log "Running new e2e stream tests with changes..." -# if ./internal/e2e_test_stream.sh; then -# success "New e2e tests with changes completed successfully" -# else -# error "New e2e tests with changes failed" -# exit 1 -# fi +log "Running new e2e stream tests with changes..." +if ./internal/e2e_test_stream.sh; then + success "New e2e tests with changes completed successfully" +else + error "New e2e tests with changes failed" + exit 1 +fi -# setup_docker +setup_docker -# # Run new e2e resume test -# log "Running new e2e resume test..." -# if ./internal/e2e_resume.sh; then -# success "E2E resume test completed successfully" -# else -# error "E2E resume test failed" -# exit 1 -# fi +# Run new e2e resume test +log "Running new e2e resume test..." +if ./internal/e2e_resume.sh; then + success "E2E resume test completed successfully" +else + error "E2E resume test failed" + exit 1 +fi -# setup_docker +setup_docker -# # Run new e2e test for transform & filter -# log "Running new e2e test for transform & filter..." -# if ./internal/e2e_transform_filter.sh; then -# success "E2E test for transform & filter test completed successfully" -# else -# error "E2E test for transform & filter test failed" -# exit 1 -# fi +# Run new e2e test for transform & filter +log "Running new e2e test for transform & filter..." +if ./internal/e2e_transform_filter.sh; then + success "E2E test for transform & filter test completed successfully" +else + error "E2E test for transform & filter test failed" + exit 1 +fi # setup_docker diff --git a/internal/how-it-works.md b/internal/how-it-works.md index 59bb179..3087402 100644 --- a/internal/how-it-works.md +++ b/internal/how-it-works.md @@ -16,7 +16,6 @@ - If no valid LSN (Log Sequence Number) is found in the target sink, `pg_flo` performs an initial bulk copy of existing data. - This process is parallelized for fast data sync: - - Tables are analyzed to optimize the copy process. - A snapshot is taken to ensure consistency. - Each table is divided into page ranges. - Multiple workers copy different ranges concurrently. diff --git a/pkg/pgflonats/pgflonats.go b/pkg/pgflonats/pgflonats.go index 76d84ad..524d6d5 100644 --- a/pkg/pgflonats/pgflonats.go +++ b/pkg/pgflonats/pgflonats.go @@ -25,7 +25,13 @@ type NATSClient struct { stateBucket string } -// NewNATSClient creates a new NATS client with the specified configuration +// State represents the current state of the replication process +type State struct { + LSN pglogrepl.LSN `json:"lsn"` + LastProcessedSeq uint64 `json:"last_processed_seq"` +} + +// NewNATSClient creates a new NATS client with the specified configuration, setting up the connection, main stream, and state bucket. func NewNATSClient(url, stream, group string) (*NATSClient, error) { if url == "" { url = os.Getenv(envNATSURL) @@ -79,7 +85,7 @@ func NewNATSClient(url, stream, group string) (*NATSClient, error) { _, err = js.CreateKeyValue(context.Background(), jetstream.KeyValueConfig{ Bucket: stateBucket, }) - if err != nil && err != jetstream.ErrConsumerNameAlreadyInUse { + if err != nil && err != jetstream.ErrBucketExists { return nil, fmt.Errorf("failed to create state bucket: %w", err) } @@ -91,7 +97,7 @@ func NewNATSClient(url, stream, group string) (*NATSClient, error) { }, nil } -// PublishMessage publishes a message to the specified NATS subject +// PublishMessage publishes a message to the specified NATS subject. func (nc *NATSClient) PublishMessage(ctx context.Context, subject string, data []byte) error { _, err := nc.js.Publish(ctx, subject, data) if err != nil { @@ -100,13 +106,13 @@ func (nc *NATSClient) PublishMessage(ctx context.Context, subject string, data [ return nil } -// Close closes the NATS connection +// Close closes the NATS connection. func (nc *NATSClient) Close() error { nc.conn.Close() return nil } -// GetStreamInfo retrieves information about the NATS stream +// GetStreamInfo retrieves information about the NATS stream. func (nc *NATSClient) GetStreamInfo(ctx context.Context) (*jetstream.StreamInfo, error) { stream, err := nc.js.Stream(ctx, nc.stream) if err != nil { @@ -115,7 +121,7 @@ func (nc *NATSClient) GetStreamInfo(ctx context.Context) (*jetstream.StreamInfo, return stream.Info(ctx) } -// PurgeStream purges all messages from the NATS stream +// PurgeStream purges all messages from the NATS stream. func (nc *NATSClient) PurgeStream(ctx context.Context) error { stream, err := nc.js.Stream(ctx, nc.stream) if err != nil { @@ -124,24 +130,24 @@ func (nc *NATSClient) PurgeStream(ctx context.Context) error { return stream.Purge(ctx) } -// DeleteStream deletes the NATS stream +// DeleteStream deletes the NATS stream. func (nc *NATSClient) DeleteStream(ctx context.Context) error { return nc.js.DeleteStream(ctx, nc.stream) } -// SaveState saves the current replication state to NATS -func (nc *NATSClient) SaveState(ctx context.Context, lsn pglogrepl.LSN) error { +// SaveState saves the current replication state to NATS. +func (nc *NATSClient) SaveState(ctx context.Context, state State) error { kv, err := nc.js.KeyValue(ctx, nc.stateBucket) if err != nil { return fmt.Errorf("failed to get KV bucket: %v", err) } - data, err := json.Marshal(lsn) + data, err := json.Marshal(state) if err != nil { return fmt.Errorf("failed to marshal state: %v", err) } - _, err = kv.Put(ctx, "lsn", data) + _, err = kv.Put(ctx, "state", data) if err != nil { return fmt.Errorf("failed to save state: %v", err) } @@ -149,30 +155,34 @@ func (nc *NATSClient) SaveState(ctx context.Context, lsn pglogrepl.LSN) error { return nil } -// GetLastState retrieves the last saved replication state from NATS -func (nc *NATSClient) GetLastState(ctx context.Context) (pglogrepl.LSN, error) { +// GetState retrieves the last saved state from NATS, initializing a new state if none is found. +func (nc *NATSClient) GetState(ctx context.Context) (State, error) { kv, err := nc.js.KeyValue(ctx, nc.stateBucket) if err != nil { - return 0, fmt.Errorf("failed to get KV bucket: %v", err) + return State{}, fmt.Errorf("failed to get KV bucket: %v", err) } - entry, err := kv.Get(ctx, "lsn") + entry, err := kv.Get(ctx, "state") if err != nil { if err == jetstream.ErrKeyNotFound { - return 0, nil // No state yet, start from the beginning + initialState := State{LastProcessedSeq: 0} + if err := nc.SaveState(ctx, initialState); err != nil { + return State{}, fmt.Errorf("failed to save initial state: %v", err) + } + return initialState, nil } - return 0, fmt.Errorf("failed to get last state: %v", err) + return State{}, fmt.Errorf("failed to get state: %v", err) } - var lsn pglogrepl.LSN - if err := json.Unmarshal(entry.Value(), &lsn); err != nil { - return 0, fmt.Errorf("failed to unmarshal state: %v", err) + var state State + if err := json.Unmarshal(entry.Value(), &state); err != nil { + return State{}, fmt.Errorf("failed to unmarshal state: %v", err) } - return lsn, nil + return state, nil } -// JetStream returns the JetStream context +// JetStream returns the JetStream context. func (nc *NATSClient) JetStream() jetstream.JetStream { return nc.js } diff --git a/pkg/replicator/base_replicator.go b/pkg/replicator/base_replicator.go index 153f405..3619b2e 100644 --- a/pkg/replicator/base_replicator.go +++ b/pkg/replicator/base_replicator.go @@ -124,11 +124,8 @@ func (r *BaseReplicator) checkPublicationExists(publicationName string) (bool, e // StartReplicationFromLSN initiates the replication process from a given LSN func (r *BaseReplicator) StartReplicationFromLSN(ctx context.Context, startLSN pglogrepl.LSN) error { - if err := r.CreatePublication(); err != nil { - return err - } - publicationName := GeneratePublicationName(r.Config.Group) + r.Logger.Info().Str("startLSN", startLSN.String()).Str("publication", publicationName).Msg("Starting replication") err := r.ReplicationConn.StartReplication(ctx, publicationName, startLSN, pglogrepl.StartReplicationOptions{ PluginArgs: []string{ "proto_version '1'", @@ -308,11 +305,12 @@ func (r *BaseReplicator) HandleInsertMessage(ctx context.Context, msg *pglogrepl } cdcMessage := utils.CDCMessage{ - Type: "INSERT", - Schema: relation.Namespace, - Table: relation.RelationName, - Columns: relation.Columns, - NewTuple: msg.Tuple, + Type: "INSERT", + Schema: relation.Namespace, + Table: relation.RelationName, + Columns: relation.Columns, + EmittedAt: time.Now(), + NewTuple: msg.Tuple, } r.AddPrimaryKeyInfo(&cdcMessage, relation.RelationName) @@ -348,11 +346,12 @@ func (r *BaseReplicator) HandleDeleteMessage(ctx context.Context, msg *pglogrepl // todo: write lsn cdcMessage := utils.CDCMessage{ - Type: "DELETE", - Schema: relation.Namespace, - Table: relation.RelationName, - Columns: relation.Columns, - OldTuple: msg.OldTuple, + Type: "DELETE", + Schema: relation.Namespace, + Table: relation.RelationName, + Columns: relation.Columns, + OldTuple: msg.OldTuple, + EmittedAt: time.Now(), } r.AddPrimaryKeyInfo(&cdcMessage, relation.RelationName) @@ -528,10 +527,33 @@ func (r *BaseReplicator) getPrimaryKeyColumn(schema, table string) (string, erro // SaveState saves the current replication state func (r *BaseReplicator) SaveState(ctx context.Context, lsn pglogrepl.LSN) error { - return r.NATSClient.SaveState(ctx, lsn) + state, err := r.NATSClient.GetState(ctx) + if err != nil { + return fmt.Errorf("failed to get current state: %w", err) + } + state.LSN = lsn + return r.NATSClient.SaveState(ctx, state) } // GetLastState retrieves the last saved replication state func (r *BaseReplicator) GetLastState(ctx context.Context) (pglogrepl.LSN, error) { - return r.NATSClient.GetLastState(ctx) + state, err := r.NATSClient.GetState(ctx) + if err != nil { + return 0, fmt.Errorf("failed to get state: %w", err) + } + return state.LSN, nil +} + +// CheckReplicationSlotStatus checks the status of the replication slot +func (r *BaseReplicator) CheckReplicationSlotStatus(ctx context.Context) error { + publicationName := GeneratePublicationName(r.Config.Group) + var restartLSN string + err := r.StandardConn.QueryRow(ctx, + "SELECT restart_lsn FROM pg_replication_slots WHERE slot_name = $1", + publicationName).Scan(&restartLSN) + if err != nil { + return fmt.Errorf("failed to query replication slot status: %w", err) + } + r.Logger.Info().Str("slotName", publicationName).Str("restartLSN", restartLSN).Msg("Replication slot status") + return nil } diff --git a/pkg/replicator/copy_and_stream_replicator.go b/pkg/replicator/copy_and_stream_replicator.go index 3f9621f..b2075df 100644 --- a/pkg/replicator/copy_and_stream_replicator.go +++ b/pkg/replicator/copy_and_stream_replicator.go @@ -7,6 +7,7 @@ import ( "os/signal" "sync" "syscall" + "time" "github.com/jackc/pglogrepl" "github.com/jackc/pgx/v5" @@ -35,6 +36,10 @@ func (r *CopyAndStreamReplicator) StartReplication() error { go r.handleShutdownSignal(sigChan, cancel) + if err := r.BaseReplicator.CreatePublication(); err != nil { + return fmt.Errorf("failed to create publication: %v", err) + } + if err := r.BaseReplicator.CreateReplicationSlot(ctx); err != nil { return fmt.Errorf("failed to create replication slot: %v", err) } @@ -58,10 +63,10 @@ func (r *CopyAndStreamReplicator) StartReplication() error { } }() - startLSN, err := r.getStartLSN() - if err != nil { - return err + if copyErr := r.ParallelCopy(context.Background()); copyErr != nil { + return fmt.Errorf("failed to perform parallel copy: %v", copyErr) } + startLSN := r.BaseReplicator.LastLSN r.Logger.Info().Str("startLSN", startLSN.String()).Msg("Starting replication from LSN") return r.BaseReplicator.StartReplicationFromLSN(ctx, startLSN) @@ -74,24 +79,8 @@ func (r *CopyAndStreamReplicator) handleShutdownSignal(sigChan <-chan os.Signal, cancel() } -// getStartLSN determines the starting LSN for replication. -func (r *CopyAndStreamReplicator) getStartLSN() (pglogrepl.LSN, error) { - r.Logger.Info().Msg("No valid LSN found in sink, starting initial copy") - if copyErr := r.ParallelCopy(context.Background()); copyErr != nil { - return 0, fmt.Errorf("failed to perform parallel copy: %v", copyErr) - } - return r.BaseReplicator.LastLSN, nil -} - // ParallelCopy performs a parallel copy of all specified tables. func (r *CopyAndStreamReplicator) ParallelCopy(ctx context.Context) error { - - for _, table := range r.Config.Tables { - if err := r.analyzeTable(ctx, table); err != nil { - return fmt.Errorf("failed to analyze table %s: %v", table, err) - } - } - tx, err := r.startSnapshotTransaction(ctx) if err != nil { return err @@ -113,13 +102,6 @@ func (r *CopyAndStreamReplicator) ParallelCopy(ctx context.Context) error { return tx.Commit(context.Background()) } -// analyzeTable runs ANALYZE on the specified table. -func (r *CopyAndStreamReplicator) analyzeTable(ctx context.Context, tableName string) error { - r.Logger.Info().Str("table", tableName).Msg("Running ANALYZE on table") - _, err := r.BaseReplicator.StandardConn.Exec(ctx, fmt.Sprintf("ANALYZE %s", pgx.Identifier{tableName}.Sanitize())) - return err -} - // startSnapshotTransaction starts a new transaction with serializable isolation level. func (r *CopyAndStreamReplicator) startSnapshotTransaction(ctx context.Context) (pgx.Tx, error) { return r.BaseReplicator.StandardConn.BeginTx(ctx, pgx.TxOptions{ @@ -334,7 +316,7 @@ func (r *CopyAndStreamReplicator) executeCopyQuery(ctx context.Context, tx pgx.T Columns: make([]*pglogrepl.TupleDataColumn, len(values)), } for i, value := range values { - data, err := utils.ConvertToPgOutput(value, fieldDescriptions[i].DataTypeOID) + data, err := utils.ConvertToPgCompatibleOutput(value, fieldDescriptions[i].DataTypeOID) if err != nil { return 0, fmt.Errorf("error converting value: %v", err) } @@ -346,11 +328,12 @@ func (r *CopyAndStreamReplicator) executeCopyQuery(ctx context.Context, tx pgx.T } cdcMessage := utils.CDCMessage{ - Type: "INSERT", - Schema: schema, - Table: tableName, - Columns: columns, - NewTuple: tupleData, + Type: "INSERT", + Schema: schema, + Table: tableName, + Columns: columns, + NewTuple: tupleData, + EmittedAt: time.Now(), } r.BaseReplicator.AddPrimaryKeyInfo(&cdcMessage, tableName) diff --git a/pkg/replicator/ddl_replicator.go b/pkg/replicator/ddl_replicator.go index 051adfa..98563d8 100644 --- a/pkg/replicator/ddl_replicator.go +++ b/pkg/replicator/ddl_replicator.go @@ -195,15 +195,16 @@ func (d *DDLReplicator) ProcessDDLEvents(ctx context.Context) error { } cdcMessage := utils.CDCMessage{ - Type: "DDL", - Schema: schema, - Table: table, - CommitTimestamp: createdAt, + Type: "DDL", + Schema: schema, + Table: table, + EmittedAt: time.Now(), Columns: []*pglogrepl.RelationMessageColumn{ {Name: "event_type", DataType: pgtype.TextOID}, {Name: "object_type", DataType: pgtype.TextOID}, {Name: "object_identity", DataType: pgtype.TextOID}, {Name: "ddl_command", DataType: pgtype.TextOID}, + {Name: "created_at", DataType: pgtype.TimestamptzOID}, }, NewTuple: &pglogrepl.TupleData{ Columns: []*pglogrepl.TupleDataColumn{ @@ -211,6 +212,7 @@ func (d *DDLReplicator) ProcessDDLEvents(ctx context.Context) error { {Data: []byte(objectType)}, {Data: []byte(objectIdentity)}, {Data: []byte(ddlCommand)}, + {Data: []byte(createdAt.Format(time.RFC3339))}, }, }, } diff --git a/pkg/replicator/interfaces.go b/pkg/replicator/interfaces.go index 21e41f1..2e3024b 100644 --- a/pkg/replicator/interfaces.go +++ b/pkg/replicator/interfaces.go @@ -8,6 +8,7 @@ import ( "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgproto3" "github.com/nats-io/nats.go/jetstream" + "github.com/shayonj/pg_flo/pkg/pgflonats" ) type Replicator interface { @@ -44,10 +45,7 @@ type PgxPoolConn interface { type NATSClient interface { PublishMessage(ctx context.Context, subject string, data []byte) error Close() error - GetStreamInfo(ctx context.Context) (*jetstream.StreamInfo, error) - PurgeStream(ctx context.Context) error - DeleteStream(ctx context.Context) error - SaveState(ctx context.Context, lsn pglogrepl.LSN) error - GetLastState(ctx context.Context) (pglogrepl.LSN, error) + SaveState(ctx context.Context, state pgflonats.State) error + GetState(ctx context.Context) (pgflonats.State, error) JetStream() jetstream.JetStream } diff --git a/pkg/replicator/stream_replicator.go b/pkg/replicator/stream_replicator.go index d04089c..f94c3df 100644 --- a/pkg/replicator/stream_replicator.go +++ b/pkg/replicator/stream_replicator.go @@ -26,6 +26,10 @@ func (r *StreamReplicator) StartReplication() error { go r.handleShutdownSignal(sigChan, cancel) + if err := r.BaseReplicator.CreatePublication(); err != nil { + return fmt.Errorf("failed to create publication: %v", err) + } + if err := r.BaseReplicator.CreateReplicationSlot(ctx); err != nil { return fmt.Errorf("failed to create replication slot: %v", err) } @@ -49,6 +53,10 @@ func (r *StreamReplicator) StartReplication() error { } }() + if err := r.BaseReplicator.CheckReplicationSlotStatus(ctx); err != nil { + return fmt.Errorf("failed to check replication slot status: %v", err) + } + startLSN, err := r.getStartLSN(ctx) if err != nil { return err diff --git a/pkg/replicator/tests/base_replicator_test.go b/pkg/replicator/tests/base_replicator_test.go index 7f99fd9..5ebd6cb 100644 --- a/pkg/replicator/tests/base_replicator_test.go +++ b/pkg/replicator/tests/base_replicator_test.go @@ -15,6 +15,7 @@ import ( "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgproto3" "github.com/rs/zerolog" + "github.com/shayonj/pg_flo/pkg/pgflonats" "github.com/shayonj/pg_flo/pkg/replicator" "github.com/shayonj/pg_flo/pkg/utils" "github.com/stretchr/testify/assert" @@ -189,15 +190,6 @@ func TestBaseReplicator(t *testing.T) { mockStandardConn := new(MockStandardConnection) mockNATSClient := new(MockNATSClient) - // Mock CreatePublication - mockStandardConn.On("QueryRow", mock.Anything, "SELECT EXISTS (SELECT 1 FROM pg_publication WHERE pubname = $1)", mock.Anything). - Return(MockRow{ - scanFunc: func(dest ...interface{}) error { - *dest[0].(*bool) = true // Publication already exists - return nil - }, - }) - mockReplicationConn.On("StartReplication", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) keepaliveMsg := &pgproto3.CopyData{ @@ -251,14 +243,6 @@ func TestBaseReplicator(t *testing.T) { mockReplicationConn := new(MockReplicationConnection) mockStandardConn := new(MockStandardConnection) - mockStandardConn.On("QueryRow", mock.Anything, "SELECT EXISTS (SELECT 1 FROM pg_publication WHERE pubname = $1)", mock.Anything). - Return(MockRow{ - scanFunc: func(dest ...interface{}) error { - *dest[0].(*bool) = true - return nil - }, - }) - mockReplicationConn.On("StartReplication", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(errors.New("start replication error")) br := &replicator.BaseReplicator{ @@ -871,7 +855,9 @@ func TestBaseReplicator(t *testing.T) { CommitLSN: 12345, } - mockNATSClient.On("SaveState", mock.Anything, pglogrepl.LSN(12345)).Return(nil) + mockNATSClient.On("GetState", mock.Anything).Return(pgflonats.State{}, nil) + + mockNATSClient.On("SaveState", mock.Anything, pgflonats.State{LSN: pglogrepl.LSN(12345)}).Return(nil) ctx := context.Background() err := br.HandleCommitMessage(ctx, msg) diff --git a/pkg/replicator/tests/copy_and_stream_replicator_test.go b/pkg/replicator/tests/copy_and_stream_replicator_test.go index 6fd26d9..85550d6 100644 --- a/pkg/replicator/tests/copy_and_stream_replicator_test.go +++ b/pkg/replicator/tests/copy_and_stream_replicator_test.go @@ -8,6 +8,7 @@ import ( "testing" "time" + "github.com/goccy/go-json" "github.com/jackc/pglogrepl" "github.com/jackc/pgtype" "github.com/jackc/pgx/v5" @@ -55,7 +56,7 @@ func TestCopyAndStreamReplicator(t *testing.T) { return strings.Contains(query, "SELECT relpages") }), mock.Anything).Return(MockRow{ scanFunc: func(dest ...interface{}) error { - *dest[0].(*uint32) = 1 // Mock 1 page for simplicity + *dest[0].(*uint32) = 1 return nil }, }) @@ -70,14 +71,6 @@ func TestCopyAndStreamReplicator(t *testing.T) { }, }) - mockStandardConn.On("Exec", - mock.Anything, - mock.MatchedBy(func(query string) bool { - return strings.HasPrefix(query, "ANALYZE") - }), - mock.Anything, - ).Return(pgconn.CommandTag{}, nil) - mockTx.On("Commit", mock.Anything).Return(nil) csr := &replicator.CopyAndStreamReplicator{ @@ -307,10 +300,10 @@ func TestCopyAndStreamReplicator(t *testing.T) { time.Date(2023, time.May, 1, 12, 34, 56, 789000000, time.UTC), }, expected: []map[string]interface{}{ - {"name": "data", "type": "jsonb", "value": `{"key": "value"}`}, + {"name": "data", "type": "jsonb", "value": json.RawMessage(`{"key": "value"}`)}, {"name": "tags", "type": "text[]", "value": "{tag1,tag2,tag3}"}, {"name": "image", "type": "bytea", "value": `\x01020304`}, - {"name": "created_at", "type": "timestamptz", "value": time.Time(time.Date(2023, time.May, 1, 12, 34, 56, 789000000, time.UTC))}, + {"name": "created_at", "type": "timestamptz", "value": time.Date(2023, time.May, 1, 12, 34, 56, 789000000, time.UTC)}, }, }, { @@ -392,13 +385,15 @@ func TestCopyAndStreamReplicator(t *testing.T) { case "text", "varchar": assert.Equal(t, expectedVal, string(actualColumn.Data)) case "jsonb": - assert.JSONEq(t, expectedVal.(string), string(actualColumn.Data)) + assert.JSONEq(t, string(expectedVal.(json.RawMessage)), string(actualColumn.Data)) + case "text[]": + assert.Equal(t, expectedVal, string(actualColumn.Data)) case "bytea": assert.Equal(t, expectedVal, string(actualColumn.Data)) case "timestamptz": actualTime, err := time.Parse(time.RFC3339Nano, string(actualColumn.Data)) assert.NoError(t, err) - assert.Equal(t, expectedVal, actualTime) + assert.Equal(t, expectedVal.(time.Time), actualTime) case "numeric": assert.Equal(t, expectedVal, string(actualColumn.Data)) default: diff --git a/pkg/replicator/tests/mocks_test.go b/pkg/replicator/tests/mocks_test.go index da7d64b..ac6feb9 100644 --- a/pkg/replicator/tests/mocks_test.go +++ b/pkg/replicator/tests/mocks_test.go @@ -8,6 +8,7 @@ import ( "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgproto3" "github.com/nats-io/nats.go/jetstream" + "github.com/shayonj/pg_flo/pkg/pgflonats" "github.com/shayonj/pg_flo/pkg/replicator" "github.com/stretchr/testify/mock" ) @@ -267,34 +268,16 @@ func (m *MockNATSClient) Close() error { return args.Error(0) } -// GetStreamInfo mocks the GetStreamInfo method -func (m *MockNATSClient) GetStreamInfo(ctx context.Context) (*jetstream.StreamInfo, error) { - args := m.Called(ctx) - return args.Get(0).(*jetstream.StreamInfo), args.Error(1) -} - -// PurgeStream mocks the PurgeStream method -func (m *MockNATSClient) PurgeStream(ctx context.Context) error { - args := m.Called(ctx) - return args.Error(0) -} - -// DeleteStream mocks the DeleteStream method -func (m *MockNATSClient) DeleteStream(ctx context.Context) error { - args := m.Called(ctx) - return args.Error(0) -} - // SaveState mocks the SaveState method -func (m *MockNATSClient) SaveState(ctx context.Context, lsn pglogrepl.LSN) error { - args := m.Called(ctx, lsn) +func (m *MockNATSClient) SaveState(ctx context.Context, state pgflonats.State) error { + args := m.Called(ctx, state) return args.Error(0) } -// GetLastState mocks the GetLastState method -func (m *MockNATSClient) GetLastState(ctx context.Context) (pglogrepl.LSN, error) { +// GetState mocks the GetState method +func (m *MockNATSClient) GetState(ctx context.Context) (pgflonats.State, error) { args := m.Called(ctx) - return args.Get(0).(pglogrepl.LSN), args.Error(1) + return args.Get(0).(pgflonats.State), args.Error(1) } // JetStream mocks the JetStream method diff --git a/pkg/utils/cdc_encoding.go b/pkg/utils/cdc_encoding.go index f0661cf..a85c4e2 100644 --- a/pkg/utils/cdc_encoding.go +++ b/pkg/utils/cdc_encoding.go @@ -8,24 +8,33 @@ import ( "strings" "time" - "github.com/jackc/pgtype" + "github.com/jackc/pgx/v5/pgtype" ) -// ConvertToPgOutput converts a Go value to its PostgreSQL output format -func ConvertToPgOutput(value interface{}, oid uint32) ([]byte, error) { +// ConvertToPgCompatibleOutput converts a Go value to its PostgreSQL output format. +func ConvertToPgCompatibleOutput(value interface{}, oid uint32) ([]byte, error) { if value == nil { - return nil, nil + return []byte("NULL"), nil } switch oid { case pgtype.BoolOID: - return []byte(strconv.FormatBool(value.(bool))), nil + return strconv.AppendBool(nil, value.(bool)), nil case pgtype.Int2OID, pgtype.Int4OID, pgtype.Int8OID: - return []byte(fmt.Sprintf("%d", value)), nil + switch v := value.(type) { + case int: + return []byte(strconv.FormatInt(int64(v), 10)), nil + case int32: + return []byte(strconv.FormatInt(int64(v), 10)), nil + case int64: + return []byte(strconv.FormatInt(v, 10)), nil + default: + return []byte(fmt.Sprintf("%d", value)), nil + } case pgtype.Float4OID, pgtype.Float8OID: - return []byte(fmt.Sprintf("%g", value)), nil + return []byte(strconv.FormatFloat(value.(float64), 'f', -1, 64)), nil case pgtype.NumericOID: - return []byte(value.(string)), nil + return []byte(fmt.Sprintf("%v", value)), nil case pgtype.TextOID, pgtype.VarcharOID: return []byte(value.(string)), nil case pgtype.ByteaOID: @@ -35,65 +44,72 @@ func ConvertToPgOutput(value interface{}, oid uint32) ([]byte, error) { case pgtype.DateOID: return []byte(value.(time.Time).Format("2006-01-02")), nil case pgtype.JSONOID, pgtype.JSONBOID: - if b, ok := value.([]byte); ok { - return b, nil + if jsonBytes, ok := value.([]byte); ok { + return jsonBytes, nil } return json.Marshal(value) - case pgtype.TextArrayOID, pgtype.VarcharArrayOID: - return EncodeTextArray(value) - case pgtype.Int2ArrayOID, pgtype.Int4ArrayOID, pgtype.Int8ArrayOID, pgtype.Float4ArrayOID, pgtype.Float8ArrayOID, pgtype.BoolArrayOID: + case pgtype.TextArrayOID, pgtype.VarcharArrayOID, + pgtype.Int2ArrayOID, pgtype.Int4ArrayOID, pgtype.Int8ArrayOID, + pgtype.Float4ArrayOID, pgtype.Float8ArrayOID, pgtype.BoolArrayOID: return EncodeArray(value) default: return []byte(fmt.Sprintf("%v", value)), nil } } -// EncodeTextArray encodes a slice of strings into a PostgreSQL text array format -func EncodeTextArray(value interface{}) ([]byte, error) { - v := reflect.ValueOf(value) - if v.Kind() != reflect.Slice { - return nil, fmt.Errorf("expected slice, got %T", value) - } - +// EncodeArray encodes a slice of values into a PostgreSQL array format. +func EncodeArray(value interface{}) ([]byte, error) { var elements []string - for i := 0; i < v.Len(); i++ { - elem := v.Index(i).Interface() - str, ok := elem.(string) - if !ok { - return nil, fmt.Errorf("expected string element in text array, got %T", elem) + + switch slice := value.(type) { + case []interface{}: + for _, v := range slice { + elem, err := encodeArrayElement(v) + if err != nil { + return nil, err + } + elements = append(elements, elem) } - elements = append(elements, QuoteArrayElement(str)) + case []string: + elements = append(elements, slice...) + case []int, []int32, []int64, []float32, []float64, []bool: + sliceValue := reflect.ValueOf(slice) + for i := 0; i < sliceValue.Len(); i++ { + elem, err := encodeArrayElement(sliceValue.Index(i).Interface()) + if err != nil { + return nil, err + } + elements = append(elements, elem) + } + default: + return nil, fmt.Errorf("unsupported slice type: %T", value) } return []byte("{" + strings.Join(elements, ",") + "}"), nil } -// EncodeArray encodes a slice of values into a PostgreSQL array format -func EncodeArray(value interface{}) ([]byte, error) { - v := reflect.ValueOf(value) - if v.Kind() != reflect.Slice { - return nil, fmt.Errorf("expected slice, got %T", value) +// encodeArrayElement encodes a single array element into a string representation. +func encodeArrayElement(v interface{}) (string, error) { + if v == nil { + return "NULL", nil } - var elements []string - for i := 0; i < v.Len(); i++ { - elem := v.Index(i).Interface() - if elem == nil { - elements = append(elements, "NULL") - } else { - elements = append(elements, fmt.Sprintf("%v", elem)) + switch val := v.(type) { + case string: + return val, nil + case int, int32, int64, float32, float64: + return fmt.Sprintf("%v", val), nil + case bool: + return strconv.FormatBool(val), nil + case time.Time: + return val.Format(time.RFC3339Nano), nil + case []byte: + return fmt.Sprintf("\\x%x", val), nil + default: + jsonBytes, err := json.Marshal(val) + if err != nil { + return "", fmt.Errorf("failed to marshal array element to JSON: %w", err) } + return string(jsonBytes), nil } - - return []byte("{" + strings.Join(elements, ",") + "}"), nil -} - -// QuoteArrayElement quotes a string element for use in a PostgreSQL array -func QuoteArrayElement(s string) string { - if strings.ContainsAny(s, `{},"\`) { - s = strings.ReplaceAll(s, `\`, `\\`) - s = strings.ReplaceAll(s, `"`, `\"`) - return `"` + s + `"` - } - return s } diff --git a/pkg/utils/cdc_message.go b/pkg/utils/cdc_message.go index fab3dbb..f4d2d97 100644 --- a/pkg/utils/cdc_message.go +++ b/pkg/utils/cdc_message.go @@ -24,7 +24,6 @@ func init() { gob.Register(CDCMessage{}) gob.Register(pglogrepl.TupleData{}) gob.Register(pglogrepl.TupleDataColumn{}) - gob.Register(time.Time{}) } // CDCMessage represents a full message for Change Data Capture @@ -37,7 +36,7 @@ type CDCMessage struct { OldTuple *pglogrepl.TupleData PrimaryKeyColumn string LSN pglogrepl.LSN - CommitTimestamp time.Time + EmittedAt time.Time } // MarshalBinary implements the encoding.BinaryMarshaler interface @@ -136,7 +135,7 @@ func EncodeCDCMessage(m CDCMessage) ([]byte, error) { if err := enc.Encode(m.LSN); err != nil { return nil, err } - if err := enc.Encode(m.CommitTimestamp); err != nil { + if err := enc.Encode(m.EmittedAt); err != nil { return nil, err } @@ -190,7 +189,7 @@ func DecodeCDCMessage(data []byte) (*CDCMessage, error) { if err := dec.Decode(&m.LSN); err != nil { return nil, err } - if err := dec.Decode(&m.CommitTimestamp); err != nil { + if err := dec.Decode(&m.EmittedAt); err != nil { return nil, err } @@ -302,7 +301,7 @@ func DecodeArray(data []byte, dataType uint32) (interface{}, error) { // EncodeValue encodes a Go value into a byte slice based on the PostgreSQL data type func EncodeValue(value interface{}, dataType uint32) ([]byte, error) { - return ConvertToPgOutput(value, dataType) + return ConvertToPgCompatibleOutput(value, dataType) } // GetDecodedColumnValue returns the decoded value of a column @@ -333,7 +332,7 @@ func (m *CDCMessage) GetDecodedMessage() (map[string]interface{}, error) { decodedMessage["Table"] = m.Table decodedMessage["PrimaryKeyColumn"] = m.PrimaryKeyColumn decodedMessage["LSN"] = m.LSN - decodedMessage["CommitTimestamp"] = m.CommitTimestamp + decodedMessage["EmittedAt"] = m.EmittedAt if m.NewTuple != nil { newTuple := make(map[string]interface{}) diff --git a/pkg/worker/worker.go b/pkg/worker/worker.go index 99e1525..0dcc290 100644 --- a/pkg/worker/worker.go +++ b/pkg/worker/worker.go @@ -12,7 +12,7 @@ import ( "github.com/shayonj/pg_flo/pkg/utils" ) -// Worker represents a worker that processes messages from NATS +// Worker represents a worker that processes messages from NATS. type Worker struct { natsClient *pgflonats.NATSClient ruleEngine *rules.RuleEngine @@ -23,9 +23,9 @@ type Worker struct { maxRetries int } -// NewWorker creates and returns a new Worker instance +// NewWorker creates and returns a new Worker instance with the provided NATS client, rule engine, sink, and group. func NewWorker(natsClient *pgflonats.NATSClient, ruleEngine *rules.RuleEngine, sink sinks.Sink, group string) *Worker { - logger := zerolog.New(zerolog.NewConsoleWriter()).With().Timestamp().Logger() + logger := zerolog.New(zerolog.NewConsoleWriter()).With().Timestamp().Str("component", "worker").Logger() return &Worker{ natsClient: natsClient, @@ -38,7 +38,7 @@ func NewWorker(natsClient *pgflonats.NATSClient, ruleEngine *rules.RuleEngine, s } } -// Start begins the worker's message processing loop +// Start begins the worker's message processing loop, setting up the NATS consumer and processing messages. func (w *Worker) Start(ctx context.Context) error { stream := fmt.Sprintf("pgflo_%s_stream", w.group) subject := fmt.Sprintf("pgflo.%s", w.group) @@ -49,11 +49,48 @@ func (w *Worker) Start(ctx context.Context) error { Str("group", w.group). Msg("Starting worker") + state, err := w.natsClient.GetState(ctx) + if err != nil { + return fmt.Errorf("failed to get state: %w", err) + } + js := w.natsClient.JetStream() + streamInfo, err := js.Stream(ctx, stream) + if err != nil { + w.logger.Error().Err(err).Msg("Failed to get stream info") + return fmt.Errorf("failed to get stream info: %w", err) + } - cons, err := js.OrderedConsumer(ctx, stream, jetstream.OrderedConsumerConfig{ + info, err := streamInfo.Info(ctx) + if err != nil { + w.logger.Error().Err(err).Msg("Failed to get stream info details") + return fmt.Errorf("failed to get stream info details: %w", err) + } + + w.logger.Info(). + Uint64("messages", info.State.Msgs). + Uint64("first_seq", info.State.FirstSeq). + Uint64("last_seq", info.State.LastSeq). + Msg("Stream info") + + startSeq := state.LastProcessedSeq + 1 + if startSeq < info.State.FirstSeq { + w.logger.Warn(). + Uint64("start_seq", startSeq). + Uint64("stream_first_seq", info.State.FirstSeq). + Msg("Start sequence is before the first available message, adjusting to stream's first sequence") + startSeq = info.State.FirstSeq + } + + w.logger.Info().Uint64("start_seq", startSeq).Msg("Starting consumer from sequence") + + consumerConfig := jetstream.OrderedConsumerConfig{ FilterSubjects: []string{subject}, - }) + DeliverPolicy: jetstream.DeliverByStartSequencePolicy, + OptStartSeq: startSeq, + } + + cons, err := js.OrderedConsumer(ctx, stream, consumerConfig) if err != nil { return fmt.Errorf("failed to create ordered consumer: %w", err) } @@ -61,7 +98,7 @@ func (w *Worker) Start(ctx context.Context) error { return w.processMessages(ctx, cons) } -// processMessages continuously processes messages from the NATS consumer +// processMessages continuously processes messages from the NATS consumer. func (w *Worker) processMessages(ctx context.Context, cons jetstream.Consumer) error { iter, err := cons.Messages() if err != nil { @@ -89,10 +126,21 @@ func (w *Worker) processMessages(ctx context.Context, cons jetstream.Consumer) e } } -// processMessage handles a single message, applying rules and writing to the sink +// processMessage handles a single message, applying rules, writing to the sink, and updating the last processed sequence. func (w *Worker) processMessage(msg jetstream.Msg) error { + metadata, err := msg.Metadata() + if err != nil { + w.logger.Error().Err(err).Msg("Failed to get message metadata") + return err + } + + w.logger.Debug(). + Uint64("stream_seq", metadata.Sequence.Stream). + Uint64("consumer_seq", metadata.Sequence.Consumer). + Msg("Processing message") + var cdcMessage utils.CDCMessage - err := cdcMessage.UnmarshalBinary(msg.Data()) + err = cdcMessage.UnmarshalBinary(msg.Data()) if err != nil { w.logger.Error().Err(err).Msg("Failed to unmarshal message") return err @@ -117,5 +165,20 @@ func (w *Worker) processMessage(msg jetstream.Msg) error { return err } + state, err := w.natsClient.GetState(context.Background()) + if err != nil { + w.logger.Error().Err(err).Msg("Failed to get current state") + return err + } + + if metadata.Sequence.Stream > state.LastProcessedSeq { + state.LastProcessedSeq = metadata.Sequence.Stream + if err := w.natsClient.SaveState(context.Background(), state); err != nil { + w.logger.Error().Err(err).Msg("Failed to save state") + } else { + w.logger.Debug().Uint64("last_processed_seq", state.LastProcessedSeq).Msg("Updated last processed sequence") + } + } + return nil }