diff --git a/internal/docker-compose.yml b/internal/docker-compose.yml index e97f2e9..29c85a2 100644 --- a/internal/docker-compose.yml +++ b/internal/docker-compose.yml @@ -2,7 +2,7 @@ version: "3.8" services: postgres: - image: postgres:14 + image: debezium/postgres:14-alpine container_name: pg_logical_replication environment: POSTGRES_USER: myuser @@ -23,7 +23,7 @@ services: restart: unless-stopped target_postgres: - image: postgres:14 + image: postgres:14-alpine container_name: pg_target environment: POSTGRES_USER: targetuser diff --git a/internal/scripts/e2e_copy_and_stream.sh b/internal/scripts/e2e_copy_and_stream.sh index 5e3017b..bbbd4a9 100755 --- a/internal/scripts/e2e_copy_and_stream.sh +++ b/internal/scripts/e2e_copy_and_stream.sh @@ -216,7 +216,7 @@ test_pg_flo_cdc() { simulate_concurrent_changes log "Waiting for changes to replicate..." - sleep 90 + sleep 180 stop_pg_flo_gracefully compare_row_counts || return 1 verify_large_json || return 1 diff --git a/internal/scripts/e2e_copy_only.sh b/internal/scripts/e2e_copy_only.sh index f56cb0b..e9151c3 100755 --- a/internal/scripts/e2e_copy_only.sh +++ b/internal/scripts/e2e_copy_only.sh @@ -88,6 +88,7 @@ start_pg_flo_worker() { --target-dbname "$TARGET_PG_DB" \ --target-user "$TARGET_PG_USER" \ --target-password "$TARGET_PG_PASSWORD" \ + --batch-size 5000 \ --target-sync-schema \ >"$pg_flo_WORKER_LOG" 2>&1 & pg_flo_WORKER_PID=$! diff --git a/internal/scripts/e2e_postgres.sh b/internal/scripts/e2e_postgres.sh index 69621a0..4e80f47 100755 --- a/internal/scripts/e2e_postgres.sh +++ b/internal/scripts/e2e_postgres.sh @@ -91,7 +91,7 @@ start_pg_flo_worker() { simulate_changes() { log "Simulating changes..." - local insert_count=6000 + local insert_count=1000 for i in $(seq 1 "$insert_count"); do run_sql "INSERT INTO public.users (data, nullable_column, toasted_column) VALUES ('Data $i', 'Nullable $i', 'Toasted $i');" diff --git a/internal/scripts/e2e_postgres_uniqueness_test.rb b/internal/scripts/e2e_postgres_uniqueness_test.rb index f32d726..2e04f19 100644 --- a/internal/scripts/e2e_postgres_uniqueness_test.rb +++ b/internal/scripts/e2e_postgres_uniqueness_test.rb @@ -4,6 +4,7 @@ require 'logger' require 'securerandom' require 'json' +require 'digest' class PostgresUniquenessTest # Database configuration @@ -200,16 +201,18 @@ def test_unique_operations @logger.info "Testing operations on unique_test..." uuid = SecureRandom.uuid - # Test INSERT + logo_path = "internal/pg_flo_logo.png" + binary_data = File.binread(logo_path).force_encoding('BINARY') + @source_db.exec_params( "INSERT INTO public.unique_test (unique_col, binary_data, interval_data, data) - VALUES ($1, $2, $3, $4)", - [uuid, '\xdeadbeef', '1 year 2 months 3 days', '{"value": "test"}'] + VALUES ($1, $2::bytea, $3, $4)", + [uuid, { value: binary_data, format: 1 }, '1 year 2 months 3 days', '{"value": "test"}'] ) sleep 1 verify_table_data("unique_test", "unique_col = '#{uuid}'", - "1 | #{uuid} | \\xdeadbeef | 1 year 2 mons 3 days | {\"value\": \"test\"}") + "1 | #{uuid} | #{Digest::MD5.hexdigest(binary_data)} | 1 year 2 mons 3 days | {\"value\": \"test\"}") # Test UPDATE @source_db.exec_params( @@ -219,8 +222,7 @@ def test_unique_operations sleep 1 verify_table_data("unique_test", "unique_col = '#{uuid}'", - "1 | #{uuid} | \\xdeadbeef | 1 year 2 mons 3 days | {\"value\": \"updated_data\"}") - + "1 | #{uuid} | #{Digest::MD5.hexdigest(binary_data)} | 1 year 2 mons 3 days | {\"value\": \"updated_data\"}") # Test DELETE @source_db.exec_params("DELETE FROM public.unique_test WHERE unique_col = $1", [uuid]) @@ -345,7 +347,7 @@ def build_verification_query(table, condition) SELECT ( id::text || ' | ' || unique_col::text || ' | ' || - '\\x' || encode(binary_data, 'hex') || ' | ' || + MD5(binary_data) || ' | ' || interval_data::text || ' | ' || data::text ) AS row_data diff --git a/internal/scripts/e2e_resume_test.rb b/internal/scripts/e2e_resume_test.rb index dafaf7f..1540f06 100644 --- a/internal/scripts/e2e_resume_test.rb +++ b/internal/scripts/e2e_resume_test.rb @@ -284,7 +284,7 @@ def test_resume @logger.info "Waiting for all inserts to complete..." threads.each(&:join) - sleep 20 + sleep 60 @logger.info "Sending final SIGTERM to cleanup..." @replicator_pids.each do |pid| diff --git a/internal/scripts/e2e_stream_only.sh b/internal/scripts/e2e_stream_only.sh index e6527fa..836fde7 100755 --- a/internal/scripts/e2e_stream_only.sh +++ b/internal/scripts/e2e_stream_only.sh @@ -66,9 +66,9 @@ simulate_changes() { verify_changes() { log "Verifying changes in ${OUTPUT_DIR}..." - local insert_count=$(jq -s '[.[] | select(.Type == "INSERT")] | length' "$OUTPUT_DIR"/*.jsonl) - local update_count=$(jq -s '[.[] | select(.Type == "UPDATE")] | length' "$OUTPUT_DIR"/*.jsonl) - local delete_count=$(jq -s '[.[] | select(.Type == "DELETE")] | length' "$OUTPUT_DIR"/*.jsonl) + local insert_count=$(jq -s '[.[] | select(.operation == "INSERT")] | length' "$OUTPUT_DIR"/*.jsonl) + local update_count=$(jq -s '[.[] | select(.operation == "UPDATE")] | length' "$OUTPUT_DIR"/*.jsonl) + local delete_count=$(jq -s '[.[] | select(.operation == "DELETE")] | length' "$OUTPUT_DIR"/*.jsonl) log "INSERT count: $insert_count (expected 1000)" log "UPDATE count: $update_count (expected 500)" diff --git a/internal/scripts/e2e_test_local.sh b/internal/scripts/e2e_test_local.sh index 465a05b..bdc80cf 100755 --- a/internal/scripts/e2e_test_local.sh +++ b/internal/scripts/e2e_test_local.sh @@ -34,7 +34,7 @@ make build setup_docker log "Running e2e ddl tests..." -if CI=false ruby ./internal/scripts/e2e_resume_test.rb; then +if CI=false ./internal/scripts/e2e_copy_only.sh; then success "e2e ddl tests completed successfully" else error "Original e2e tests failed" diff --git a/internal/scripts/e2e_transform_filter.sh b/internal/scripts/e2e_transform_filter.sh index 364fc90..de41a0a 100755 --- a/internal/scripts/e2e_transform_filter.sh +++ b/internal/scripts/e2e_transform_filter.sh @@ -11,6 +11,10 @@ create_users() { email text, phone text, age int, + balance numeric(10,2), + score bigint, + rating real, + weight double precision, ssn text, created_at timestamp DEFAULT current_timestamp );" @@ -58,12 +62,12 @@ start_pg_flo_worker() { simulate_changes() { log "Simulating changes..." - run_sql "INSERT INTO public.users (email, phone, age, ssn) VALUES - ('john@example.com', '1234567890', 25, '123-45-6789'), - ('jane@example.com', '9876543210', 17, '987-65-4321'), - ('bob@example.com', '5551234567', 30, '555-12-3456');" + run_sql "INSERT INTO public.users (email, phone, age, balance, score, rating, weight, ssn) VALUES + ('john@example.com', '1234567890', 25, 100.50, 1000000000, 4.5, 75.5, '123-45-6789'), + ('jane@example.com', '9876543210', 17, 50.25, 2000000000, 3.8, 65.3, '987-65-4321'), + ('bob@example.com', '5551234567', 30, 75.75, 3000000000, 4.2, 80.1, '555-12-3456');" - run_sql "UPDATE public.users SET email = 'updated@example.com', phone = '1112223333' WHERE id = 1;" + run_sql "UPDATE public.users SET email = 'updated@example.com', phone = '1112223333', balance = 150.75 WHERE id = 1;" run_sql "DELETE FROM public.users WHERE age = 30;" run_sql "DELETE FROM public.users WHERE age = 17;" @@ -72,28 +76,51 @@ simulate_changes() { verify_changes() { log "Verifying changes..." - local insert_count=$(jq -s '[.[] | select(.Type == "INSERT")] | length' "$OUTPUT_DIR"/*.jsonl) - local update_count=$(jq -s '[.[] | select(.Type == "UPDATE")] | length' "$OUTPUT_DIR"/*.jsonl) - local delete_count=$(jq -s '[.[] | select(.Type == "DELETE")] | length' "$OUTPUT_DIR"/*.jsonl) - - log "INSERT count: $insert_count (expected 2)" + local insert_count=$(jq -s '[.[] | select(.operation == "INSERT")] | length' "$OUTPUT_DIR"/*.jsonl) + local update_count=$(jq -s '[.[] | select(.operation == "UPDATE")] | length' "$OUTPUT_DIR"/*.jsonl) + local delete_count=$(jq -s '[.[] | select(.operation == "DELETE")] | length' "$OUTPUT_DIR"/*.jsonl) + + # We expect: + # - 1 INSERT (id=1, age=25 passes all filters) + # - 1 UPDATE (for id=1) + # - 2 DELETEs (for age=30 and age=17) + log "INSERT count: $insert_count (expected 1)" log "UPDATE count: $update_count (expected 1)" log "DELETE count: $delete_count (expected 2)" - if [ "$insert_count" -eq 2 ] && [ "$update_count" -eq 1 ] && [ "$delete_count" -eq 2 ]; then + if [ "$insert_count" -eq 1 ] && [ "$update_count" -eq 1 ] && [ "$delete_count" -eq 2 ]; then success "Change counts match expected values" else error "Change counts do not match expected values" return 1 fi + # Verify numeric filters + local filtered_records=$(jq -r '.operation as $op | + select($op == "INSERT") | + select( + (.data.balance < 75.00) or + (.data.score >= 2500000000) or + (.data.rating <= 4.0) or + (.data.weight > 80.0) + ) | .data.id' "$OUTPUT_DIR"/*.jsonl) + + if [[ -z "$filtered_records" ]]; then + success "Numeric filters working for all types" + else + error "Numeric filters not working correctly" + log "Records that should have been filtered: $filtered_records" + jq -r 'select(.data.id == '"$filtered_records"') | {id: .data.id, balance: .data.balance, score: .data.score, rating: .data.rating, weight: .data.weight}' "$OUTPUT_DIR"/*.jsonl + return 1 + fi + # Verify transformations and filters - local masked_email=$(jq -r 'select(.Type == "INSERT" and .NewTuple.id == 1) | .NewTuple.email' "$OUTPUT_DIR"/*.jsonl) - local formatted_phone=$(jq -r 'select(.Type == "INSERT" and .NewTuple.id == 1) | .NewTuple.phone' "$OUTPUT_DIR"/*.jsonl) - local filtered_insert=$(jq -r 'select(.Type == "INSERT" and .NewTuple.id == 2) | .NewTuple.id' "$OUTPUT_DIR"/*.jsonl) - local updated_email=$(jq -r 'select(.Type == "UPDATE") | .NewTuple.email' "$OUTPUT_DIR"/*.jsonl) - local masked_ssn=$(jq -r 'select(.Type == "INSERT" and .NewTuple.id == 1) | .NewTuple.ssn' "$OUTPUT_DIR"/*.jsonl) - local filtered_age=$(jq -r 'select(.Type == "INSERT" and .NewTuple.id == 2) | .NewTuple.age' "$OUTPUT_DIR"/*.jsonl) + local masked_email=$(jq -r 'select(.operation == "INSERT" and .data.id == 1) | .data.email' "$OUTPUT_DIR"/*.jsonl) + local formatted_phone=$(jq -r 'select(.operation == "INSERT" and .data.id == 1) | .data.phone' "$OUTPUT_DIR"/*.jsonl) + local filtered_insert=$(jq -r 'select(.operation == "INSERT" and .data.id == 2) | .data.id' "$OUTPUT_DIR"/*.jsonl) + local updated_email=$(jq -r 'select(.operation == "UPDATE") | .data.email' "$OUTPUT_DIR"/*.jsonl) + local masked_ssn=$(jq -r 'select(.operation == "INSERT" and .data.id == 1) | .data.ssn' "$OUTPUT_DIR"/*.jsonl) + local filtered_age=$(jq -r 'select(.operation == "INSERT" and .data.id == 2) | .data.age' "$OUTPUT_DIR"/*.jsonl) if [[ "$masked_email" == "j**************m" ]] && [[ "$formatted_phone" == "(123) 456-7890" ]] && diff --git a/internal/scripts/rules.yml b/internal/scripts/rules.yml index 8405b7d..d0b0ec2 100644 --- a/internal/scripts/rules.yml +++ b/internal/scripts/rules.yml @@ -29,3 +29,31 @@ tables: mask_char: "X" allow_empty_deletes: true operations: [INSERT, UPDATE, DELETE] + - type: filter + column: balance + parameters: + operator: gte + value: 75.00 + allow_empty_deletes: true + operations: [INSERT, UPDATE, DELETE] + - type: filter + column: score + parameters: + operator: lt + value: 2500000000 + allow_empty_deletes: true + operations: [INSERT, UPDATE, DELETE] + - type: filter + column: rating + parameters: + operator: gt + value: 4.0 + allow_empty_deletes: true + operations: [INSERT, UPDATE, DELETE] + - type: filter + column: weight + parameters: + operator: lte + value: 80.0 + allow_empty_deletes: true + operations: [INSERT, UPDATE, DELETE] diff --git a/pkg/replicator/base_replicator.go b/pkg/replicator/base_replicator.go index 7b79aae..4ed9169 100644 --- a/pkg/replicator/base_replicator.go +++ b/pkg/replicator/base_replicator.go @@ -2,7 +2,9 @@ package replicator import ( "context" + "encoding/json" "fmt" + "strconv" "strings" "sync" "time" @@ -153,20 +155,30 @@ func (r *BaseReplicator) checkPublicationExists(publicationName string) (bool, e // StartReplicationFromLSN initiates the replication process from a given LSN func (r *BaseReplicator) StartReplicationFromLSN(ctx context.Context, startLSN pglogrepl.LSN, stopChan <-chan struct{}) error { publicationName := GeneratePublicationName(r.Config.Group) - r.Logger.Info().Str("startLSN", startLSN.String()).Str("publication", publicationName).Msg("Starting replication") - err := r.ReplicationConn.StartReplication(ctx, publicationName, startLSN, pglogrepl.StartReplicationOptions{ + tables, err := r.GetConfiguredTables(ctx) + if err != nil { + return fmt.Errorf("failed to get configured tables: %w", err) + } + + tableList := strings.Join(tables, ",") + + err = r.ReplicationConn.StartReplication(ctx, publicationName, startLSN, pglogrepl.StartReplicationOptions{ PluginArgs: []string{ - "proto_version '1'", - fmt.Sprintf("publication_names '%s'", publicationName), + "\"pretty-print\" 'true'", + "\"include-types\" 'true'", + "\"include-timestamp\" 'true'", + "\"include-pk\" 'true'", + "\"format-version\" '2'", + "\"include-column-positions\" 'true'", + "\"actions\" 'insert,update,delete'", + fmt.Sprintf("\"add-tables\" '%s'", tableList), }, }) if err != nil { return fmt.Errorf("failed to start replication: %w", err) } - r.Logger.Info().Str("startLSN", startLSN.String()).Msg("Replication started successfully") - return r.StreamChanges(ctx, stopChan) } @@ -281,136 +293,102 @@ func (r *BaseReplicator) handlePrimaryKeepaliveMessage(ctx context.Context, data return nil } -// processWALData handles different types of WAL messages +// processWALData processes the WAL data from wal2json func (r *BaseReplicator) processWALData(walData []byte, lsn pglogrepl.LSN) error { - logicalMsg, err := pglogrepl.Parse(walData) - if err != nil { - return fmt.Errorf("failed to parse WAL data: %w", err) - } - - switch msg := logicalMsg.(type) { - case *pglogrepl.RelationMessage: - r.handleRelationMessage(msg) - case *pglogrepl.BeginMessage: - return r.HandleBeginMessage(msg) - case *pglogrepl.InsertMessage: - return r.HandleInsertMessage(msg, lsn) - case *pglogrepl.UpdateMessage: - return r.HandleUpdateMessage(msg, lsn) - case *pglogrepl.DeleteMessage: - return r.HandleDeleteMessage(msg, lsn) - case *pglogrepl.CommitMessage: - return r.HandleCommitMessage(msg) - default: - r.Logger.Warn().Type("message", msg).Msg("Received unexpected logical replication message") - } - - return nil -} -// handleRelationMessage handles RelationMessage messages -func (r *BaseReplicator) handleRelationMessage(msg *pglogrepl.RelationMessage) { - r.Relations[msg.RelationID] = msg - r.Logger.Info().Str("table", msg.RelationName).Uint32("id", msg.RelationID).Msg("Relation message received") -} + var msg utils.Wal2JsonMessage + if err := json.Unmarshal(walData, &msg); err != nil { + return fmt.Errorf("failed to parse wal2json message: %w", err) + } -// HandleBeginMessage handles BeginMessage messages -func (r *BaseReplicator) HandleBeginMessage(msg *pglogrepl.BeginMessage) error { - r.currentTxBuffer = make([]utils.CDCMessage, 0) - r.currentTxLSN = msg.FinalLSN - return nil -} + switch msg.Action { + case "B": // Begin + r.mu.Lock() + r.currentTxBuffer = make([]utils.CDCMessage, 0) + r.mu.Unlock() + return nil -// HandleInsertMessage handles InsertMessage messages -func (r *BaseReplicator) HandleInsertMessage(msg *pglogrepl.InsertMessage, lsn pglogrepl.LSN) error { - relation, ok := r.Relations[msg.RelationID] - if !ok { - return fmt.Errorf("unknown relation ID: %d", msg.RelationID) - } + case "C": // Commit + r.mu.Lock() + defer r.mu.Unlock() - cdcMessage := utils.CDCMessage{ - Type: utils.OperationInsert, - Schema: relation.Namespace, - Table: relation.RelationName, - Columns: relation.Columns, - EmittedAt: time.Now(), - NewTuple: msg.Tuple, - LSN: lsn.String(), - } + for _, cdcMsg := range r.currentTxBuffer { + if err := r.PublishToNATS(cdcMsg); err != nil { + return fmt.Errorf("failed to publish message: %w", err) + } + } - r.AddPrimaryKeyInfo(&cdcMessage, relation.RelationName) - r.currentTxBuffer = append(r.currentTxBuffer, cdcMessage) - return nil -} + r.currentTxBuffer = nil + r.LastLSN = lsn + return r.SaveState(lsn) + + case "I", "U", "D": // Insert, Update, Delete + cdcMsg := &utils.CDCMessage{ + Schema: msg.Schema, + Table: msg.Table, + LSN: lsn.String(), + EmittedAt: time.Now(), + Data: make(map[string]interface{}), + OldData: make(map[string]interface{}), + ColumnTypes: make(map[string]string), + Columns: make([]utils.Column, len(msg.Columns)), + } -// HandleUpdateMessage handles UpdateMessage messages -func (r *BaseReplicator) HandleUpdateMessage(msg *pglogrepl.UpdateMessage, lsn pglogrepl.LSN) error { - relation, ok := r.Relations[msg.RelationID] - if !ok { - return fmt.Errorf("unknown relation ID: %d", msg.RelationID) - } - - cdcMessage := utils.CDCMessage{ - Type: utils.OperationUpdate, - Schema: relation.Namespace, - Table: relation.RelationName, - Columns: relation.Columns, - NewTuple: msg.NewTuple, - OldTuple: msg.OldTuple, - LSN: lsn.String(), - EmittedAt: time.Now(), - ToastedColumns: make(map[string]bool), - } - - for i, col := range relation.Columns { - if msg.NewTuple != nil { - newVal := msg.NewTuple.Columns[i] - cdcMessage.ToastedColumns[col.Name] = newVal.DataType == 'u' + switch msg.Action { + case "I": + cdcMsg.Operation = utils.OperationInsert + case "U": + cdcMsg.Operation = utils.OperationUpdate + case "D": + cdcMsg.Operation = utils.OperationDelete } - } - r.AddPrimaryKeyInfo(&cdcMessage, relation.RelationName) - r.currentTxBuffer = append(r.currentTxBuffer, cdcMessage) - return nil -} + r.AddPrimaryKeyInfo(cdcMsg, msg.Table) -// HandleDeleteMessage handles DeleteMessage messages -func (r *BaseReplicator) HandleDeleteMessage(msg *pglogrepl.DeleteMessage, lsn pglogrepl.LSN) error { - relation, ok := r.Relations[msg.RelationID] - if !ok { - return fmt.Errorf("unknown relation ID: %d", msg.RelationID) - } + for i, col := range msg.Columns { + value := col.Value - cdcMessage := utils.CDCMessage{ - Type: utils.OperationDelete, - Schema: relation.Namespace, - Table: relation.RelationName, - Columns: relation.Columns, - OldTuple: msg.OldTuple, - EmittedAt: time.Now(), - LSN: lsn.String(), - } + // TODO: is this working? do we need to merge this with convert function + if col.Type == "bigint" && col.Value != nil { + if strVal, ok := col.Value.(string); ok { + if parsed, err := strconv.ParseInt(strVal, 10, 64); err == nil { + value = parsed + } + } + } - r.AddPrimaryKeyInfo(&cdcMessage, relation.RelationName) - r.currentTxBuffer = append(r.currentTxBuffer, cdcMessage) - return nil -} + cdcMsg.Data[col.Name] = value + cdcMsg.ColumnTypes[col.Name] = col.Type + cdcMsg.Columns[i] = utils.Column{ + Name: col.Name, + DataType: utils.GetOIDFromTypeName(col.Type), + } + } -// HandleCommitMessage processes a commit message and publishes it to NATS -func (r *BaseReplicator) HandleCommitMessage(msg *pglogrepl.CommitMessage) error { - for _, cdcMessage := range r.currentTxBuffer { - if err := r.PublishToNATS(cdcMessage); err != nil { - return fmt.Errorf("failed to publish message: %w", err) + // Process old values from identity field + if len(msg.Identity) > 0 { + for _, col := range msg.Identity { + value := col.Value + + // Handle bigint values for old data too + if col.Type == "bigint" && col.Value != nil { + if strVal, ok := col.Value.(string); ok { + if parsed, err := strconv.ParseInt(strVal, 10, 64); err == nil { + value = parsed + } + } + } + + cdcMsg.OldData[col.Name] = value + } } - } - r.LastLSN = msg.CommitLSN - if err := r.SaveState(msg.CommitLSN); err != nil { - r.Logger.Error().Err(err).Msg("Failed to save replication state") - return err + r.mu.Lock() + r.currentTxBuffer = append(r.currentTxBuffer, *cdcMsg) + r.mu.Unlock() + } - r.currentTxBuffer = nil return nil } @@ -422,6 +400,7 @@ func (r *BaseReplicator) PublishToNATS(data utils.CDCMessage) error { } subject := fmt.Sprintf("pgflo.%s", r.Config.Group) + err = r.NATSClient.PublishMessage(subject, binaryData) if err != nil { r.Logger.Error(). diff --git a/pkg/replicator/copy_and_stream_replicator.go b/pkg/replicator/copy_and_stream_replicator.go index 116a306..543c6b4 100644 --- a/pkg/replicator/copy_and_stream_replicator.go +++ b/pkg/replicator/copy_and_stream_replicator.go @@ -214,8 +214,8 @@ func (r *CopyAndStreamReplicator) CopyTableRange(ctx context.Context, tableName } }() - if setSnapshotErr := r.setTransactionSnapshot(tx, snapshotID); setSnapshotErr != nil { - return 0, setSnapshotErr + if err := r.setTransactionSnapshot(tx, snapshotID); err != nil { + return 0, err } schema, err := r.getSchemaName(tx, tableName) @@ -224,55 +224,19 @@ func (r *CopyAndStreamReplicator) CopyTableRange(ctx context.Context, tableName } query := r.buildCopyQuery(tableName, startPage, endPage) - return r.executeCopyQuery(ctx, tx, query, schema, tableName, workerID) -} - -// setTransactionSnapshot sets the transaction snapshot. -func (r *CopyAndStreamReplicator) setTransactionSnapshot(tx pgx.Tx, snapshotID string) error { - _, err := tx.Exec(context.Background(), fmt.Sprintf("SET TRANSACTION SNAPSHOT '%s'", snapshotID)) - if err != nil { - return fmt.Errorf("failed to set transaction snapshot: %v", err) - } - return nil -} - -// getSchemaName retrieves the schema name for a given table. -func (r *CopyAndStreamReplicator) getSchemaName(tx pgx.Tx, tableName string) (string, error) { - var schema string - err := tx.QueryRow(context.Background(), "SELECT schemaname FROM pg_tables WHERE tablename = $1", tableName).Scan(&schema) - if err != nil { - return "", fmt.Errorf("failed to get schema name: %v", err) - } - return schema, nil -} - -// buildCopyQuery constructs the SQL query for copying a range of pages from a table. -func (r *CopyAndStreamReplicator) buildCopyQuery(tableName string, startPage, endPage uint32) string { - query := fmt.Sprintf(` - SELECT * - FROM %s - WHERE ctid >= '(%d,0)'::tid AND ctid < '(%d,0)'::tid`, - pgx.Identifier{tableName}.Sanitize(), startPage, endPage) - return query -} - -// executeCopyQuery executes the copy query and publishes the results to NATS. -func (r *CopyAndStreamReplicator) executeCopyQuery(ctx context.Context, tx pgx.Tx, query, schema, tableName string, workerID int) (int64, error) { - r.Logger.Debug().Str("copyQuery", query).Int("workerID", workerID).Msg("Executing initial copy query") - - rows, err := tx.Query(context.Background(), query) + rows, err := tx.Query(ctx, query) if err != nil { - return 0, fmt.Errorf("failed to execute initial copy query: %v", err) + return 0, fmt.Errorf("failed to execute copy query: %v", err) } defer rows.Close() fieldDescriptions := rows.FieldDescriptions() - columns := make([]*pglogrepl.RelationMessageColumn, len(fieldDescriptions)) + columnTypes := make(map[string]uint32) + columnNames := make([]string, len(fieldDescriptions)) + for i, fd := range fieldDescriptions { - columns[i] = &pglogrepl.RelationMessageColumn{ - Name: fd.Name, - DataType: fd.DataTypeOID, - } + columnNames[i] = fd.Name + columnTypes[fd.Name] = fd.DataTypeOID } var copyCount int64 @@ -282,33 +246,39 @@ func (r *CopyAndStreamReplicator) executeCopyQuery(ctx context.Context, tx pgx.T return 0, fmt.Errorf("error reading row: %v", err) } - tupleData := &pglogrepl.TupleData{ - Columns: make([]*pglogrepl.TupleDataColumn, len(values)), - } + data := make(map[string]interface{}) for i, value := range values { - data, err := utils.ConvertToPgCompatibleOutput(value, fieldDescriptions[i].DataTypeOID) - if err != nil { - return 0, fmt.Errorf("error converting value: %v", err) - } - - tupleData.Columns[i] = &pglogrepl.TupleDataColumn{ - DataType: uint8(fieldDescriptions[i].DataTypeOID), - Data: data, + if value != nil { + convertedValue, err := utils.ConvertToPgCompatibleOutput(value, columnTypes[columnNames[i]]) + if err != nil { + return 0, fmt.Errorf("error converting value: %v", err) + } + data[columnNames[i]] = convertedValue + } else { + data[columnNames[i]] = nil } } cdcMessage := utils.CDCMessage{ - Type: utils.OperationInsert, + Operation: utils.OperationInsert, Schema: schema, Table: tableName, - Columns: columns, - NewTuple: tupleData, + Data: data, + LSN: r.LastLSN.String(), EmittedAt: time.Now(), + Columns: make([]utils.Column, len(fieldDescriptions)), } - r.BaseReplicator.AddPrimaryKeyInfo(&cdcMessage, tableName) - if err := r.BaseReplicator.PublishToNATS(cdcMessage); err != nil { - return 0, fmt.Errorf("failed to publish insert event to NATS: %v", err) + for i, fd := range fieldDescriptions { + cdcMessage.Columns[i] = utils.Column{ + Name: fd.Name, + DataType: fd.DataTypeOID, + } + } + + r.AddPrimaryKeyInfo(&cdcMessage, tableName) + if err := r.PublishToNATS(cdcMessage); err != nil { + return 0, fmt.Errorf("failed to publish message: %v", err) } copyCount++ @@ -320,11 +290,36 @@ func (r *CopyAndStreamReplicator) executeCopyQuery(ctx context.Context, tx pgx.T } } - if err := rows.Err(); err != nil { - return 0, fmt.Errorf("error during row iteration: %v", err) + return copyCount, rows.Err() +} + +// setTransactionSnapshot sets the transaction snapshot. +func (r *CopyAndStreamReplicator) setTransactionSnapshot(tx pgx.Tx, snapshotID string) error { + _, err := tx.Exec(context.Background(), fmt.Sprintf("SET TRANSACTION SNAPSHOT '%s'", snapshotID)) + if err != nil { + return fmt.Errorf("failed to set transaction snapshot: %v", err) + } + return nil +} + +// getSchemaName retrieves the schema name for a given table. +func (r *CopyAndStreamReplicator) getSchemaName(tx pgx.Tx, tableName string) (string, error) { + var schema string + err := tx.QueryRow(context.Background(), "SELECT schemaname FROM pg_tables WHERE tablename = $1", tableName).Scan(&schema) + if err != nil { + return "", fmt.Errorf("failed to get schema name: %v", err) } + return schema, nil +} - return copyCount, nil +// buildCopyQuery constructs the SQL query for copying a range of pages from a table. +func (r *CopyAndStreamReplicator) buildCopyQuery(tableName string, startPage, endPage uint32) string { + query := fmt.Sprintf(` + SELECT * + FROM %s + WHERE ctid >= '(%d,0)'::tid AND ctid < '(%d,0)'::tid`, + pgx.Identifier{tableName}.Sanitize(), startPage, endPage) + return query } // collectErrors collects errors from the error channel and returns them as a single error. diff --git a/pkg/replicator/ddl_replicator.go b/pkg/replicator/ddl_replicator.go index dec4df3..5ae1824 100644 --- a/pkg/replicator/ddl_replicator.go +++ b/pkg/replicator/ddl_replicator.go @@ -7,8 +7,6 @@ import ( "strings" "time" - "github.com/jackc/pglogrepl" - "github.com/jackc/pgtype" "github.com/pgflo/pg_flo/pkg/utils" ) @@ -179,25 +177,16 @@ func (d *DDLReplicator) ProcessDDLEvents(ctx context.Context) error { } cdcMessage := utils.CDCMessage{ - Type: utils.OperationDDL, + Operation: utils.OperationDDL, Schema: schema, Table: table, EmittedAt: time.Now(), - Columns: []*pglogrepl.RelationMessageColumn{ - {Name: "event_type", DataType: pgtype.TextOID}, - {Name: "object_type", DataType: pgtype.TextOID}, - {Name: "object_identity", DataType: pgtype.TextOID}, - {Name: "ddl_command", DataType: pgtype.TextOID}, - {Name: "created_at", DataType: pgtype.TimestamptzOID}, - }, - NewTuple: &pglogrepl.TupleData{ - Columns: []*pglogrepl.TupleDataColumn{ - {Data: []byte(eventType)}, - {Data: []byte(objectType)}, - {Data: []byte(objectIdentity)}, - {Data: []byte(ddlCommand)}, - {Data: []byte(createdAt.Format(time.RFC3339))}, - }, + Data: map[string]interface{}{ + "event_type": eventType, + "object_type": objectType, + "object_identity": objectIdentity, + "ddl_command": ddlCommand, + "created_at": createdAt.Format(time.RFC3339), }, } diff --git a/pkg/replicator/replication_connection.go b/pkg/replicator/replication_connection.go index 911dd39..3458da1 100644 --- a/pkg/replicator/replication_connection.go +++ b/pkg/replicator/replication_connection.go @@ -54,7 +54,9 @@ func (rc *PostgresReplicationConnection) Close(ctx context.Context) error { // CreateReplicationSlot creates a new replication slot in the PostgreSQL database. func (rc *PostgresReplicationConnection) CreateReplicationSlot(ctx context.Context, slotName string) (pglogrepl.CreateReplicationSlotResult, error) { - return pglogrepl.CreateReplicationSlot(ctx, rc.Conn, slotName, "pgoutput", pglogrepl.CreateReplicationSlotOptions{Temporary: false}) + return pglogrepl.CreateReplicationSlot(ctx, rc.Conn, slotName, "wal2json", pglogrepl.CreateReplicationSlotOptions{ + Temporary: false, + }) } // StartReplication initiates the replication process from the specified LSN. diff --git a/pkg/replicator/table_handling.go b/pkg/replicator/table_handling.go index 39d2564..f1ca165 100644 --- a/pkg/replicator/table_handling.go +++ b/pkg/replicator/table_handling.go @@ -10,10 +10,17 @@ import ( // AddPrimaryKeyInfo adds replication key information to the CDCMessage func (r *BaseReplicator) AddPrimaryKeyInfo(message *utils.CDCMessage, table string) { if key, ok := r.TableReplicationKeys[table]; ok { - message.ReplicationKey = key + replicationKey := key + message.ReplicationKey = &replicationKey } else { r.Logger.Error(). Str("table", table). + Any("message", message.Data). + Any("schema", message.Schema). + Any("table", message.Table). + Any("operation", message.Operation). + Any("lsn", message.LSN). + Any("emitted_at", message.EmittedAt). Msg("No replication key information found for table. This should not happen as validation is done during initialization") } } diff --git a/pkg/routing/router.go b/pkg/routing/router.go index 41a3de8..f349a63 100644 --- a/pkg/routing/router.go +++ b/pkg/routing/router.go @@ -3,7 +3,6 @@ package routing import ( "sync" - "github.com/jackc/pglogrepl" "github.com/pgflo/pg_flo/pkg/utils" "github.com/rs/zerolog" "github.com/rs/zerolog/log" @@ -48,7 +47,7 @@ func (r *Router) ApplyRouting(message *utils.CDCMessage) (*utils.CDCMessage, err return message, nil } - if !ContainsOperation(route.Operations, message.Type) { + if !ContainsOperation(route.Operations, message.Operation) { return nil, nil } @@ -56,18 +55,45 @@ func (r *Router) ApplyRouting(message *utils.CDCMessage) (*utils.CDCMessage, err routedMessage.Table = route.DestinationTable if len(route.ColumnMappings) > 0 { - newColumns := make([]*pglogrepl.RelationMessageColumn, len(message.Columns)) + newColumns := make([]utils.Column, len(message.Columns)) for i, col := range message.Columns { - newCol := *col + newCol := col mappedName := GetMappedColumnName(route.ColumnMappings, col.Name) if mappedName != "" { newCol.Name = mappedName } - newColumns[i] = &newCol + newColumns[i] = newCol } routedMessage.Columns = newColumns - if routedMessage.ReplicationKey.Type != utils.ReplicationKeyFull { + newData := make(map[string]interface{}) + newOldData := make(map[string]interface{}) + newColumnTypes := make(map[string]string) + + for oldName, value := range message.Data { + newName := GetMappedColumnName(route.ColumnMappings, oldName) + if newName == "" { + newName = oldName + } + newData[newName] = value + if typeVal, exists := message.ColumnTypes[oldName]; exists { + newColumnTypes[newName] = typeVal + } + } + + for oldName, value := range message.OldData { + newName := GetMappedColumnName(route.ColumnMappings, oldName) + if newName == "" { + newName = oldName + } + newOldData[newName] = value + } + + routedMessage.Data = newData + routedMessage.OldData = newOldData + routedMessage.ColumnTypes = newColumnTypes + + if routedMessage.ReplicationKey != nil && routedMessage.ReplicationKey.Type != utils.ReplicationKeyFull { mappedColumns := make([]string, len(routedMessage.ReplicationKey.Columns)) for i, keyCol := range routedMessage.ReplicationKey.Columns { mappedName := GetMappedColumnName(route.ColumnMappings, keyCol) diff --git a/pkg/rules/engine.go b/pkg/rules/engine.go index 482606d..6dafe4a 100644 --- a/pkg/rules/engine.go +++ b/pkg/rules/engine.go @@ -25,7 +25,7 @@ func (re *RuleEngine) ApplyRules(message *utils.CDCMessage) (*utils.CDCMessage, logger.Info(). Str("table", message.Table). - Str("operation", string(message.Type)). + Str("operation", string(message.Operation)). Int("ruleCount", len(rules)). Msg("Applying rules") diff --git a/pkg/rules/rules.go b/pkg/rules/rules.go index ded51bc..579e82b 100644 --- a/pkg/rules/rules.go +++ b/pkg/rules/rules.go @@ -1,15 +1,13 @@ package rules import ( + "encoding/json" "fmt" - "reflect" "regexp" "strings" - "time" "os" - "github.com/jackc/pgtype" "github.com/pgflo/pg_flo/pkg/utils" "github.com/rs/zerolog" "github.com/shopspring/decimal" @@ -109,7 +107,7 @@ func NewMaskTransformRule(table, column string, params map[string]interface{}) ( } transform := func(m *utils.CDCMessage) (*utils.CDCMessage, error) { - useOldValues := m.Type == utils.OperationDelete + useOldValues := m.Operation == utils.OperationDelete value, err := m.GetColumnValue(column, useOldValues) if err != nil { return m, nil @@ -180,80 +178,103 @@ func NewFilterRule(table, column string, params map[string]interface{}) (Rule, e // NewComparisonCondition creates a new comparison condition function func NewComparisonCondition(column, operator string, value interface{}) func(*utils.CDCMessage) bool { return func(m *utils.CDCMessage) bool { - useOldValues := m.Type == utils.OperationDelete + useOldValues := m.Operation == utils.OperationDelete columnValue, err := m.GetColumnValue(column, useOldValues) if err != nil { return false } - colIndex := m.GetColumnIndex(column) - if colIndex == -1 { - return false - } + var colVal, compareVal decimal.Decimal - columnType := m.Columns[colIndex].DataType - - switch columnType { - case pgtype.Int2OID, pgtype.Int4OID, pgtype.Int8OID: - intVal, ok := utils.ToInt64(columnValue) - if !ok { - return false - } - compareVal, ok := utils.ToInt64(value) - if !ok { - return false - } - return compareValues(intVal, compareVal, operator) - case pgtype.Float4OID, pgtype.Float8OID: - floatVal, ok := utils.ToFloat64(columnValue) - if !ok { - return false - } - compareVal, ok := utils.ToFloat64(value) - if !ok { - return false - } - return compareValues(floatVal, compareVal, operator) - case pgtype.TextOID, pgtype.VarcharOID: - strVal, ok := columnValue.(string) - if !ok { - return false - } - compareVal, ok := value.(string) - if !ok { - return false - } - return compareValues(strVal, compareVal, operator) - case pgtype.TimestampOID, pgtype.TimestamptzOID: - timeVal, ok := columnValue.(time.Time) - if !ok { - return false - } - compareVal, err := utils.ParseTimestamp(fmt.Sprintf("%v", value)) + // Handle column value from JSON + switch v := columnValue.(type) { + case float64: + colVal = decimal.NewFromFloat(v) + logger.Debug(). + Float64("value", v). + Str("column", column). + Msg("Converting float64") + case json.Number: + var err error + colVal, err = decimal.NewFromString(v.String()) if err != nil { + logger.Debug(). + Err(err). + Str("value", v.String()). + Str("column", column). + Msg("Failed to parse json.Number") return false } - return compareValues(timeVal.UTC(), compareVal.UTC(), operator) - case pgtype.BoolOID: - boolVal, ok := utils.ToBool(columnValue) - if !ok { - return false - } - compareVal, ok := value.(bool) - if !ok { - return false - } - return compareValues(boolVal, compareVal, operator) - case pgtype.NumericOID: - numVal, ok := columnValue.(string) - if !ok { + logger.Debug(). + Str("value", v.String()). + Str("column", column). + Msg("Converting json.Number") + case int64: + colVal = decimal.NewFromInt(v) + logger.Debug(). + Int64("value", v). + Str("column", column). + Msg("Converting int64") + case int: + colVal = decimal.NewFromInt(int64(v)) + logger.Debug(). + Int("value", v). + Str("column", column). + Msg("Converting int") + case string: + // Handle numeric strings (sometimes wal2json sends these) + var err error + colVal, err = decimal.NewFromString(v) + if err != nil { + logger.Debug(). + Err(err). + Str("value", v). + Str("column", column). + Msg("Failed to parse numeric string") return false } - compareVal, ok := value.(string) - if !ok { + logger.Debug(). + Str("value", v). + Str("column", column). + Msg("Converting numeric string") + default: + logger.Debug(). + Str("type", fmt.Sprintf("%T", v)). + Any("value", v). + Str("column", column). + Msg("Unsupported numeric type") + return false + } + + // Handle comparison value from YAML + switch v := value.(type) { + case float64: + compareVal = decimal.NewFromFloat(v) + case json.Number: + var err error + compareVal, err = decimal.NewFromString(v.String()) + if err != nil { return false } - return compareNumericValues(numVal, compareVal, operator) + case int: + compareVal = decimal.NewFromInt(int64(v)) + default: + return false + } + + switch operator { + case "eq": + return colVal.Equal(compareVal) + case "ne": + return !colVal.Equal(compareVal) + case "gt": + return colVal.GreaterThan(compareVal) + case "lt": + return colVal.LessThan(compareVal) + case "gte": + return colVal.GreaterThanOrEqual(compareVal) + case "lte": + return colVal.LessThanOrEqual(compareVal) default: return false } @@ -263,7 +284,7 @@ func NewComparisonCondition(column, operator string, value interface{}) func(*ut // NewContainsCondition creates a new contains condition function func NewContainsCondition(column string, value interface{}) func(*utils.CDCMessage) bool { return func(m *utils.CDCMessage) bool { - useOldValues := m.Type == utils.OperationDelete + useOldValues := m.Operation == utils.OperationDelete columnValue, err := m.GetColumnValue(column, useOldValues) if err != nil { return false @@ -280,91 +301,14 @@ func NewContainsCondition(column string, value interface{}) func(*utils.CDCMessa } } -// compareValues compares two values based on the provided operator -func compareValues(a, b interface{}, operator string) bool { - switch operator { - case "eq": - return reflect.DeepEqual(a, b) - case "ne": - return !reflect.DeepEqual(a, b) - case "gt": - return compareGreaterThan(a, b) - case "lt": - return compareLessThan(a, b) - case "gte": - return compareGreaterThan(a, b) || reflect.DeepEqual(a, b) - case "lte": - return compareLessThan(a, b) || reflect.DeepEqual(a, b) - } - return false -} - -// compareGreaterThan checks if 'a' is greater than 'b' -func compareGreaterThan(a, b interface{}) bool { - switch a := a.(type) { - case int64: - return a > b.(int64) - case float64: - return a > b.(float64) - case string: - return a > b.(string) - case time.Time: - return a.After(b.(time.Time)) - default: - return false - } -} - -// compareLessThan checks if 'a' is less than 'b' -func compareLessThan(a, b interface{}) bool { - switch a := a.(type) { - case int64: - return a < b.(int64) - case float64: - return a < b.(float64) - case string: - return a < b.(string) - case time.Time: - return a.Before(b.(time.Time)) - default: - return false - } -} - -// compareNumericValues compares two numeric values based on the provided operator -func compareNumericValues(a, b string, operator string) bool { - aNum, err1 := decimal.NewFromString(a) - bNum, err2 := decimal.NewFromString(b) - if err1 != nil || err2 != nil { - return false - } - - switch operator { - case "eq": - return aNum.Equal(bNum) - case "ne": - return !aNum.Equal(bNum) - case "gt": - return aNum.GreaterThan(bNum) - case "lt": - return aNum.LessThan(bNum) - case "gte": - return aNum.GreaterThanOrEqual(bNum) - case "lte": - return aNum.LessThanOrEqual(bNum) - default: - return false - } -} - // Apply applies the transform rule to the provided data func (r *TransformRule) Apply(message *utils.CDCMessage) (*utils.CDCMessage, error) { - if !containsOperation(r.Operations, message.Type) { + if !containsOperation(r.Operations, message.Operation) { return message, nil } // Don't apply rule if asked not to - if message.Type == utils.OperationDelete && r.AllowEmptyDeletes { + if message.Operation == utils.OperationDelete && r.AllowEmptyDeletes { return message, nil } @@ -373,12 +317,13 @@ func (r *TransformRule) Apply(message *utils.CDCMessage) (*utils.CDCMessage, err // Apply applies the filter rule to the provided data func (r *FilterRule) Apply(message *utils.CDCMessage) (*utils.CDCMessage, error) { - if !containsOperation(r.Operations, message.Type) { + + if !containsOperation(r.Operations, message.Operation) { return message, nil } // Don't apply rule if asked not to - if message.Type == utils.OperationDelete && r.AllowEmptyDeletes { + if message.Operation == utils.OperationDelete && r.AllowEmptyDeletes { return message, nil } @@ -386,7 +331,7 @@ func (r *FilterRule) Apply(message *utils.CDCMessage) (*utils.CDCMessage, error) logger.Debug(). Str("column", r.ColumnName). - Any("operation", message.Type). + Any("operation", message.Operation). Bool("passes", passes). Bool("allowEmptyDeletes", r.AllowEmptyDeletes). Msg("Filter condition result") diff --git a/pkg/sinks/file.go b/pkg/sinks/file.go index 2c6ea62..d3cb59e 100644 --- a/pkg/sinks/file.go +++ b/pkg/sinks/file.go @@ -84,12 +84,7 @@ func (s *FileSink) WriteBatch(messages []*utils.CDCMessage) error { defer s.mutex.Unlock() for _, message := range messages { - decodedMessage, err := buildDecodedMessage(message) - if err != nil { - return fmt.Errorf("failed to build decoded message: %v", err) - } - - jsonData, err := json.Marshal(decodedMessage) + jsonData, err := json.Marshal(message) if err != nil { return fmt.Errorf("failed to marshal data to JSON: %v", err) } diff --git a/pkg/sinks/postgres.go b/pkg/sinks/postgres.go index c6a8534..fe1a564 100644 --- a/pkg/sinks/postgres.go +++ b/pkg/sinks/postgres.go @@ -10,6 +10,8 @@ import ( "sync" "time" + "encoding/hex" + "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" "github.com/pgflo/pg_flo/pkg/utils" @@ -109,18 +111,22 @@ func (s *PostgresSink) syncSchema(sourceHost string, sourcePort int, sourceDBNam // handleInsert processes an insert operation func (s *PostgresSink) handleInsert(tx pgx.Tx, message *utils.CDCMessage) error { - columns := make([]string, 0, len(message.Columns)) - placeholders := make([]string, 0, len(message.Columns)) - values := make([]interface{}, 0, len(message.Columns)) - - for i, col := range message.Columns { - value, err := message.GetColumnValue(col.Name, false) + columns := make([]string, 0, len(message.Data)) + placeholders := make([]string, 0, len(message.Data)) + values := make([]interface{}, 0, len(message.Data)) + i := 1 + + for colName, value := range message.Data { + pgType := message.GetColumnType(colName) + convertedValue, err := s.convertValue(value, pgType) if err != nil { - return fmt.Errorf("failed to get column value: %w", err) + return fmt.Errorf("failed to convert value for column %s: %w", colName, err) } - columns = append(columns, col.Name) - placeholders = append(placeholders, fmt.Sprintf("$%d", i+1)) - values = append(values, value) + + columns = append(columns, colName) + placeholders = append(placeholders, fmt.Sprintf("$%d", i)) + values = append(values, convertedValue) + i++ } query := fmt.Sprintf( @@ -131,8 +137,7 @@ func (s *PostgresSink) handleInsert(tx pgx.Tx, message *utils.CDCMessage) error strings.Join(placeholders, ", "), ) - // Add ON CONFLICT clause for PK/UNIQUE keys - if message.ReplicationKey.Type != utils.ReplicationKeyFull && len(message.ReplicationKey.Columns) > 0 { + if message.ReplicationKey != nil && message.ReplicationKey.Type != utils.ReplicationKeyFull && len(message.ReplicationKey.Columns) > 0 { query += fmt.Sprintf(" ON CONFLICT (%s) DO NOTHING", strings.Join(message.ReplicationKey.Columns, ", ")) } @@ -149,7 +154,6 @@ func getWhereConditions(message *utils.CDCMessage, useOldValues bool, startingIn var conditions []string var values []interface{} valueIndex := startingIndex - switch message.ReplicationKey.Type { case utils.ReplicationKeyFull: // For FULL, use all non-null values @@ -230,26 +234,16 @@ func (s *PostgresSink) handleUpdate(tx pgx.Tx, message *utils.CDCMessage) error return fmt.Errorf("invalid replication key configuration") } - setClauses := make([]string, 0, len(message.Columns)) - setValues := make([]interface{}, 0, len(message.Columns)) + setClauses := make([]string, 0, len(message.Data)) + setValues := make([]interface{}, 0, len(message.Data)) valueIndex := 1 - for _, column := range message.Columns { - // Skip toasted columns - if message.IsColumnToasted(column.Name) { - continue - } - - // Get the new value for the column - value, err := message.GetColumnValue(column.Name, false) - if err != nil { - return fmt.Errorf("failed to get column value: %w", err) - } - + // Iterate over Data map directly + for colName, value := range message.Data { if value == nil { - setClauses = append(setClauses, fmt.Sprintf("%s = NULL", column.Name)) + setClauses = append(setClauses, fmt.Sprintf("%s = NULL", colName)) } else { - setClauses = append(setClauses, fmt.Sprintf("%s = $%d", column.Name, valueIndex)) + setClauses = append(setClauses, fmt.Sprintf("%s = $%d", colName, valueIndex)) setValues = append(setValues, value) valueIndex++ } @@ -412,7 +406,7 @@ func (s *PostgresSink) writeBatchInternal(ctx context.Context, messages []*utils } } - switch message.Type { + switch message.Operation { case utils.OperationInsert: operationErr = s.handleInsert(tx, message) case utils.OperationUpdate: @@ -424,7 +418,7 @@ func (s *PostgresSink) writeBatchInternal(ctx context.Context, messages []*utils newTx, operationErr = s.handleDDL(tx, message) tx = newTx default: - operationErr = fmt.Errorf("unknown operation type: %s", message.Type) + operationErr = fmt.Errorf("unknown operation type: %s", message.Operation) } if operationErr != nil && isConnectionError(operationErr) { @@ -441,7 +435,7 @@ func (s *PostgresSink) writeBatchInternal(ctx context.Context, messages []*utils } tx = nil return fmt.Errorf("failed to handle %s for table %s.%s: %v-%v", - message.Type, + message.Operation, message.Schema, message.Table, err, @@ -486,3 +480,31 @@ func isConnectionError(err error) bool { func (s *PostgresSink) Close() error { return s.conn.Close(context.Background()) } + +// convertValue converts a value from CDCMessage to PostgreSQL compatible format +func (s *PostgresSink) convertValue(value interface{}, pgType string) (interface{}, error) { + if value == nil { + return nil, nil + } + + // TODO: Rethink bytea handling (source or sink responsibility?) + switch pgType { + case "bytea": + switch v := value.(type) { + case string: + if strings.HasPrefix(v, "\\x") { + // Decode hex string to bytes + hexStr := v[2:] // Remove \x prefix + return hex.DecodeString(hexStr) + } + // If no \x prefix, assume it's already hex encoded + return hex.DecodeString(v) + case []byte: + return v, nil + default: + return nil, fmt.Errorf("unsupported bytea value type: %T", v) + } + default: + return value, nil + } +} diff --git a/pkg/sinks/shared.go b/pkg/sinks/shared.go deleted file mode 100644 index 7466777..0000000 --- a/pkg/sinks/shared.go +++ /dev/null @@ -1,39 +0,0 @@ -package sinks - -import "github.com/pgflo/pg_flo/pkg/utils" - -func buildDecodedMessage(message *utils.CDCMessage) (map[string]interface{}, error) { - decodedMessage := make(map[string]interface{}) - decodedMessage["Type"] = message.Type - decodedMessage["Schema"] = message.Schema - decodedMessage["Table"] = message.Table - decodedMessage["ReplicationKey"] = message.ReplicationKey - decodedMessage["LSN"] = message.LSN - decodedMessage["EmittedAt"] = message.EmittedAt - - if message.NewTuple != nil { - newTuple := make(map[string]interface{}) - for _, col := range message.Columns { - value, err := message.GetColumnValue(col.Name, false) - if err != nil { - return nil, err - } - newTuple[col.Name] = value - } - decodedMessage["NewTuple"] = newTuple - } - - if message.OldTuple != nil { - oldTuple := make(map[string]interface{}) - for _, col := range message.Columns { - value, err := message.GetColumnValue(col.Name, true) - if err != nil { - return nil, err - } - oldTuple[col.Name] = value - } - decodedMessage["OldTuple"] = oldTuple - } - - return decodedMessage, nil -} diff --git a/pkg/sinks/stdout.go b/pkg/sinks/stdout.go index 57d9114..cb6da98 100644 --- a/pkg/sinks/stdout.go +++ b/pkg/sinks/stdout.go @@ -18,12 +18,7 @@ func NewStdoutSink() (*StdoutSink, error) { // WriteBatch writes a batch of data to standard output func (s *StdoutSink) WriteBatch(messages []*utils.CDCMessage) error { for _, message := range messages { - decodedMessage, err := buildDecodedMessage(message) - if err != nil { - return fmt.Errorf("failed to build decoded message: %v", err) - } - - jsonData, err := json.Marshal(decodedMessage) + jsonData, err := json.Marshal(message) if err != nil { return fmt.Errorf("failed to marshal data to JSON: %v", err) } diff --git a/pkg/sinks/webhooks.go b/pkg/sinks/webhooks.go index 7722572..1ab8060 100644 --- a/pkg/sinks/webhooks.go +++ b/pkg/sinks/webhooks.go @@ -38,12 +38,7 @@ func NewWebhookSink(webhookURL string) (*WebhookSink, error) { // WriteBatch sends a batch of data to the webhook endpoint func (s *WebhookSink) WriteBatch(messages []*utils.CDCMessage) error { for _, message := range messages { - decodedMessage, err := buildDecodedMessage(message) - if err != nil { - return fmt.Errorf("failed to build decoded message: %v", err) - } - - jsonData, err := json.Marshal(decodedMessage) + jsonData, err := json.Marshal(message) if err != nil { return fmt.Errorf("failed to marshal data to JSON: %v", err) } diff --git a/pkg/utils/cdc_encoding.go b/pkg/utils/cdc_encoding.go index 528aeb1..1a6995d 100644 --- a/pkg/utils/cdc_encoding.go +++ b/pkg/utils/cdc_encoding.go @@ -12,60 +12,72 @@ import ( ) // ConvertToPgCompatibleOutput converts a Go value to its PostgreSQL output format. -func ConvertToPgCompatibleOutput(value interface{}, oid uint32) ([]byte, error) { +func ConvertToPgCompatibleOutput(value interface{}, oid uint32) (string, error) { if value == nil { - return nil, nil + return "", nil } switch oid { case pgtype.BoolOID: - return strconv.AppendBool(nil, value.(bool)), nil + return strconv.FormatBool(value.(bool)), nil case pgtype.Int2OID, pgtype.Int4OID, pgtype.Int8OID: switch v := value.(type) { case int: - return []byte(strconv.FormatInt(int64(v), 10)), nil + return strconv.FormatInt(int64(v), 10), nil case int32: - return []byte(strconv.FormatInt(int64(v), 10)), nil + return strconv.FormatInt(int64(v), 10), nil case int64: - return []byte(strconv.FormatInt(v, 10)), nil + return strconv.FormatInt(v, 10), nil default: - return []byte(fmt.Sprintf("%d", value)), nil + return fmt.Sprintf("%d", value), nil } case pgtype.Float4OID, pgtype.Float8OID: - return []byte(strconv.FormatFloat(value.(float64), 'f', -1, 64)), nil + return strconv.FormatFloat(value.(float64), 'f', -1, 64), nil case pgtype.NumericOID: - return []byte(fmt.Sprintf("%v", value)), nil + return fmt.Sprintf("%v", value), nil case pgtype.TextOID, pgtype.VarcharOID: - return []byte(value.(string)), nil + return value.(string), nil case pgtype.ByteaOID: if byteaData, ok := value.([]byte); ok { - return byteaData, nil + return fmt.Sprintf("\\x%x", byteaData), nil } - return nil, fmt.Errorf("invalid bytea data type") + return "", fmt.Errorf("invalid bytea data type") case pgtype.TimestampOID, pgtype.TimestamptzOID: - return []byte(value.(time.Time).Format(time.RFC3339Nano)), nil + return value.(time.Time).Format(time.RFC3339Nano), nil case pgtype.DateOID: - return []byte(value.(time.Time).Format("2006-01-02")), nil + return value.(time.Time).Format("2006-01-02"), nil case pgtype.JSONOID: switch v := value.(type) { case string: - return []byte(v), nil - case []byte: return v, nil + case []byte: + return string(v), nil default: - return nil, fmt.Errorf("unsupported type for JSON data: %T", value) + return "", fmt.Errorf("unsupported type for JSON data: %T", value) } case pgtype.JSONBOID: if jsonBytes, ok := value.([]byte); ok { - return jsonBytes, nil + return string(jsonBytes), nil } - return json.Marshal(value) + jsonBytes, err := json.Marshal(value) + if err != nil { + return "", err + } + return string(jsonBytes), nil case pgtype.TextArrayOID, pgtype.VarcharArrayOID, pgtype.Int2ArrayOID, pgtype.Int4ArrayOID, pgtype.Int8ArrayOID, pgtype.Float4ArrayOID, pgtype.Float8ArrayOID, pgtype.BoolArrayOID: - return EncodeArray(value) + arrayBytes, err := EncodeArray(value) + if err != nil { + return "", err + } + return string(arrayBytes), nil default: - return []byte(fmt.Sprintf("%v", value)), nil + jsonBytes, err := json.Marshal(value) + if err != nil { + return "", fmt.Errorf("failed to marshal value to JSON: %w", err) + } + return string(jsonBytes), nil } } diff --git a/pkg/utils/cdc_message.go b/pkg/utils/cdc_message.go index 0ef3aad..0bae268 100644 --- a/pkg/utils/cdc_message.go +++ b/pkg/utils/cdc_message.go @@ -1,349 +1,98 @@ package utils import ( - "bytes" - "encoding/gob" - "encoding/hex" "encoding/json" "fmt" - "strconv" - "strings" "time" - - "github.com/jackc/pglogrepl" - "github.com/jackc/pgx/v5/pgtype" ) -// init registers types with the gob package for encoding/decoding -func init() { - gob.Register(json.RawMessage{}) - gob.Register(time.Time{}) - gob.Register(map[string]interface{}{}) - gob.Register(pglogrepl.RelationMessageColumn{}) - gob.Register(pglogrepl.LSN(0)) - - gob.Register(CDCMessage{}) - gob.Register(pglogrepl.TupleData{}) - gob.Register(pglogrepl.TupleDataColumn{}) +// Column represents a database column +type Column struct { + Name string `json:"name"` + DataType uint32 `json:"-"` } -// CDCMessage represents a full message for Change Data Capture +// CDCMessage represents a change data capture message type CDCMessage struct { - Type OperationType - Schema string - Table string - Columns []*pglogrepl.RelationMessageColumn - NewTuple *pglogrepl.TupleData - OldTuple *pglogrepl.TupleData - ReplicationKey ReplicationKey - LSN string - EmittedAt time.Time - ToastedColumns map[string]bool + Operation OperationType `json:"operation"` + Schema string `json:"schema"` + Table string `json:"table"` + Columns []Column `json:"columns"` + Data map[string]interface{} `json:"data"` + OldData map[string]interface{} `json:"old_data,omitempty"` + PrimaryKey map[string]interface{} `json:"primary_key,omitempty"` + LSN string `json:"lsn"` + EmittedAt time.Time `json:"emitted_at"` + ReplicationKey *ReplicationKey `json:"replication_key,omitempty"` + ColumnTypes map[string]string `json:"column_types,omitempty"` } -// MarshalBinary implements the encoding.BinaryMarshaler interface -func (m CDCMessage) MarshalBinary() ([]byte, error) { - return EncodeCDCMessage(m) +// Wal2JsonMessage represents the raw message from wal2json +type Wal2JsonMessage struct { + Action string `json:"action"` + Schema string `json:"schema"` + Table string `json:"table"` + Columns []wal2JsonCol `json:"columns"` + Identity []wal2JsonCol `json:"identity,omitempty"` + PK []wal2JsonCol `json:"pk"` + Timestamp string `json:"timestamp"` } -// UnmarshalBinary implements the encoding.BinaryUnmarshaler interface -func (m *CDCMessage) UnmarshalBinary(data []byte) error { - decodedMessage, err := DecodeCDCMessage(data) - if err != nil { - return err - } - *m = *decodedMessage - return nil +// wal2JsonCol is used only for parsing the raw wal2json output +type wal2JsonCol struct { + Name string `json:"name"` + Type string `json:"type"` + Value interface{} `json:"value"` + Position int `json:"position,omitempty"` } -func (m *CDCMessage) GetColumnIndex(columnName string) int { - for i, col := range m.Columns { - if col.Name == columnName { - return i - } - } - return -1 -} - -// GetColumnValue gets a column value, optionally using old values for DELETE/UPDATE -func (m *CDCMessage) GetColumnValue(columnName string, useOldValues bool) (interface{}, error) { - colIndex := m.GetColumnIndex(columnName) - if colIndex == -1 { - return nil, fmt.Errorf("column %s not found", columnName) - } - - var data []byte - if useOldValues && m.OldTuple != nil { - data = m.OldTuple.Columns[colIndex].Data - } else if m.NewTuple != nil { - data = m.NewTuple.Columns[colIndex].Data - } else { - return nil, fmt.Errorf("no data available for column %s", columnName) - } - - return DecodeValue(data, m.Columns[colIndex].DataType) +// Wal2JsonChange represents a single change from wal2json +type Wal2JsonChange struct { + Kind string `json:"kind"` + Schema string `json:"schema"` + Table string `json:"table"` + ColumnNames []string `json:"columnnames"` + ColumnTypes []string `json:"columntypes"` + ColumnValues []interface{} `json:"columnvalues"` + OldKeys map[string]interface{} `json:"oldkeys,omitempty"` } -// SetColumnValue sets the value of a column, respecting its type -func (m *CDCMessage) SetColumnValue(columnName string, value interface{}) error { - colIndex := m.GetColumnIndex(columnName) - if colIndex == -1 { - return fmt.Errorf("column %s not found", columnName) +func (m *CDCMessage) GetColumnType(columnName string) string { + if m.ColumnTypes != nil { + return m.ColumnTypes[columnName] } - - column := m.Columns[colIndex] - encodedValue, err := EncodeValue(value, column.DataType) - if err != nil { - return err - } - - if m.Type == OperationDelete { - m.OldTuple.Columns[colIndex] = &pglogrepl.TupleDataColumn{Data: encodedValue} - } else { - m.NewTuple.Columns[colIndex] = &pglogrepl.TupleDataColumn{Data: encodedValue} - } - - return nil + return "" } -// EncodeCDCMessage encodes a CDCMessage into a byte slice -func EncodeCDCMessage(m CDCMessage) ([]byte, error) { - var buf bytes.Buffer - enc := gob.NewEncoder(&buf) - - if err := enc.Encode(m.Type); err != nil { - return nil, err - } - if err := enc.Encode(m.Schema); err != nil { - return nil, err - } - if err := enc.Encode(m.Table); err != nil { - return nil, err - } - if err := enc.Encode(m.Columns); err != nil { - return nil, err - } - - if err := enc.Encode(m.NewTuple != nil); err != nil { - return nil, err - } - if m.NewTuple != nil { - if err := enc.Encode(m.NewTuple); err != nil { - return nil, err - } - } - - if err := enc.Encode(m.OldTuple != nil); err != nil { - return nil, err - } - - if m.OldTuple != nil { - if err := enc.Encode(m.OldTuple); err != nil { - return nil, err - } - } - - if err := enc.Encode(m.ReplicationKey); err != nil { - return nil, err - } - - if err := enc.Encode(m.LSN); err != nil { - return nil, err - } - - if err := enc.Encode(m.EmittedAt); err != nil { - return nil, err - } - - if err := enc.Encode(m.ToastedColumns); err != nil { - return nil, err - } - - return buf.Bytes(), nil +// Binary marshaling for NATS compatibility +func (m CDCMessage) MarshalBinary() ([]byte, error) { + return json.Marshal(m) } -// DecodeCDCMessage decodes a byte slice into a CDCMessage -func DecodeCDCMessage(data []byte) (*CDCMessage, error) { - buf := bytes.NewBuffer(data) - dec := gob.NewDecoder(buf) - m := &CDCMessage{} - - if err := dec.Decode(&m.Type); err != nil { - return nil, err - } - if err := dec.Decode(&m.Schema); err != nil { - return nil, err - } - if err := dec.Decode(&m.Table); err != nil { - return nil, err - } - if err := dec.Decode(&m.Columns); err != nil { - return nil, err - } - - var newTupleExists bool - if err := dec.Decode(&newTupleExists); err != nil { - return nil, err - } - if newTupleExists { - m.NewTuple = &pglogrepl.TupleData{} - if err := dec.Decode(m.NewTuple); err != nil { - return nil, err - } - } - - var oldTupleExists bool - if err := dec.Decode(&oldTupleExists); err != nil { - return nil, err - } - if oldTupleExists { - m.OldTuple = &pglogrepl.TupleData{} - if err := dec.Decode(m.OldTuple); err != nil { - return nil, err - } - } - - if err := dec.Decode(&m.ReplicationKey); err != nil { - return nil, err - } - - if err := dec.Decode(&m.LSN); err != nil { - return nil, err - } - - if err := dec.Decode(&m.EmittedAt); err != nil { - return nil, err - } - - if err := dec.Decode(&m.ToastedColumns); err != nil { - return nil, err - } - - return m, nil +func (m *CDCMessage) UnmarshalBinary(data []byte) error { + return json.Unmarshal(data, m) } -// DecodeValue decodes a byte slice into a Go value based on the PostgreSQL data type -func DecodeValue(data []byte, dataType uint32) (interface{}, error) { - if data == nil { - return nil, nil - } - strData := string(data) - switch dataType { - case pgtype.BoolOID: - return strconv.ParseBool(string(data)) - case pgtype.Int2OID, pgtype.Int4OID, pgtype.Int8OID: - return strconv.ParseInt(string(data), 10, 64) - case pgtype.Float4OID, pgtype.Float8OID: - if strings.EqualFold(strData, "NULL") { - return nil, nil - } - return strconv.ParseFloat(strData, 64) - case pgtype.NumericOID: - return string(data), nil - case pgtype.TextOID, pgtype.VarcharOID: - return string(data), nil - case pgtype.ByteaOID: - if strings.HasPrefix(strData, "\\x") { - hexString := strData[2:] - byteData, err := hex.DecodeString(hexString) - if err != nil { - return nil, fmt.Errorf("failed to decode bytea hex string: %v", err) - } - return byteData, nil +// GetColumnValue retrieves a column value from either current or old data +func (m *CDCMessage) GetColumnValue(columnName string, useOldValues bool) (interface{}, error) { + if useOldValues { + if value, ok := m.OldData[columnName]; ok { + return value, nil } - return data, nil - case pgtype.TimestampOID, pgtype.TimestamptzOID: - return ParseTimestamp(string(data)) - case pgtype.DateOID: - return time.Parse("2006-01-02", string(data)) - case pgtype.JSONOID: - return string(data), nil - case pgtype.JSONBOID: - var result interface{} - err := json.Unmarshal(data, &result) - return result, err - case pgtype.TextArrayOID, pgtype.VarcharArrayOID: - return DecodeTextArray(data) - case pgtype.Int2ArrayOID, pgtype.Int4ArrayOID, pgtype.Int8ArrayOID, pgtype.Float4ArrayOID, pgtype.Float8ArrayOID, pgtype.BoolArrayOID: - return DecodeArray(data, dataType) - default: - return string(data), nil + return nil, fmt.Errorf("column %s not found in old data", columnName) } -} -// DecodeTextArray decodes a PostgreSQL text array into a []string -func DecodeTextArray(data []byte) ([]string, error) { - if len(data) < 2 || data[0] != '{' || data[len(data)-1] != '}' { - return nil, fmt.Errorf("invalid array format") + if value, ok := m.Data[columnName]; ok { + return value, nil } - elements := strings.Split(string(data[1:len(data)-1]), ",") - for i, elem := range elements { - elements[i] = strings.Trim(elem, "\"") - } - return elements, nil + return nil, fmt.Errorf("column %s not found in data", columnName) } -// DecodeArray decodes a PostgreSQL array into a slice of the appropriate type -func DecodeArray(data []byte, dataType uint32) (interface{}, error) { - if len(data) < 2 || data[0] != '{' || data[len(data)-1] != '}' { - return nil, fmt.Errorf("invalid array format") - } - elements := strings.Split(string(data[1:len(data)-1]), ",") - - switch dataType { - case pgtype.Int2ArrayOID, pgtype.Int4ArrayOID, pgtype.Int8ArrayOID: - result := make([]interface{}, len(elements)) - for i, elem := range elements { - if elem == "NULL" { - result[i] = nil - continue - } - val, err := strconv.ParseInt(elem, 10, 64) - if err != nil { - return nil, err - } - result[i] = val - } - return result, nil - case pgtype.Float4ArrayOID, pgtype.Float8ArrayOID: - result := make([]interface{}, len(elements)) - for i, elem := range elements { - if elem == "NULL" { - result[i] = nil - continue - } - val, err := strconv.ParseFloat(elem, 64) - if err != nil { - return nil, err - } - result[i] = val - } - return result, nil - case pgtype.BoolArrayOID: - result := make([]interface{}, len(elements)) - for i, elem := range elements { - if elem == "NULL" { - result[i] = nil - continue - } - val, err := strconv.ParseBool(elem) - if err != nil { - return nil, err - } - result[i] = val - } - return result, nil - default: - return elements, nil +// SetColumnValue sets a column value in the Data map +func (m *CDCMessage) SetColumnValue(columnName string, value interface{}) error { + if m.Data == nil { + m.Data = make(map[string]interface{}) } -} - -// EncodeValue encodes a Go value into a byte slice based on the PostgreSQL data type -func EncodeValue(value interface{}, dataType uint32) ([]byte, error) { - return ConvertToPgCompatibleOutput(value, dataType) -} - -// IsColumnToasted checks if a column was TOASTed -func (m *CDCMessage) IsColumnToasted(columnName string) bool { - return m.ToastedColumns[columnName] + m.Data[columnName] = value + return nil } diff --git a/pkg/utils/pg_types.go b/pkg/utils/pg_types.go new file mode 100644 index 0000000..8576b52 --- /dev/null +++ b/pkg/utils/pg_types.go @@ -0,0 +1,14 @@ +package utils + +import "github.com/jackc/pgx/v5/pgtype" + +var typeMap = pgtype.NewMap() + +// GetOIDFromTypeName converts a PostgreSQL type name to its OID +func GetOIDFromTypeName(typeName string) uint32 { + dt, ok := typeMap.TypeForName(typeName) + if !ok { + return pgtype.TextOID + } + return dt.OID +}