From 308b9577582cd92e5e0a2570b146753d37a64c9e Mon Sep 17 00:00:00 2001 From: Shayon Mukherjee Date: Fri, 15 Nov 2024 13:43:48 -0500 Subject: [PATCH] Support for syncing CDC changes by Unique columns and fall back to replica identity when present This way users can still sync a table if it does not have a Primary Key. With Unique columns, composite primary keys and/or replica identities as fallback --- .github/workflows/integration.yml | 38 ++ README.md | 32 ++ internal/scripts/e2e_postgres_uniquness.rb | 497 ++++++++++++++++++ internal/scripts/e2e_test_local.sh | 4 +- pkg/replicator/base_replicator.go | 179 +++++-- pkg/replicator/copy_and_stream_replicator.go | 2 +- pkg/replicator/ddl_replicator.go | 2 +- pkg/replicator/tests/base_replicator_test.go | 67 ++- .../tests/copy_and_stream_replicator_test.go | 4 +- pkg/replicator/tests/json_encoder_test.go | 4 +- pkg/routing/router.go | 19 +- pkg/routing/tests/routing_test.go | 18 +- pkg/rules/engine.go | 2 +- pkg/rules/rules.go | 19 +- pkg/rules/tests/engine_test.go | 125 +++-- pkg/rules/tests/rules_test.go | 98 ++-- pkg/sinks/postgres.go | 202 +++++-- pkg/utils/cdc_message.go | 109 +++- pkg/utils/shared.go | 19 + pkg/utils/shared_types.go | 16 + 20 files changed, 1190 insertions(+), 266 deletions(-) create mode 100644 internal/scripts/e2e_postgres_uniquness.rb diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index 9dea148..ec9fff6 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -250,3 +250,41 @@ jobs: sleep 10 ruby ./internal/scripts/e2e_resume_test.rb docker-compose -f internal/docker-compose.yml down -v + + postgres_uniqueness_test: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: Set up Go + uses: actions/setup-go@v4 + with: + go-version: "1.21" + - name: Build pg_flo + run: make build + - name: Install dependencies + run: | + sudo apt-get update + sudo apt-get install -y postgresql-client jq ruby ruby-dev libpq-dev build-essential + sudo gem install pg + - name: Set up Docker Compose + run: | + sudo curl -L "https://github.com/docker/compose/releases/download/v2.17.2/docker-compose-$(uname -s)-$(uname -m)" -o /usr/local/bin/docker-compose + sudo chmod +x /usr/local/bin/docker-compose + - name: Run test + env: + PG_HOST: localhost + PG_PORT: 5433 + PG_USER: myuser + PG_PASSWORD: mypassword!@#%1234 + PG_DB: mydb + TARGET_PG_HOST: localhost + TARGET_PG_PORT: 5434 + TARGET_PG_USER: targetuser + TARGET_PG_PASSWORD: targetpassword!@#1234 + TARGET_PG_DB: targetdb + working-directory: . + run: | + docker-compose -f internal/docker-compose.yml up -d + sleep 10 + ruby ./internal/scripts/e2e_postgres_uniquness.rb + docker-compose -f internal/docker-compose.yml down -v diff --git a/README.md b/README.md index c9a7f78..83c968d 100644 --- a/README.md +++ b/README.md @@ -232,6 +232,38 @@ pg_flo worker postgres --group inventory - NATS message size: 8MB (configurable) - One worker per group recommended - PostgreSQL logical replication prerequisites required +- Tables must have one of the following for replication: + - Primary key + - Unique constraint with `NOT NULL` columns + - `REPLICA IDENTITY FULL` set + +Example table configurations: + +```sql +-- Using primary key (recommended) +CREATE TABLE users ( + id SERIAL PRIMARY KEY, + email TEXT, + name TEXT +); + +-- Using unique constraint +CREATE TABLE orders ( + order_id TEXT NOT NULL, + customer_id TEXT NOT NULL, + data JSONB, + CONSTRAINT orders_unique UNIQUE (order_id, customer_id) +); +ALTER TABLE orders REPLICA IDENTITY USING INDEX orders_unique; + +-- Using all columns (higher overhead in terms of performance) +CREATE TABLE audit_logs ( + id SERIAL, + action TEXT, + data JSONB +); +ALTER TABLE audit_logs REPLICA IDENTITY FULL; +``` ## Development diff --git a/internal/scripts/e2e_postgres_uniquness.rb b/internal/scripts/e2e_postgres_uniquness.rb new file mode 100644 index 0000000..f32d726 --- /dev/null +++ b/internal/scripts/e2e_postgres_uniquness.rb @@ -0,0 +1,497 @@ +#!/usr/bin/env ruby + +require 'pg' +require 'logger' +require 'securerandom' +require 'json' + +class PostgresUniquenessTest + # Database configuration + PG_HOST = 'localhost' + PG_PORT = 5433 + PG_USER = 'myuser' + PG_PASSWORD = 'mypassword!@#%1234' + PG_DB = 'mydb' + + TARGET_PG_HOST = 'localhost' + TARGET_PG_PORT = 5434 + TARGET_PG_USER = 'targetuser' + TARGET_PG_PASSWORD = 'targetpassword!@#1234' + TARGET_PG_DB = 'targetdb' + + # NATS configuration + NATS_URL = 'nats://localhost:4222' + + # Paths + PG_FLO_BIN = './bin/pg_flo' + PG_FLO_LOG = '/tmp/pg_flo.log' + PG_FLO_WORKER_LOG = '/tmp/pg_flo_worker.log' + + def initialize + @logger = Logger.new(STDOUT) + @logger.formatter = proc { |_, _, _, msg| "#{msg}\n" } + sleep 5 + connect_to_databases + end + + def connect_to_databases + @source_db = PG.connect( + host: PG_HOST, + port: PG_PORT, + dbname: PG_DB, + user: PG_USER, + password: PG_PASSWORD + ) + + @target_db = PG.connect( + host: TARGET_PG_HOST, + port: TARGET_PG_PORT, + dbname: TARGET_PG_DB, + user: TARGET_PG_USER, + password: TARGET_PG_PASSWORD + ) + end + + def create_test_tables + @logger.info "Creating test tables..." + + # Create extensions and types + @source_db.exec(<<-SQL) + CREATE EXTENSION IF NOT EXISTS hstore; + DROP TYPE IF EXISTS mood CASCADE; + CREATE TYPE mood AS ENUM ('sad', 'ok', 'happy'); + SQL + + # Drop existing tables + @source_db.exec(<<-SQL) + DROP TABLE IF EXISTS public.pk_test CASCADE; + DROP TABLE IF EXISTS public.unique_test CASCADE; + DROP TABLE IF EXISTS public.composite_pk_test CASCADE; + DROP TABLE IF EXISTS public.multi_unique_test CASCADE; + DROP TABLE IF EXISTS public.mixed_constraints_test CASCADE; + SQL + + # Create test tables + @source_db.exec(<<-SQL) + CREATE TABLE public.pk_test ( + id serial PRIMARY KEY, + text_data text, + json_data jsonb, + array_data text[], + numeric_data numeric(10,2), + timestamp_data timestamp with time zone, + enum_data mood, + nullable_column text + ); + + CREATE TABLE public.unique_test ( + id serial, + unique_col uuid UNIQUE NOT NULL, + binary_data bytea, + interval_data interval, + data jsonb + ); + + CREATE TABLE public.composite_pk_test ( + id1 integer, + id2 uuid, + json_data jsonb, + array_int_data integer[], + data text, + PRIMARY KEY (id1, id2) + ); + + CREATE TABLE public.multi_unique_test ( + id serial, + unique_col1 timestamp with time zone NOT NULL, + unique_col2 uuid NOT NULL, + array_data text[], + hstore_data hstore, + data jsonb, + CONSTRAINT multi_unique_test_unique_cols UNIQUE (unique_col1, unique_col2) + ); + + CREATE TABLE public.mixed_constraints_test ( + id serial PRIMARY KEY, + unique_col uuid UNIQUE, + non_unique_col jsonb, + array_data integer[], + data text + ); + + ALTER TABLE public.unique_test REPLICA IDENTITY USING INDEX unique_test_unique_col_key; + ALTER TABLE public.multi_unique_test REPLICA IDENTITY USING INDEX multi_unique_test_unique_cols; + ALTER TABLE public.mixed_constraints_test REPLICA IDENTITY FULL; + SQL + + @logger.info "Test tables created successfully" + end + + def start_pg_flo + @logger.info "Starting pg_flo..." + + create_config_files + + replicator_cmd = "#{PG_FLO_BIN} replicator --config /tmp/pg_flo_replicator.yml" + worker_cmd = "#{PG_FLO_BIN} worker postgres --config /tmp/pg_flo_worker.yml" + + @replicator_pid = spawn(replicator_cmd, out: PG_FLO_LOG, err: PG_FLO_LOG) + sleep 2 + @worker_pid = spawn(worker_cmd, out: PG_FLO_WORKER_LOG, err: PG_FLO_WORKER_LOG) + sleep 2 + end + + def create_config_files + File.write('/tmp/pg_flo_replicator.yml', <<-YAML) +host: #{PG_HOST} +port: #{PG_PORT} +dbname: #{PG_DB} +user: #{PG_USER} +password: #{PG_PASSWORD} +group: group_unique +tables: + - pk_test + - unique_test + - composite_pk_test + - multi_unique_test + - mixed_constraints_test +schema: public +nats-url: #{NATS_URL} + YAML + + File.write('/tmp/pg_flo_worker.yml', <<-YAML) +group: group_unique +nats-url: #{NATS_URL} +source-host: #{PG_HOST} +source-port: #{PG_PORT} +source-dbname: #{PG_DB} +source-user: #{PG_USER} +source-password: #{PG_PASSWORD} +target-host: #{TARGET_PG_HOST} +target-port: #{TARGET_PG_PORT} +target-dbname: #{TARGET_PG_DB} +target-user: #{TARGET_PG_USER} +target-password: #{TARGET_PG_PASSWORD} +target-sync-schema: true + YAML + end + + def test_pk_operations + @logger.info "Testing operations on pk_test..." + + # Test INSERT + @source_db.exec_params( + "INSERT INTO public.pk_test (text_data, json_data, array_data, numeric_data, timestamp_data, enum_data, nullable_column) + VALUES ($1, $2, $3, $4, $5, $6, $7)", + ['test1', '{"key": "value"}', '{a,b,c}', 123.45, '2024-03-20 10:00:00+00', 'happy', 'value1'] + ) + + sleep 1 + success = verify_table_data("pk_test", "id = 1", + "1 | test1 | {\"key\": \"value\"} | a,b,c | 123.45 | 2024-03-20 10:00:00+00 | happy | value1") + + @source_db.exec("DELETE FROM public.pk_test WHERE id = 1") + sleep 1 + + success + end + + def test_unique_operations + @logger.info "Testing operations on unique_test..." + uuid = SecureRandom.uuid + + # Test INSERT + @source_db.exec_params( + "INSERT INTO public.unique_test (unique_col, binary_data, interval_data, data) + VALUES ($1, $2, $3, $4)", + [uuid, '\xdeadbeef', '1 year 2 months 3 days', '{"value": "test"}'] + ) + + sleep 1 + verify_table_data("unique_test", "unique_col = '#{uuid}'", + "1 | #{uuid} | \\xdeadbeef | 1 year 2 mons 3 days | {\"value\": \"test\"}") + + # Test UPDATE + @source_db.exec_params( + "UPDATE public.unique_test SET data = $1 WHERE unique_col = $2", + ['{"value": "updated_data"}', uuid] + ) + + sleep 1 + verify_table_data("unique_test", "unique_col = '#{uuid}'", + "1 | #{uuid} | \\xdeadbeef | 1 year 2 mons 3 days | {\"value\": \"updated_data\"}") + + + # Test DELETE + @source_db.exec_params("DELETE FROM public.unique_test WHERE unique_col = $1", [uuid]) + end + + def test_composite_pk_operations + @logger.info "Testing operations on composite_pk_test..." + uuid = SecureRandom.uuid + + # Test INSERT + @source_db.exec_params( + "INSERT INTO public.composite_pk_test (id1, id2, json_data, array_int_data, data) + VALUES ($1, $2, $3, $4, $5)", + [1, uuid, '{"key": "value"}', '{1,2,3}', 'test_composite'] + ) + + sleep 1 + verify_table_data("composite_pk_test", "id1 = 1 AND id2 = '#{uuid}'", + "1 | #{uuid} | {\"key\": \"value\"} | 1,2,3 | test_composite") + + # Test UPDATE + @source_db.exec_params( + "UPDATE public.composite_pk_test SET data = $1 WHERE id1 = $2 AND id2 = $3", + ['updated_composite', 1, uuid] + ) + + sleep 1 + verify_table_data("composite_pk_test", "id1 = 1 AND id2 = '#{uuid}'", + "1 | #{uuid} | {\"key\": \"value\"} | 1,2,3 | updated_composite") + + # Test DELETE + @source_db.exec_params( + "DELETE FROM public.composite_pk_test WHERE id1 = $1 AND id2 = $2", + [1, uuid] + ) + end + + def test_multi_unique_operations + @logger.info "Testing operations on multi_unique_test..." + uuid = SecureRandom.uuid + timestamp = '2024-03-20 10:00:00+00' + + # Test INSERT + @source_db.exec_params( + "INSERT INTO public.multi_unique_test (unique_col1, unique_col2, array_data, hstore_data, data) + VALUES ($1, $2, $3, $4, $5)", + [timestamp, uuid, '{test1,test2}', 'key1=>value1,key2=>value2', '{"test": "multi"}'] + ) + + sleep 1 + verify_table_data("multi_unique_test", + "unique_col1 = '#{timestamp}' AND unique_col2 = '#{uuid}'", + "1 | #{timestamp} | #{uuid} | test1,test2 | \"key1\"=>\"value1\", \"key2\"=>\"value2\" | {\"test\": \"multi\"}") + + # Test UPDATE + @source_db.exec_params( + "UPDATE public.multi_unique_test SET data = $1 WHERE unique_col1 = $2 AND unique_col2 = $3", + ['{"test": "updated_multi"}', timestamp, uuid] + ) + + sleep 1 + verify_table_data("multi_unique_test", + "unique_col1 = '#{timestamp}' AND unique_col2 = '#{uuid}'", + "1 | #{timestamp} | #{uuid} | test1,test2 | \"key1\"=>\"value1\", \"key2\"=>\"value2\" | {\"test\": \"updated_multi\"}") + + # Test DELETE + @source_db.exec_params( + "DELETE FROM public.multi_unique_test WHERE unique_col1 = $1 AND unique_col2 = $2", + [timestamp, uuid] + ) + end + + def test_mixed_constraints_operations + @logger.info "Testing operations on mixed_constraints_test..." + uuid = SecureRandom.uuid + + # Test INSERT + @source_db.exec_params( + "INSERT INTO public.mixed_constraints_test (unique_col, non_unique_col, array_data, data) + VALUES ($1, $2, $3, $4)", + [uuid, '{"key": "non_unique1"}', '{1,2,3}', 'test_mixed'] + ) + + sleep 1 + verify_table_data("mixed_constraints_test", "id = 1", + "1 | #{uuid} | {\"key\": \"non_unique1\"} | 1,2,3 | test_mixed") + + # Test UPDATE + @source_db.exec_params( + "UPDATE public.mixed_constraints_test SET data = $1 WHERE id = $2", + ['updated_by_pk', 1] + ) + + sleep 1 + verify_table_data("mixed_constraints_test", "id = 1", + "1 | #{uuid} | {\"key\": \"non_unique1\"} | 1,2,3 | updated_by_pk") + + # Test DELETE + @source_db.exec_params("DELETE FROM public.mixed_constraints_test WHERE id = $1", [1]) + end + + def build_verification_query(table, condition) + case table + when "pk_test" + <<~SQL + SELECT ( + id::text || ' | ' || + text_data || ' | ' || + json_data::text || ' | ' || + array_to_string(array_data, ',') || ' | ' || + numeric_data::text || ' | ' || + timestamp_data::text || ' | ' || + enum_data::text || ' | ' || + COALESCE(nullable_column, '') + ) AS row_data + FROM public.#{table} + WHERE #{condition} + ORDER BY id + SQL + when "unique_test" + <<~SQL + SELECT ( + id::text || ' | ' || + unique_col::text || ' | ' || + '\\x' || encode(binary_data, 'hex') || ' | ' || + interval_data::text || ' | ' || + data::text + ) AS row_data + FROM public.#{table} + WHERE #{condition} + ORDER BY id + SQL + when "composite_pk_test" + <<~SQL + SELECT ( + id1::text || ' | ' || + id2::text || ' | ' || + json_data::text || ' | ' || + array_to_string(array_int_data, ',') || ' | ' || + data + ) AS row_data + FROM public.#{table} + WHERE #{condition} + ORDER BY id1, id2 + SQL + when "multi_unique_test" + <<~SQL + SELECT ( + id::text || ' | ' || + unique_col1::text || ' | ' || + unique_col2::text || ' | ' || + array_to_string(array_data, ',') || ' | ' || + hstore_data::text || ' | ' || + data::text + ) AS row_data + FROM public.#{table} + WHERE #{condition} + ORDER BY id + SQL + when "mixed_constraints_test" + <<~SQL + SELECT ( + id::text || ' | ' || + unique_col::text || ' | ' || + non_unique_col::text || ' | ' || + array_to_string(array_data, ',') || ' | ' || + data + ) AS row_data + FROM public.#{table} + WHERE #{condition} + ORDER BY id + SQL + end + end + + def verify_table_data(table, condition, expected_values) + max_retries = 3 + retry_count = 0 + + while retry_count < max_retries + query = build_verification_query(table, condition) + result = @target_db.exec(query) + + actual_values = result.ntuples > 0 ? result[0]['row_data'] : "" + + if actual_values == expected_values + @logger.info "✓ Table #{table} data verified successfully" + return true + end + + retry_count += 1 + if retry_count < max_retries + @logger.info "Retrying verification (attempt #{retry_count + 1}/#{max_retries})..." + sleep 1 + end + end + + @logger.error "✗ Table #{table} data verification failed" + @logger.error " Condition: [#{condition}]" + @logger.error " Expected: #{expected_values}" + @logger.error " Actual: #{actual_values}" + false + end + + def verify_final_state + @logger.info "Verifying final state..." + sleep 2 # Give time for final DELETEs to replicate + + tables = ["pk_test", "unique_test", "composite_pk_test", "multi_unique_test", "mixed_constraints_test"] + + failed = false + tables.each do |table| + result = @target_db.exec("SELECT COUNT(*) FROM #{table}") + count = result[0]['count'].to_i + if count != 0 + @logger.error "Table #{table} verification failed: expected 0 rows, got #{count}" + @logger.error "Current state of #{table}:" + @target_db.exec("SELECT * FROM #{table}").each do |row| + @logger.error " #{row.values.join(' | ')}" + end + failed = true + end + end + + if failed + false + else + @logger.info "Final state verified successfully" + true + end + end + + def cleanup + Process.kill('TERM', @replicator_pid) if @replicator_pid + Process.kill('TERM', @worker_pid) if @worker_pid + Process.wait(@replicator_pid) if @replicator_pid + Process.wait(@worker_pid) if @worker_pid + end + + def run_test + success = true + begin + create_test_tables + start_pg_flo + + sleep 2 + + # Run all tests and track success + success = success && test_pk_operations + success = success && test_unique_operations + success = success && test_composite_pk_operations + success = success && test_multi_unique_operations + success = success && test_mixed_constraints_operations + + sleep 5 + success = success && verify_final_state + ensure + cleanup + end + success + end +end + +# Simple main execution +if __FILE__ == $0 + begin + test = PostgresUniquenessTest.new + exit(test.run_test ? 0 : 1) + rescue => e + puts "Error: #{e.message}" + puts e.backtrace + exit 1 + end +end diff --git a/internal/scripts/e2e_test_local.sh b/internal/scripts/e2e_test_local.sh index 4f76805..1c7b415 100755 --- a/internal/scripts/e2e_test_local.sh +++ b/internal/scripts/e2e_test_local.sh @@ -33,8 +33,8 @@ make build setup_docker -log "Running e2e ddl tests..." -if CI=false ./internal/scripts/e2e_ddl.sh; then +log "Running e2e postgres uniquness tests..." +if CI=false ruby ./internal/scripts/e2e_postgres_uniquness.rb; then success "e2e ddl tests completed successfully" else error "Original e2e tests failed" diff --git a/pkg/replicator/base_replicator.go b/pkg/replicator/base_replicator.go index b4dc83e..ebdefd0 100644 --- a/pkg/replicator/base_replicator.go +++ b/pkg/replicator/base_replicator.go @@ -30,14 +30,15 @@ func GeneratePublicationName(group string) string { // BaseReplicator provides core functionality for PostgreSQL logical replication type BaseReplicator struct { - Config Config - ReplicationConn ReplicationConnection - StandardConn StandardConnection - Relations map[uint32]*pglogrepl.RelationMessage - Logger zerolog.Logger - TableDetails map[string][]string - LastLSN pglogrepl.LSN - NATSClient NATSClient + Config Config + ReplicationConn ReplicationConnection + StandardConn StandardConnection + Relations map[uint32]*pglogrepl.RelationMessage + Logger zerolog.Logger + TableDetails map[string][]string + LastLSN pglogrepl.LSN + NATSClient NATSClient + TableReplicationKeys map[string]utils.ReplicationKey } // NewBaseReplicator creates a new BaseReplicator instance @@ -316,7 +317,7 @@ func (r *BaseReplicator) HandleInsertMessage(msg *pglogrepl.InsertMessage, lsn p } cdcMessage := utils.CDCMessage{ - Type: "INSERT", + Type: utils.OperationInsert, Schema: relation.Namespace, Table: relation.RelationName, Columns: relation.Columns, @@ -337,22 +338,36 @@ func (r *BaseReplicator) HandleUpdateMessage(msg *pglogrepl.UpdateMessage, lsn p } cdcMessage := utils.CDCMessage{ - Type: "UPDATE", + Type: utils.OperationUpdate, Schema: relation.Namespace, Table: relation.RelationName, Columns: relation.Columns, NewTuple: msg.NewTuple, OldTuple: msg.OldTuple, LSN: lsn.String(), + EmittedAt: time.Now(), ToastedColumns: make(map[string]bool), } + // Track toasted columns for i, col := range relation.Columns { - newVal := msg.NewTuple.Columns[i] - cdcMessage.ToastedColumns[col.Name] = newVal.DataType == 'u' + if msg.NewTuple != nil { + newVal := msg.NewTuple.Columns[i] + cdcMessage.ToastedColumns[col.Name] = newVal.DataType == 'u' + } } + // Add replication key information r.AddPrimaryKeyInfo(&cdcMessage, relation.RelationName) + + // Ensure we have valid column types + for i, col := range relation.Columns { + if msg.NewTuple != nil && msg.NewTuple.Columns[i].Data != nil { + // Ensure proper type information is preserved + col.DataType = uint32(msg.NewTuple.Columns[i].DataType) + } + } + return r.PublishToNATS(cdcMessage) } @@ -364,7 +379,7 @@ func (r *BaseReplicator) HandleDeleteMessage(msg *pglogrepl.DeleteMessage, lsn p } cdcMessage := utils.CDCMessage{ - Type: "DELETE", + Type: utils.OperationDelete, Schema: relation.Namespace, Table: relation.RelationName, Columns: relation.Columns, @@ -409,10 +424,14 @@ func (r *BaseReplicator) PublishToNATS(data utils.CDCMessage) error { return nil } -// AddPrimaryKeyInfo adds primary key information to the CDCMessage +// AddPrimaryKeyInfo adds replication key information to the CDCMessage func (r *BaseReplicator) AddPrimaryKeyInfo(message *utils.CDCMessage, table string) { - if pkColumns, ok := r.TableDetails[table]; ok && len(pkColumns) > 0 { - message.PrimaryKeyColumn = pkColumns[0] + if key, ok := r.TableReplicationKeys[table]; ok { + message.ReplicationKey = key + } else { + r.Logger.Error(). + Str("table", table). + Msg("No replication key information found for table. This should not happen as validation is done during initialization") } } @@ -504,37 +523,99 @@ func (r *BaseReplicator) closeConnections(ctx context.Context) error { // InitializePrimaryKeyInfo initializes primary key information for all tables func (r *BaseReplicator) InitializePrimaryKeyInfo() error { - for _, table := range r.Config.Tables { - column, err := r.getPrimaryKeyColumn(r.Config.Schema, table) - if err != nil { - return err - } - r.TableDetails[table] = []string{column} - } - return nil -} - -// getPrimaryKeyColumn retrieves the primary key column for a given table -func (r *BaseReplicator) getPrimaryKeyColumn(schema, table string) (string, error) { query := ` - SELECT pg_attribute.attname - FROM pg_index, pg_class, pg_attribute, pg_namespace - WHERE - pg_class.oid = $1::regclass AND - indrelid = pg_class.oid AND - nspname = $2 AND - pg_class.relnamespace = pg_namespace.oid AND - pg_attribute.attrelid = pg_class.oid AND - pg_attribute.attnum = any(pg_index.indkey) AND - indisprimary - LIMIT 1 + WITH table_info AS ( + SELECT + t.tablename, + c.relreplident, + ( + SELECT array_agg(a.attname ORDER BY array_position(i.indkey, a.attnum)) + FROM pg_index i + JOIN pg_attribute a ON a.attrelid = c.oid AND a.attnum = ANY(i.indkey) + WHERE i.indrelid = c.oid AND i.indisprimary + ) as pk_columns, + ( + SELECT array_agg(a.attname ORDER BY array_position(i.indkey, a.attnum)) + FROM pg_index i + JOIN pg_attribute a ON a.attrelid = c.oid AND a.attnum = ANY(i.indkey) + WHERE i.indrelid = c.oid AND i.indisunique AND NOT i.indisprimary + LIMIT 1 + ) as unique_columns + FROM pg_tables t + JOIN pg_class c ON t.tablename = c.relname + JOIN pg_namespace n ON c.relnamespace = n.oid + WHERE t.schemaname = $1 + ) + SELECT + tablename, + relreplident::text, + COALESCE(pk_columns, ARRAY[]::text[]) as pk_columns, + COALESCE(unique_columns, ARRAY[]::text[]) as unique_columns + FROM table_info; ` - var column string - err := r.StandardConn.QueryRow(context.Background(), query, fmt.Sprintf("%s.%s", schema, table), schema).Scan(&column) + + rows, err := r.StandardConn.Query(context.Background(), query, r.Config.Schema) if err != nil { - return "", fmt.Errorf("failed to query primary key column: %v", err) + return fmt.Errorf("failed to query replication key info: %v", err) } - return column, nil + defer rows.Close() + + r.TableReplicationKeys = make(map[string]utils.ReplicationKey) + + for rows.Next() { + var ( + tableName string + replicaIdentity string + pkColumns []string + uniqueColumns []string + ) + + if err := rows.Scan(&tableName, &replicaIdentity, &pkColumns, &uniqueColumns); err != nil { + return fmt.Errorf("failed to scan row: %v", err) + } + + key := utils.ReplicationKey{} + + switch { + case len(pkColumns) > 0: + key = utils.ReplicationKey{ + Type: utils.ReplicationKeyPK, + Columns: pkColumns, + } + case len(uniqueColumns) > 0: + key = utils.ReplicationKey{ + Type: utils.ReplicationKeyUnique, + Columns: uniqueColumns, + } + case replicaIdentity == "f": + key = utils.ReplicationKey{ + Type: utils.ReplicationKeyFull, + Columns: nil, + } + } + + if err := r.validateTableReplicationKey(tableName, key); err != nil { + r.Logger.Warn(). + Str("table", tableName). + Str("replica_identity", replicaIdentity). + Str("key_type", string(key.Type)). + Strs("columns", key.Columns). + Err(err). + Msg("Invalid replication key configuration") + continue + } + + r.TableReplicationKeys[tableName] = key + + r.Logger.Debug(). + Str("table", tableName). + Str("key_type", string(key.Type)). + Strs("columns", key.Columns). + Str("replica_identity", replicaIdentity). + Msg("Initialized replication key configuration") + } + + return rows.Err() } // SaveState saves the current replication state @@ -603,3 +684,15 @@ func (r *BaseReplicator) GetConfiguredTables(ctx context.Context) ([]string, err return tables, nil } + +func (r *BaseReplicator) validateTableReplicationKey(tableName string, key utils.ReplicationKey) error { + if !key.IsValid() { + return fmt.Errorf( + "table %q requires one of the following:\n"+ + "\t1. A PRIMARY KEY constraint\n"+ + "\t2. A UNIQUE constraint\n"+ + "\t3. REPLICA IDENTITY FULL (ALTER TABLE %s REPLICA IDENTITY FULL)", + tableName, tableName) + } + return nil +} diff --git a/pkg/replicator/copy_and_stream_replicator.go b/pkg/replicator/copy_and_stream_replicator.go index 93e2ed6..b8a5e5d 100644 --- a/pkg/replicator/copy_and_stream_replicator.go +++ b/pkg/replicator/copy_and_stream_replicator.go @@ -370,7 +370,7 @@ func (r *CopyAndStreamReplicator) executeCopyQuery(ctx context.Context, tx pgx.T } cdcMessage := utils.CDCMessage{ - Type: "INSERT", + Type: utils.OperationInsert, Schema: schema, Table: tableName, Columns: columns, diff --git a/pkg/replicator/ddl_replicator.go b/pkg/replicator/ddl_replicator.go index d306956..ec7e58b 100644 --- a/pkg/replicator/ddl_replicator.go +++ b/pkg/replicator/ddl_replicator.go @@ -195,7 +195,7 @@ func (d *DDLReplicator) ProcessDDLEvents(ctx context.Context) error { } cdcMessage := utils.CDCMessage{ - Type: "DDL", + Type: utils.OperationDDL, Schema: schema, Table: table, EmittedAt: time.Now(), diff --git a/pkg/replicator/tests/base_replicator_test.go b/pkg/replicator/tests/base_replicator_test.go index f3ca005..ed7d14d 100644 --- a/pkg/replicator/tests/base_replicator_test.go +++ b/pkg/replicator/tests/base_replicator_test.go @@ -29,12 +29,26 @@ func TestBaseReplicator(t *testing.T) { mockStandardConn := new(MockStandardConnection) mockNATSClient := new(MockNATSClient) - mockRows := new(MockRows) - mockRows.On("Next").Return(false).Once() - mockRows.On("Err").Return(nil).Once() - mockRows.On("Close").Return().Once() - - mockStandardConn.On("Query", mock.Anything, mock.Anything, mock.Anything).Return(mockRows, nil).Once() + // Mock for InitializeOIDMap query + mockOIDRows := new(MockRows) + mockOIDRows.On("Next").Return(false) + mockOIDRows.On("Err").Return(nil) + mockOIDRows.On("Close").Return() + + // Mock for InitializePrimaryKeyInfo query + mockPKRows := new(MockRows) + mockPKRows.On("Next").Return(false) + mockPKRows.On("Err").Return(nil) + mockPKRows.On("Close").Return() + + // Set up expectations for both queries + mockStandardConn.On("Query", mock.Anything, mock.MatchedBy(func(q string) bool { + return strings.Contains(q, "pg_type") + }), mock.Anything).Return(mockOIDRows, nil).Once() + + mockStandardConn.On("Query", mock.Anything, mock.MatchedBy(func(q string) bool { + return strings.Contains(q, "table_info") + }), mock.Anything).Return(mockPKRows, nil).Once() mockPoolConn := &MockPgxPoolConn{} mockStandardConn.On("Acquire", mock.Anything).Return(mockPoolConn, nil).Maybe() @@ -58,7 +72,8 @@ func TestBaseReplicator(t *testing.T) { assert.Equal(t, mockNATSClient, br.NATSClient) mockStandardConn.AssertExpectations(t) - mockRows.AssertExpectations(t) + mockOIDRows.AssertExpectations(t) + mockPKRows.AssertExpectations(t) }) t.Run("CreatePublication", func(t *testing.T) { @@ -363,7 +378,7 @@ func TestBaseReplicator(t *testing.T) { return false } - assert.Equal(t, "INSERT", decodedMsg.Type) + assert.Equal(t, utils.OperationInsert, decodedMsg.Type) assert.Equal(t, "public", decodedMsg.Schema) assert.Equal(t, "users", decodedMsg.Table) assert.Equal(t, msg.Tuple, decodedMsg.NewTuple) @@ -536,7 +551,7 @@ func TestBaseReplicator(t *testing.T) { return false } - assert.Equal(t, "INSERT", decodedMsg.Type) + assert.Equal(t, utils.OperationInsert, decodedMsg.Type) assert.Equal(t, "public", decodedMsg.Schema) assert.Equal(t, "test_table", decodedMsg.Table) @@ -629,7 +644,7 @@ func TestBaseReplicator(t *testing.T) { return false } - assert.Equal(t, "UPDATE", decodedMsg.Type) + assert.Equal(t, utils.OperationUpdate, decodedMsg.Type) assert.Equal(t, "public", decodedMsg.Schema) assert.Equal(t, "users", decodedMsg.Table) assert.Equal(t, msg.OldTuple, decodedMsg.OldTuple) @@ -680,8 +695,8 @@ func TestBaseReplicator(t *testing.T) { OldTuple: nil, NewTuple: &pglogrepl.TupleData{ Columns: []*pglogrepl.TupleDataColumn{ - {Data: []byte("1")}, - {Data: []byte("John Doe")}, + {Data: []byte("1"), DataType: pgtype.Int4OID}, + {Data: []byte("John Doe"), DataType: pgtype.TextOID}, }, }, } @@ -694,18 +709,16 @@ func TestBaseReplicator(t *testing.T) { return false } - assert.Equal(t, "UPDATE", decodedMsg.Type) + assert.Equal(t, utils.OperationUpdate, decodedMsg.Type) assert.Equal(t, "public", decodedMsg.Schema) assert.Equal(t, "users", decodedMsg.Table) assert.Nil(t, decodedMsg.OldTuple) assert.NotNil(t, decodedMsg.NewTuple) - assert.Equal(t, msg.NewTuple, decodedMsg.NewTuple) - assert.Len(t, decodedMsg.Columns, 2) - assert.Equal(t, "id", decodedMsg.Columns[0].Name) assert.Equal(t, uint32(pgtype.Int4OID), decodedMsg.Columns[0].DataType) - assert.Equal(t, "name", decodedMsg.Columns[1].Name) assert.Equal(t, uint32(pgtype.TextOID), decodedMsg.Columns[1].DataType) + assert.Equal(t, []byte("1"), decodedMsg.NewTuple.Columns[0].Data) + assert.Equal(t, []byte("John Doe"), decodedMsg.NewTuple.Columns[1].Data) return true })).Return(nil) @@ -754,7 +767,7 @@ func TestBaseReplicator(t *testing.T) { return false } - assert.Equal(t, "DELETE", decodedMsg.Type) + assert.Equal(t, utils.OperationDelete, decodedMsg.Type) assert.Equal(t, "public", decodedMsg.Schema) assert.Equal(t, "users", decodedMsg.Table) assert.Equal(t, msg.OldTuple, decodedMsg.OldTuple) @@ -819,7 +832,7 @@ func TestBaseReplicator(t *testing.T) { } data := utils.CDCMessage{ - Type: "INSERT", + Type: utils.OperationInsert, Schema: "public", Table: "users", Columns: []*pglogrepl.RelationMessageColumn{ @@ -842,7 +855,7 @@ func TestBaseReplicator(t *testing.T) { return false } - assert.Equal(t, "INSERT", decodedMsg.Type) + assert.Equal(t, utils.OperationInsert, decodedMsg.Type) assert.Equal(t, "public", decodedMsg.Schema) assert.Equal(t, "users", decodedMsg.Table) @@ -877,7 +890,7 @@ func TestBaseReplicator(t *testing.T) { } data := utils.CDCMessage{ - Type: "INSERT", + Type: utils.OperationInsert, Schema: "public", Table: "users", Columns: []*pglogrepl.RelationMessageColumn{ @@ -905,8 +918,11 @@ func TestBaseReplicator(t *testing.T) { t.Run("AddPrimaryKeyInfo", func(t *testing.T) { t.Run("Successful addition of primary key info", func(t *testing.T) { br := &replicator.BaseReplicator{ - TableDetails: map[string][]string{ - "public.users": {"id"}, + TableReplicationKeys: map[string]utils.ReplicationKey{ + "public.users": { + Type: utils.ReplicationKeyPK, + Columns: []string{"id"}, + }, }, } @@ -938,7 +954,10 @@ func TestBaseReplicator(t *testing.T) { {Data: []byte("John Doe")}, }, }, - PrimaryKeyColumn: "id", + ReplicationKey: utils.ReplicationKey{ + Type: utils.ReplicationKeyPK, + Columns: []string{"id"}, + }, } br.AddPrimaryKeyInfo(message, "public.users") diff --git a/pkg/replicator/tests/copy_and_stream_replicator_test.go b/pkg/replicator/tests/copy_and_stream_replicator_test.go index 98ca062..0e7fbc3 100644 --- a/pkg/replicator/tests/copy_and_stream_replicator_test.go +++ b/pkg/replicator/tests/copy_and_stream_replicator_test.go @@ -72,7 +72,7 @@ func TestCopyAndStreamReplicator(t *testing.T) { return false } - assert.Equal(t, "INSERT", decodedMsg.Type) + assert.Equal(t, utils.OperationInsert, decodedMsg.Type) assert.Equal(t, "public", decodedMsg.Schema) assert.Equal(t, "users", decodedMsg.Table) @@ -286,7 +286,7 @@ func TestCopyAndStreamReplicator(t *testing.T) { err := decodedMsg.UnmarshalBinary(data) assert.NoError(t, err, "Failed to unmarshal binary data") - assert.Equal(t, "INSERT", decodedMsg.Type) + assert.Equal(t, utils.OperationInsert, decodedMsg.Type) assert.Equal(t, "public", decodedMsg.Schema) assert.Equal(t, "test_table", decodedMsg.Table) assert.Equal(t, len(tc.expected), len(decodedMsg.NewTuple.Columns)) diff --git a/pkg/replicator/tests/json_encoder_test.go b/pkg/replicator/tests/json_encoder_test.go index b212eba..fe5ec52 100644 --- a/pkg/replicator/tests/json_encoder_test.go +++ b/pkg/replicator/tests/json_encoder_test.go @@ -21,7 +21,7 @@ func TestOIDToString(t *testing.T) { func TestCDCBinaryEncoding(t *testing.T) { t.Run("Encode and decode preserves CDC types", func(t *testing.T) { testData := utils.CDCMessage{ - Type: "INSERT", + Type: utils.OperationInsert, Schema: "public", Table: "users", Columns: []*pglogrepl.RelationMessageColumn{ @@ -62,7 +62,7 @@ func TestBinaryEncodingComplexTypes(t *testing.T) { textArrayValue := []byte("{hello,world}") testData := utils.CDCMessage{ - Type: "INSERT", + Type: utils.OperationInsert, Schema: "public", Table: "complex_types", Columns: []*pglogrepl.RelationMessageColumn{ diff --git a/pkg/routing/router.go b/pkg/routing/router.go index 688e525..313dc8f 100644 --- a/pkg/routing/router.go +++ b/pkg/routing/router.go @@ -48,7 +48,7 @@ func (r *Router) ApplyRouting(message *utils.CDCMessage) (*utils.CDCMessage, err return message, nil } - if !ContainsOperation(route.Operations, utils.OperationType(message.Type)) { + if !ContainsOperation(route.Operations, message.Type) { return nil, nil } @@ -66,13 +66,18 @@ func (r *Router) ApplyRouting(message *utils.CDCMessage) (*utils.CDCMessage, err newColumns[i] = &newCol } routedMessage.Columns = newColumns - } - routedMessage.MappedPrimaryKeyColumn = message.PrimaryKeyColumn - for _, mapping := range route.ColumnMappings { - if mapping.Source == message.PrimaryKeyColumn { - routedMessage.MappedPrimaryKeyColumn = mapping.Destination - break + if routedMessage.ReplicationKey.Type != utils.ReplicationKeyFull { + mappedColumns := make([]string, len(routedMessage.ReplicationKey.Columns)) + for i, keyCol := range routedMessage.ReplicationKey.Columns { + mappedName := GetMappedColumnName(route.ColumnMappings, keyCol) + if mappedName != "" { + mappedColumns[i] = mappedName + } else { + mappedColumns[i] = keyCol + } + } + routedMessage.ReplicationKey.Columns = mappedColumns } } diff --git a/pkg/routing/tests/routing_test.go b/pkg/routing/tests/routing_test.go index 7233761..aefd74b 100644 --- a/pkg/routing/tests/routing_test.go +++ b/pkg/routing/tests/routing_test.go @@ -27,7 +27,7 @@ func TestRouter_ApplyRouting(t *testing.T) { }, }, inputMessage: &utils.CDCMessage{ - Type: string(utils.OperationInsert), + Type: utils.OperationInsert, Table: "source_table", Columns: []*pglogrepl.RelationMessageColumn{ {Name: "id", DataType: 23}, @@ -35,7 +35,7 @@ func TestRouter_ApplyRouting(t *testing.T) { }, }, expectedOutput: &utils.CDCMessage{ - Type: string(utils.OperationInsert), + Type: utils.OperationInsert, Table: "dest_table", Columns: []*pglogrepl.RelationMessageColumn{ {Name: "id", DataType: 23}, @@ -57,7 +57,7 @@ func TestRouter_ApplyRouting(t *testing.T) { }, }, inputMessage: &utils.CDCMessage{ - Type: string(utils.OperationUpdate), + Type: utils.OperationUpdate, Table: "users", Columns: []*pglogrepl.RelationMessageColumn{ {Name: "user_id", DataType: 23}, @@ -66,7 +66,7 @@ func TestRouter_ApplyRouting(t *testing.T) { }, }, expectedOutput: &utils.CDCMessage{ - Type: string(utils.OperationUpdate), + Type: utils.OperationUpdate, Table: "customers", Columns: []*pglogrepl.RelationMessageColumn{ {Name: "customer_id", DataType: 23}, @@ -85,11 +85,11 @@ func TestRouter_ApplyRouting(t *testing.T) { }, }, inputMessage: &utils.CDCMessage{ - Type: string(utils.OperationUpdate), + Type: utils.OperationUpdate, Table: "orders", }, expectedOutput: &utils.CDCMessage{ - Type: string(utils.OperationUpdate), + Type: utils.OperationUpdate, Table: "processed_orders", }, }, @@ -103,7 +103,7 @@ func TestRouter_ApplyRouting(t *testing.T) { }, }, inputMessage: &utils.CDCMessage{ - Type: string(utils.OperationDelete), + Type: utils.OperationDelete, Table: "orders", }, expectNil: true, @@ -112,11 +112,11 @@ func TestRouter_ApplyRouting(t *testing.T) { name: "No route for table", routes: map[string]routing.TableRoute{}, inputMessage: &utils.CDCMessage{ - Type: string(utils.OperationInsert), + Type: utils.OperationInsert, Table: "unknown_table", }, expectedOutput: &utils.CDCMessage{ - Type: string(utils.OperationInsert), + Type: utils.OperationInsert, Table: "unknown_table", }, }, diff --git a/pkg/rules/engine.go b/pkg/rules/engine.go index f4e3027..b42bbf8 100644 --- a/pkg/rules/engine.go +++ b/pkg/rules/engine.go @@ -25,7 +25,7 @@ func (re *RuleEngine) ApplyRules(message *utils.CDCMessage) (*utils.CDCMessage, logger.Info(). Str("table", message.Table). - Str("operation", message.Type). + Str("operation", string(message.Type)). Int("ruleCount", len(rules)). Msg("Applying rules") diff --git a/pkg/rules/rules.go b/pkg/rules/rules.go index 49f99be..4136400 100644 --- a/pkg/rules/rules.go +++ b/pkg/rules/rules.go @@ -78,7 +78,7 @@ func NewRegexTransformRule(table, column string, params map[string]interface{}) } transform := func(m *utils.CDCMessage) (*utils.CDCMessage, error) { - value, err := m.GetColumnValue(column) + value, err := m.GetColumnValue(column, false) if err != nil { return m, nil } @@ -109,7 +109,7 @@ func NewMaskTransformRule(table, column string, params map[string]interface{}) ( } transform := func(m *utils.CDCMessage) (*utils.CDCMessage, error) { - value, err := m.GetColumnValue(column) + value, err := m.GetColumnValue(column, false) if err != nil { return m, nil } @@ -179,7 +179,8 @@ func NewFilterRule(table, column string, params map[string]interface{}) (Rule, e // NewComparisonCondition creates a new comparison condition function func NewComparisonCondition(column, operator string, value interface{}) func(*utils.CDCMessage) bool { return func(m *utils.CDCMessage) bool { - columnValue, err := m.GetColumnValue(column) + useOldValues := m.Type == utils.OperationDelete + columnValue, err := m.GetColumnValue(column, useOldValues) if err != nil { return false } @@ -261,7 +262,8 @@ func NewComparisonCondition(column, operator string, value interface{}) func(*ut // NewContainsCondition creates a new contains condition function func NewContainsCondition(column string, value interface{}) func(*utils.CDCMessage) bool { return func(m *utils.CDCMessage) bool { - columnValue, err := m.GetColumnValue(column) + useOldValues := m.Type == utils.OperationDelete + columnValue, err := m.GetColumnValue(column, useOldValues) if err != nil { return false } @@ -356,12 +358,12 @@ func compareNumericValues(a, b string, operator string) bool { // Apply applies the transform rule to the provided data func (r *TransformRule) Apply(message *utils.CDCMessage) (*utils.CDCMessage, error) { - if !containsOperation(r.Operations, utils.OperationType(message.Type)) { + if !containsOperation(r.Operations, message.Type) { return message, nil } // Don't apply rule if asked not to - if message.Type == "DELETE" && r.AllowEmptyDeletes { + if message.Type == utils.OperationDelete && r.AllowEmptyDeletes { return message, nil } @@ -370,16 +372,17 @@ func (r *TransformRule) Apply(message *utils.CDCMessage) (*utils.CDCMessage, err // Apply applies the filter rule to the provided data func (r *FilterRule) Apply(message *utils.CDCMessage) (*utils.CDCMessage, error) { - if !containsOperation(r.Operations, utils.OperationType(message.Type)) { + if !containsOperation(r.Operations, message.Type) { return message, nil } // Don't apply rule if asked not to - if message.Type == "DELETE" && r.AllowEmptyDeletes { + if message.Type == utils.OperationDelete && r.AllowEmptyDeletes { return message, nil } passes := r.Condition(message) + logger.Debug(). Str("column", r.ColumnName). Any("operation", message.Type). diff --git a/pkg/rules/tests/engine_test.go b/pkg/rules/tests/engine_test.go index d5b3a49..ad8ad02 100644 --- a/pkg/rules/tests/engine_test.go +++ b/pkg/rules/tests/engine_test.go @@ -29,7 +29,7 @@ func TestRuleEngine_AddRule(t *testing.T) { re.AddRule("users", rule) message := &utils.CDCMessage{ - Type: "INSERT", + Type: utils.OperationInsert, Schema: "public", Table: "users", Columns: []*pglogrepl.RelationMessageColumn{ @@ -60,7 +60,7 @@ func TestRuleEngine_ApplyRules(t *testing.T) { re.AddRule("users", rule) message := &utils.CDCMessage{ - Type: "INSERT", + Type: utils.OperationInsert, Schema: "public", Table: "users", Columns: []*pglogrepl.RelationMessageColumn{ @@ -76,7 +76,7 @@ func TestRuleEngine_ApplyRules(t *testing.T) { result, err := re.ApplyRules(message) assert.NoError(t, err) - value, err := result.GetColumnValue("test_column") + value, err := result.GetColumnValue("test_column", false) assert.NoError(t, err) assert.Equal(t, "transformed", value) } @@ -84,7 +84,7 @@ func TestRuleEngine_ApplyRules(t *testing.T) { func TestRuleEngine_ApplyRules_NoRules(t *testing.T) { re := rules.NewRuleEngine() message := &utils.CDCMessage{ - Type: "INSERT", + Type: utils.OperationInsert, Schema: "public", Table: "users", Columns: []*pglogrepl.RelationMessageColumn{ @@ -103,7 +103,7 @@ func TestRuleEngine_ApplyRules_NoRules(t *testing.T) { assert.Equal(t, message, result) } -func TestRuleEngine_LoadRules(t *testing.T) { +func TestRuleEngine_LoadRules_Transform(t *testing.T) { re := rules.NewRuleEngine() config := rules.Config{ Tables: map[string][]rules.RuleConfig{ @@ -117,12 +117,46 @@ func TestRuleEngine_LoadRules(t *testing.T) { }, Operations: []utils.OperationType{utils.OperationInsert, utils.OperationUpdate}, }, + }, + }, + } + + err := re.LoadRules(config) + assert.NoError(t, err) + + message := &utils.CDCMessage{ + Type: utils.OperationInsert, + Schema: "public", + Table: "users", + Columns: []*pglogrepl.RelationMessageColumn{ + {Name: "test_column", DataType: pgtype.TextOID}, + }, + NewTuple: &pglogrepl.TupleData{ + Columns: []*pglogrepl.TupleDataColumn{ + {Data: []byte("test")}, + }, + }, + } + + result, err := re.ApplyRules(message) + assert.NoError(t, err) + assert.NotNil(t, result) + value, err := result.GetColumnValue("test_column", false) + assert.NoError(t, err) + assert.Equal(t, "t**t", value) +} + +func TestRuleEngine_LoadRules_Filter(t *testing.T) { + re := rules.NewRuleEngine() + config := rules.Config{ + Tables: map[string][]rules.RuleConfig{ + "users": { { Type: "filter", Column: "id", Parameters: map[string]interface{}{ "operator": "gt", - "value": int64(100), // Change this to int64 + "value": int64(100), }, Operations: []utils.OperationType{utils.OperationDelete}, }, @@ -134,57 +168,74 @@ func TestRuleEngine_LoadRules(t *testing.T) { assert.NoError(t, err) message := &utils.CDCMessage{ - Type: "INSERT", + Type: utils.OperationDelete, Schema: "public", Table: "users", Columns: []*pglogrepl.RelationMessageColumn{ - {Name: "test_column", DataType: pgtype.TextOID}, {Name: "id", DataType: pgtype.Int8OID}, }, - NewTuple: &pglogrepl.TupleData{ + OldTuple: &pglogrepl.TupleData{ Columns: []*pglogrepl.TupleDataColumn{ - {Data: []byte("test")}, {Data: []byte("101")}, }, }, - OldTuple: nil, } result, err := re.ApplyRules(message) - assert.NoError(t, err) assert.NotNil(t, result) - value, err := result.GetColumnValue("test_column") + value, err := result.GetColumnValue("id", true) assert.NoError(t, err) - assert.Equal(t, "t**t", value) - idValue, err := result.GetColumnValue("id") + assert.Equal(t, int64(101), value) + + message.OldTuple.Columns[0].Data = []byte("99") + result, err = re.ApplyRules(message) assert.NoError(t, err) - assert.Equal(t, int64(101), idValue) + assert.Nil(t, result) +} - message.Type = "DELETE" - message.OldTuple = &pglogrepl.TupleData{ - Columns: []*pglogrepl.TupleDataColumn{ - {Data: []byte("test")}, - {Data: []byte("101")}, +func TestRuleEngine_LoadRules_EmptyDeletes(t *testing.T) { + re := rules.NewRuleEngine() + config := rules.Config{ + Tables: map[string][]rules.RuleConfig{ + "users": { + { + Type: "filter", + Column: "id", + AllowEmptyDeletes: true, + Parameters: map[string]interface{}{ + "operator": "eq", + "value": int64(101), + }, + Operations: []utils.OperationType{utils.OperationDelete}, + }, + }, }, } - message.NewTuple = nil - result, err = re.ApplyRules(message) + err := re.LoadRules(config) assert.NoError(t, err) - assert.NotNil(t, result) - value, err = result.GetColumnValue("test_column") - assert.NoError(t, err) - assert.Equal(t, "t**t", value) - idValue, err = result.GetColumnValue("id") - assert.NoError(t, err) - assert.Equal(t, int64(101), idValue) - message.OldTuple.Columns[1].Data = []byte("99") - result, err = re.ApplyRules(message) + message := &utils.CDCMessage{ + Type: utils.OperationDelete, + Schema: "public", + Table: "users", + Columns: []*pglogrepl.RelationMessageColumn{ + {Name: "id", DataType: pgtype.Int8OID}, + }, + OldTuple: &pglogrepl.TupleData{ + Columns: []*pglogrepl.TupleDataColumn{ + {Data: []byte("101")}, + }, + }, + } + result, err := re.ApplyRules(message) assert.NoError(t, err) - assert.Nil(t, result) + assert.NotNil(t, result) + value, err := result.GetColumnValue("id", true) + assert.NoError(t, err) + assert.Equal(t, int64(101), value) } func TestRuleEngine_ApplyRules_FilterRule(t *testing.T) { @@ -209,7 +260,7 @@ func TestRuleEngine_ApplyRules_FilterRule(t *testing.T) { assert.NoError(t, err) message := &utils.CDCMessage{ - Type: "UPDATE", + Type: utils.OperationUpdate, Schema: "public", Table: "users", Columns: []*pglogrepl.RelationMessageColumn{ @@ -225,7 +276,7 @@ func TestRuleEngine_ApplyRules_FilterRule(t *testing.T) { assert.NoError(t, err) assert.NotNil(t, result) - idValue, err := result.GetColumnValue("id") + idValue, err := result.GetColumnValue("id", false) assert.NoError(t, err) assert.Equal(t, int64(101), idValue) @@ -235,13 +286,13 @@ func TestRuleEngine_ApplyRules_FilterRule(t *testing.T) { assert.NoError(t, err) assert.Nil(t, result) - message.Type = "INSERT" + message.Type = utils.OperationInsert message.NewTuple.Columns[0].Data = []byte("101") result, err = re.ApplyRules(message) assert.NoError(t, err) assert.NotNil(t, result) - idValue, err = result.GetColumnValue("id") + idValue, err = result.GetColumnValue("id", false) assert.NoError(t, err) assert.Equal(t, int64(101), idValue) } diff --git a/pkg/rules/tests/rules_test.go b/pkg/rules/tests/rules_test.go index 81a8066..bfcb3d8 100644 --- a/pkg/rules/tests/rules_test.go +++ b/pkg/rules/tests/rules_test.go @@ -26,8 +26,8 @@ func TestTransformRules(t *testing.T) { "pattern": "@example\\.com$", "replace": "@masked.com", }), - input: createCDCMessage("INSERT", "email", pgtype.TextOID, "user@example.com"), - expectedOutput: createCDCMessage("INSERT", "email", pgtype.TextOID, "user@masked.com"), + input: createCDCMessage(utils.OperationInsert, "email", pgtype.TextOID, "user@example.com"), + expectedOutput: createCDCMessage(utils.OperationInsert, "email", pgtype.TextOID, "user@masked.com"), }, { name: "Mask Transform - Credit Card", @@ -35,8 +35,8 @@ func TestTransformRules(t *testing.T) { "type": "mask", "mask_char": "*", }), - input: createCDCMessage("INSERT", "credit_card", pgtype.TextOID, "1234567890123456"), - expectedOutput: createCDCMessage("INSERT", "credit_card", pgtype.TextOID, "1**************6"), + input: createCDCMessage(utils.OperationInsert, "credit_card", pgtype.TextOID, "1234567890123456"), + expectedOutput: createCDCMessage(utils.OperationInsert, "credit_card", pgtype.TextOID, "1**************6"), }, { name: "Regex Transform - Phone Number", @@ -45,8 +45,8 @@ func TestTransformRules(t *testing.T) { "pattern": "(\\d{3})(\\d{3})(\\d{4})", "replace": "($1) $2-$3", }), - input: createCDCMessage("UPDATE", "phone", pgtype.TextOID, "1234567890"), - expectedOutput: createCDCMessage("UPDATE", "phone", pgtype.TextOID, "(123) 456-7890"), + input: createCDCMessage(utils.OperationUpdate, "phone", pgtype.TextOID, "1234567890"), + expectedOutput: createCDCMessage(utils.OperationUpdate, "phone", pgtype.TextOID, "(123) 456-7890"), }, { name: "Regex Transform - No Match", @@ -55,8 +55,8 @@ func TestTransformRules(t *testing.T) { "pattern": "@example\\.com$", "replace": "@masked.com", }), - input: createCDCMessage("INSERT", "email", pgtype.TextOID, "user@otherdomain.com"), - expectedOutput: createCDCMessage("INSERT", "email", pgtype.TextOID, "user@otherdomain.com"), + input: createCDCMessage(utils.OperationInsert, "email", pgtype.TextOID, "user@otherdomain.com"), + expectedOutput: createCDCMessage(utils.OperationInsert, "email", pgtype.TextOID, "user@otherdomain.com"), }, { name: "Mask Transform - Short String", @@ -64,8 +64,8 @@ func TestTransformRules(t *testing.T) { "type": "mask", "mask_char": "*", }), - input: createCDCMessage("INSERT", "credit_card", pgtype.TextOID, "12"), - expectedOutput: createCDCMessage("INSERT", "credit_card", pgtype.TextOID, "12"), + input: createCDCMessage(utils.OperationInsert, "credit_card", pgtype.TextOID, "12"), + expectedOutput: createCDCMessage(utils.OperationInsert, "credit_card", pgtype.TextOID, "12"), }, { name: "Regex Transform - Non-String Input", @@ -74,8 +74,8 @@ func TestTransformRules(t *testing.T) { "pattern": "^\\d+\\.\\d{2}$", "replace": "$.$$", }), - input: createCDCMessage("UPDATE", "price", pgtype.Float8OID, 99.99), - expectedOutput: createCDCMessage("UPDATE", "price", pgtype.Float8OID, 99.99), + input: createCDCMessage(utils.OperationUpdate, "price", pgtype.Float8OID, 99.99), + expectedOutput: createCDCMessage(utils.OperationUpdate, "price", pgtype.Float8OID, 99.99), }, } @@ -101,8 +101,8 @@ func TestFilterRules(t *testing.T) { "operator": "eq", "value": "completed", }), - input: createCDCMessage("INSERT", "status", pgtype.TextOID, "completed"), - expectedOutput: createCDCMessage("INSERT", "status", pgtype.TextOID, "completed"), + input: createCDCMessage(utils.OperationInsert, "status", pgtype.TextOID, "completed"), + expectedOutput: createCDCMessage(utils.OperationInsert, "status", pgtype.TextOID, "completed"), }, { name: "Equal Filter (Text) - Fail", @@ -110,7 +110,7 @@ func TestFilterRules(t *testing.T) { "operator": "eq", "value": "completed", }), - input: createCDCMessage("INSERT", "status", pgtype.TextOID, "pending"), + input: createCDCMessage(utils.OperationInsert, "status", pgtype.TextOID, "pending"), expectedOutput: nil, }, { @@ -119,8 +119,8 @@ func TestFilterRules(t *testing.T) { "operator": "gt", "value": 10, }), - input: createCDCMessage("UPDATE", "stock", pgtype.Int4OID, 15), - expectedOutput: createCDCMessage("UPDATE", "stock", pgtype.Int4OID, 15), + input: createCDCMessage(utils.OperationUpdate, "stock", pgtype.Int4OID, 15), + expectedOutput: createCDCMessage(utils.OperationUpdate, "stock", pgtype.Int4OID, 15), }, { name: "Less Than Filter (Float) - Pass", @@ -128,8 +128,8 @@ func TestFilterRules(t *testing.T) { "operator": "lt", "value": 100.0, }), - input: createCDCMessage("INSERT", "amount", pgtype.Float8OID, 99.99), - expectedOutput: createCDCMessage("INSERT", "amount", pgtype.Float8OID, 99.99), + input: createCDCMessage(utils.OperationInsert, "amount", pgtype.Float8OID, 99.99), + expectedOutput: createCDCMessage(utils.OperationInsert, "amount", pgtype.Float8OID, 99.99), }, { name: "Contains Filter (Text) - Pass", @@ -137,8 +137,8 @@ func TestFilterRules(t *testing.T) { "operator": "contains", "value": "Premium", }), - input: createCDCMessage("INSERT", "name", pgtype.TextOID, "Premium Widget"), - expectedOutput: createCDCMessage("INSERT", "name", pgtype.TextOID, "Premium Widget"), + input: createCDCMessage(utils.OperationInsert, "name", pgtype.TextOID, "Premium Widget"), + expectedOutput: createCDCMessage(utils.OperationInsert, "name", pgtype.TextOID, "Premium Widget"), }, { name: "Equal Filter (Case Insensitive) - Pass", @@ -146,7 +146,7 @@ func TestFilterRules(t *testing.T) { "operator": "eq", "value": "Electronics", }), - input: createCDCMessage("INSERT", "category", pgtype.TextOID, "electronics"), + input: createCDCMessage(utils.OperationInsert, "category", pgtype.TextOID, "electronics"), expectedOutput: nil, }, { @@ -155,8 +155,8 @@ func TestFilterRules(t *testing.T) { "operator": "gt", "value": "5", }), - input: createCDCMessage("UPDATE", "quantity", pgtype.Int4OID, 10), - expectedOutput: createCDCMessage("UPDATE", "quantity", pgtype.Int4OID, 10), + input: createCDCMessage(utils.OperationUpdate, "quantity", pgtype.Int4OID, 10), + expectedOutput: createCDCMessage(utils.OperationUpdate, "quantity", pgtype.Int4OID, 10), }, { name: "Less Than Filter (Float vs Integer) - Pass", @@ -164,8 +164,8 @@ func TestFilterRules(t *testing.T) { "operator": "lt", "value": 100, }), - input: createCDCMessage("INSERT", "price", pgtype.Float8OID, 99.99), - expectedOutput: createCDCMessage("INSERT", "price", pgtype.Float8OID, 99.99), + input: createCDCMessage(utils.OperationInsert, "price", pgtype.Float8OID, 99.99), + expectedOutput: createCDCMessage(utils.OperationInsert, "price", pgtype.Float8OID, 99.99), }, { name: "Contains Filter (Case Sensitive) - Fail", @@ -173,7 +173,7 @@ func TestFilterRules(t *testing.T) { "operator": "contains", "value": "John", }), - input: createCDCMessage("INSERT", "name", pgtype.TextOID, "john doe"), + input: createCDCMessage(utils.OperationInsert, "name", pgtype.TextOID, "john doe"), expectedOutput: nil, }, } @@ -203,8 +203,8 @@ func TestDateTimeFilters(t *testing.T) { "operator": "gt", "value": pastDate.Format(time.RFC3339), }), - input: createCDCMessage("INSERT", "date", pgtype.TimestamptzOID, now), - expectedOutput: createCDCMessage("INSERT", "date", pgtype.TimestamptzOID, now), + input: createCDCMessage(utils.OperationInsert, "date", pgtype.TimestamptzOID, now), + expectedOutput: createCDCMessage(utils.OperationInsert, "date", pgtype.TimestamptzOID, now), }, { name: "Less Than or Equal Date Filter - Pass", @@ -212,8 +212,8 @@ func TestDateTimeFilters(t *testing.T) { "operator": "lte", "value": now.Format(time.RFC3339), }), - input: createCDCMessage("INSERT", "date", pgtype.TimestamptzOID, pastDate), - expectedOutput: createCDCMessage("INSERT", "date", pgtype.TimestamptzOID, pastDate), + input: createCDCMessage(utils.OperationInsert, "date", pgtype.TimestamptzOID, pastDate), + expectedOutput: createCDCMessage(utils.OperationInsert, "date", pgtype.TimestamptzOID, pastDate), }, { name: "Equal Date Filter (Different Timezone) - Pass", @@ -221,8 +221,8 @@ func TestDateTimeFilters(t *testing.T) { "operator": "eq", "value": now.UTC().Format(time.RFC3339), }), - input: createCDCMessage("INSERT", "date", pgtype.TimestamptzOID, now.In(time.FixedZone("EST", -5*60*60))), - expectedOutput: createCDCMessage("INSERT", "date", pgtype.TimestamptzOID, now.In(time.FixedZone("EST", -5*60*60))), + input: createCDCMessage(utils.OperationInsert, "date", pgtype.TimestamptzOID, now.In(time.FixedZone("EST", -5*60*60))), + expectedOutput: createCDCMessage(utils.OperationInsert, "date", pgtype.TimestamptzOID, now.In(time.FixedZone("EST", -5*60*60))), }, { name: "Greater Than Date Filter (String Input) - Fail", @@ -230,8 +230,8 @@ func TestDateTimeFilters(t *testing.T) { "operator": "gt", "value": pastDate.Format(time.RFC3339), }), - input: createCDCMessage("INSERT", "date", pgtype.TextOID, now.Format(time.RFC3339)), - expectedOutput: createCDCMessage("INSERT", "date", pgtype.TextOID, now.Format(time.RFC3339)), + input: createCDCMessage(utils.OperationInsert, "date", pgtype.TextOID, now.Format(time.RFC3339)), + expectedOutput: createCDCMessage(utils.OperationInsert, "date", pgtype.TextOID, now.Format(time.RFC3339)), }, } @@ -257,8 +257,8 @@ func TestBooleanFilters(t *testing.T) { "operator": "eq", "value": true, }), - input: createCDCMessage("INSERT", "is_active", pgtype.BoolOID, true), - expectedOutput: createCDCMessage("INSERT", "is_active", pgtype.BoolOID, true), + input: createCDCMessage(utils.OperationInsert, "is_active", pgtype.BoolOID, true), + expectedOutput: createCDCMessage(utils.OperationInsert, "is_active", pgtype.BoolOID, true), }, { name: "Not Equal Boolean Filter - Pass", @@ -266,8 +266,8 @@ func TestBooleanFilters(t *testing.T) { "operator": "ne", "value": true, }), - input: createCDCMessage("UPDATE", "is_deleted", pgtype.BoolOID, false), - expectedOutput: createCDCMessage("UPDATE", "is_deleted", pgtype.BoolOID, false), + input: createCDCMessage(utils.OperationUpdate, "is_deleted", pgtype.BoolOID, false), + expectedOutput: createCDCMessage(utils.OperationUpdate, "is_deleted", pgtype.BoolOID, false), }, { name: "Equal Boolean Filter (String Input) - Pass", @@ -275,7 +275,7 @@ func TestBooleanFilters(t *testing.T) { "operator": "eq", "value": true, }), - input: createCDCMessage("INSERT", "is_active", pgtype.TextOID, "true"), + input: createCDCMessage(utils.OperationInsert, "is_active", pgtype.TextOID, "true"), expectedOutput: nil, }, { @@ -284,7 +284,7 @@ func TestBooleanFilters(t *testing.T) { "operator": "ne", "value": false, }), - input: createCDCMessage("UPDATE", "is_deleted", pgtype.Int4OID, 0), + input: createCDCMessage(utils.OperationUpdate, "is_deleted", pgtype.Int4OID, 0), expectedOutput: nil, }, } @@ -311,8 +311,8 @@ func TestNumericFilters(t *testing.T) { "operator": "gte", "value": "99.99", }), - input: createCDCMessage("INSERT", "price", pgtype.NumericOID, "100.00"), - expectedOutput: createCDCMessage("INSERT", "price", pgtype.NumericOID, "100.00"), + input: createCDCMessage(utils.OperationInsert, "price", pgtype.NumericOID, "100.00"), + expectedOutput: createCDCMessage(utils.OperationInsert, "price", pgtype.NumericOID, "100.00"), }, { name: "Less Than Numeric Filter - Pass", @@ -320,8 +320,8 @@ func TestNumericFilters(t *testing.T) { "operator": "lt", "value": "1000.00", }), - input: createCDCMessage("UPDATE", "total", pgtype.NumericOID, "999.99"), - expectedOutput: createCDCMessage("UPDATE", "total", pgtype.NumericOID, "999.99"), + input: createCDCMessage(utils.OperationUpdate, "total", pgtype.NumericOID, "999.99"), + expectedOutput: createCDCMessage(utils.OperationUpdate, "total", pgtype.NumericOID, "999.99"), }, { name: "Less Than Numeric Filter (String Input) - Fail", @@ -329,7 +329,7 @@ func TestNumericFilters(t *testing.T) { "operator": "lt", "value": 1000.00, }), - input: createCDCMessage("UPDATE", "total", pgtype.TextOID, "999.99"), + input: createCDCMessage(utils.OperationUpdate, "total", pgtype.TextOID, "999.99"), expectedOutput: nil, }, { @@ -338,8 +338,8 @@ func TestNumericFilters(t *testing.T) { "operator": "eq", "value": 1.23, }), - input: createCDCMessage("INSERT", "weight", pgtype.Float8OID, 1.2300000001), - expectedOutput: createCDCMessage("INSERT", "weight", pgtype.Float8OID, 1.2300000001), + input: createCDCMessage(utils.OperationInsert, "weight", pgtype.Float8OID, 1.2300000001), + expectedOutput: createCDCMessage(utils.OperationInsert, "weight", pgtype.Float8OID, 1.2300000001), }, } @@ -372,7 +372,7 @@ func createRule(t *testing.T, ruleType, table, column string, params map[string] return rule } -func createCDCMessage(opType, columnName string, dataType uint32, value interface{}) *utils.CDCMessage { +func createCDCMessage(opType utils.OperationType, columnName string, dataType uint32, value interface{}) *utils.CDCMessage { return &utils.CDCMessage{ Type: opType, Columns: []*pglogrepl.RelationMessageColumn{ diff --git a/pkg/sinks/postgres.go b/pkg/sinks/postgres.go index dd29741..5abe550 100644 --- a/pkg/sinks/postgres.go +++ b/pkg/sinks/postgres.go @@ -120,61 +120,157 @@ func (s *PostgresSink) handleInsert(tx pgx.Tx, message *utils.CDCMessage) error values := make([]interface{}, 0, len(message.Columns)) for i, col := range message.Columns { - columns = append(columns, col.Name) - placeholders = append(placeholders, fmt.Sprintf("$%d", i+1)) - value, err := message.GetDecodedColumnValue(col.Name) + value, err := message.GetColumnValue(col.Name, false) if err != nil { - return fmt.Errorf("failed to get column value: %v", err) + return fmt.Errorf("failed to get column value: %w", err) } + columns = append(columns, col.Name) + placeholders = append(placeholders, fmt.Sprintf("$%d", i+1)) values = append(values, value) } - query := fmt.Sprintf("INSERT INTO %s.%s (%s) VALUES (%s) ON CONFLICT (%s) DO NOTHING", + query := fmt.Sprintf( + "INSERT INTO %s.%s (%s) VALUES (%s)", message.Schema, message.Table, - strings.Join(columns, ","), - strings.Join(placeholders, ","), - message.PrimaryKeyColumn, + strings.Join(columns, ", "), + strings.Join(placeholders, ", "), ) + // Add ON CONFLICT clause for PK/UNIQUE keys + if message.ReplicationKey.Type != utils.ReplicationKeyFull && len(message.ReplicationKey.Columns) > 0 { + query += fmt.Sprintf(" ON CONFLICT (%s) DO NOTHING", + strings.Join(message.ReplicationKey.Columns, ", ")) + } + _, err := tx.Exec(context.Background(), query, values...) - return err + if err != nil { + return fmt.Errorf("insert failed: %w", err) + } + return nil +} + +// getWhereConditions builds WHERE clause conditions based on the replication key type +func getWhereConditions(message *utils.CDCMessage, useOldValues bool, startingIndex int) ([]string, []interface{}, error) { + var conditions []string + var values []interface{} + valueIndex := startingIndex + + switch message.ReplicationKey.Type { + case utils.ReplicationKeyFull: + // For FULL, use all non-null values + for _, col := range message.Columns { + value, err := message.GetColumnValue(col.Name, useOldValues) + if err != nil { + continue // Skip columns with errors + } + if value == nil { + conditions = append(conditions, fmt.Sprintf("%s IS NULL", col.Name)) + } else { + conditions = append(conditions, fmt.Sprintf("%s = $%d", col.Name, valueIndex)) + values = append(values, value) + valueIndex++ + } + } + case utils.ReplicationKeyPK, utils.ReplicationKeyUnique: + // For PK/UNIQUE, use only the key columns + for _, colName := range message.ReplicationKey.Columns { + value, err := message.GetColumnValue(colName, useOldValues) + if err != nil { + return nil, nil, fmt.Errorf("failed to get value for key column %s: %w", colName, err) + } + if value == nil { + conditions = append(conditions, fmt.Sprintf("%s IS NULL", colName)) + } else { + conditions = append(conditions, fmt.Sprintf("%s = $%d", colName, valueIndex)) + values = append(values, value) + valueIndex++ + } + } + } + + if len(conditions) == 0 { + return nil, nil, fmt.Errorf("no valid conditions generated for WHERE clause") + } + + return conditions, values, nil +} + +// handleDelete processes a delete operation +func (s *PostgresSink) handleDelete(tx pgx.Tx, message *utils.CDCMessage) error { + if !message.ReplicationKey.IsValid() { + return fmt.Errorf("invalid replication key configuration") + } + + startingIndex := 1 + whereConditions, whereValues, err := getWhereConditions(message, true, startingIndex) + if err != nil { + return fmt.Errorf("failed to build WHERE conditions: %w", err) + } + + query := fmt.Sprintf("DELETE FROM %s.%s WHERE %s", + message.Schema, + message.Table, + strings.Join(whereConditions, " AND "), + ) + + result, err := tx.Exec(context.Background(), query, whereValues...) + if err != nil { + return fmt.Errorf("delete failed: %w", err) + } + + if result.RowsAffected() == 0 { + log.Warn(). + Str("table", message.Table). + Str("query", query). + Interface("values", whereValues). + Msg("Delete affected 0 rows") + } + + return nil } // handleUpdate processes an update operation func (s *PostgresSink) handleUpdate(tx pgx.Tx, message *utils.CDCMessage) error { + if !message.ReplicationKey.IsValid() { + return fmt.Errorf("invalid replication key configuration") + } + setClauses := make([]string, 0, len(message.Columns)) - values := make([]interface{}, 0, len(message.Columns)) - whereConditions := make([]string, 0) + setValues := make([]interface{}, 0, len(message.Columns)) valueIndex := 1 - for _, col := range message.Columns { - if message.IsColumnToasted(col.Name) { - // Skip TOAST columns that haven't changed + for _, column := range message.Columns { + // Skip toasted columns + if message.IsColumnToasted(column.Name) { continue } - newValue, err := message.GetColumnValue(col.Name) + // Get the new value for the column + value, err := message.GetColumnValue(column.Name, false) if err != nil { - return fmt.Errorf("failed to get column value: %v", err) + return fmt.Errorf("failed to get column value: %w", err) } - setClauses = append(setClauses, fmt.Sprintf("%s = $%d", col.Name, valueIndex)) - values = append(values, newValue) - valueIndex++ - - if col.Name == message.PrimaryKeyColumn { - whereConditions = append(whereConditions, fmt.Sprintf("%s = $%d", col.Name, valueIndex)) - values = append(values, newValue) // Use the same value for the WHERE clause + if value == nil { + setClauses = append(setClauses, fmt.Sprintf("%s = NULL", column.Name)) + } else { + setClauses = append(setClauses, fmt.Sprintf("%s = $%d", column.Name, valueIndex)) + setValues = append(setValues, value) valueIndex++ } } if len(setClauses) == 0 { - // If there are no columns to update (all were TOAST and unchanged), we can skip this update + log.Debug().Msg("No columns to update, skipping") return nil } + whereConditions, whereValues, err := getWhereConditions(message, true, valueIndex) + if err != nil { + return fmt.Errorf("failed to build WHERE conditions: %w", err) + } + query := fmt.Sprintf( "UPDATE %s.%s SET %s WHERE %s", message.Schema, @@ -183,30 +279,22 @@ func (s *PostgresSink) handleUpdate(tx pgx.Tx, message *utils.CDCMessage) error strings.Join(whereConditions, " AND "), ) - _, err := tx.Exec(context.Background(), query, values...) + values := append(setValues, whereValues...) + result, err := tx.Exec(context.Background(), query, values...) if err != nil { - return fmt.Errorf("failed to execute update query: %v", err) - } - - return nil -} - -// handleDelete processes a delete operation -func (s *PostgresSink) handleDelete(tx pgx.Tx, message *utils.CDCMessage) error { - if message.PrimaryKeyColumn == "" { - return fmt.Errorf("primary key column not specified in the message") + return fmt.Errorf("update failed for table %s.%s: %w (query: %s, values: %v)", + message.Schema, message.Table, err, query, values) } - pkValue, err := message.GetColumnValue(message.PrimaryKeyColumn) - if err != nil { - return fmt.Errorf("failed to get primary key value: %v", err) + if result.RowsAffected() == 0 { + log.Warn(). + Str("table", message.Table). + Str("query", query). + Interface("values", values). + Msg("Update affected 0 rows") } - query := fmt.Sprintf("DELETE FROM %s.%s WHERE %s = $1", - message.Schema, message.Table, message.PrimaryKeyColumn) - - _, err = tx.Exec(context.Background(), query, pkValue) - return err + return nil } // handleDDL processes a DDL operation @@ -222,7 +310,7 @@ func (s *PostgresSink) handleDDL(tx pgx.Tx, message *utils.CDCMessage) (pgx.Tx, tx = newTx } - ddlCommand, err := message.GetColumnValue("ddl_command") + ddlCommand, err := message.GetColumnValue("ddl_command", false) if err != nil { return tx, fmt.Errorf("failed to get DDL command: %v", err) } @@ -304,9 +392,6 @@ func (s *PostgresSink) writeBatchInternal(ctx context.Context, messages []*utils if s.disableForeignKeyChecks { if err := s.disableForeignKeys(ctx); err != nil { - if rollbackErr := tx.Rollback(ctx); rollbackErr != nil { - log.Error().Err(rollbackErr).Msg("failed to rollback transaction") - } return fmt.Errorf("failed to disable foreign key checks: %v", err) } defer func() { @@ -317,10 +402,6 @@ func (s *PostgresSink) writeBatchInternal(ctx context.Context, messages []*utils } for _, message := range messages { - primaryKeyColumn := message.MappedPrimaryKeyColumn - if primaryKeyColumn != "" { - message.PrimaryKeyColumn = message.MappedPrimaryKeyColumn - } var operationErr error err := utils.WithRetry(ctx, s.retryConfig, func() error { @@ -328,7 +409,6 @@ func (s *PostgresSink) writeBatchInternal(ctx context.Context, messages []*utils if err := s.connect(ctx); err != nil { return fmt.Errorf("failed to reconnect to database: %v", err) } - // Start a new transaction if needed if tx == nil { newTx, err := s.conn.Begin(ctx) if err != nil { @@ -339,18 +419,18 @@ func (s *PostgresSink) writeBatchInternal(ctx context.Context, messages []*utils } switch message.Type { - case "INSERT": + case utils.OperationInsert: operationErr = s.handleInsert(tx, message) - case "UPDATE": + case utils.OperationUpdate: operationErr = s.handleUpdate(tx, message) - case "DELETE": + case utils.OperationDelete: operationErr = s.handleDelete(tx, message) - case "DDL": + case utils.OperationDDL: var newTx pgx.Tx newTx, operationErr = s.handleDDL(tx, message) tx = newTx default: - operationErr = fmt.Errorf("unknown event type: %s", message.Type) + operationErr = fmt.Errorf("unknown operation type: %s", message.Type) } if operationErr != nil && isConnectionError(operationErr) { @@ -366,7 +446,13 @@ func (s *PostgresSink) writeBatchInternal(ctx context.Context, messages []*utils } } tx = nil - return fmt.Errorf("failed to handle %s: %v-%v", message.Type, err, operationErr) + return fmt.Errorf("failed to handle %s for table %s.%s: %v-%v", + message.Type, + message.Schema, + message.Table, + err, + operationErr, + ) } } diff --git a/pkg/utils/cdc_message.go b/pkg/utils/cdc_message.go index 56207da..05fea7b 100644 --- a/pkg/utils/cdc_message.go +++ b/pkg/utils/cdc_message.go @@ -29,17 +29,16 @@ func init() { // CDCMessage represents a full message for Change Data Capture type CDCMessage struct { - Type string - Schema string - Table string - Columns []*pglogrepl.RelationMessageColumn - NewTuple *pglogrepl.TupleData - OldTuple *pglogrepl.TupleData - PrimaryKeyColumn string - LSN string - EmittedAt time.Time - ToastedColumns map[string]bool - MappedPrimaryKeyColumn string + Type OperationType + Schema string + Table string + Columns []*pglogrepl.RelationMessageColumn + NewTuple *pglogrepl.TupleData + OldTuple *pglogrepl.TupleData + ReplicationKey ReplicationKey + LSN string + EmittedAt time.Time + ToastedColumns map[string]bool } // MarshalBinary implements the encoding.BinaryMarshaler interface @@ -66,23 +65,23 @@ func (m *CDCMessage) GetColumnIndex(columnName string) int { return -1 } -// GetColumnValue returns the typed value of a column -func (m *CDCMessage) GetColumnValue(columnName string) (interface{}, error) { +// GetColumnValue gets a column value, optionally using old values for DELETE/UPDATE +func (m *CDCMessage) GetColumnValue(columnName string, useOldValues bool) (interface{}, error) { colIndex := m.GetColumnIndex(columnName) if colIndex == -1 { return nil, fmt.Errorf("column %s not found", columnName) } - column := m.Columns[colIndex] var data []byte - - if m.Type == "DELETE" { + if useOldValues && m.OldTuple != nil { data = m.OldTuple.Columns[colIndex].Data - } else { + } else if m.NewTuple != nil { data = m.NewTuple.Columns[colIndex].Data + } else { + return nil, fmt.Errorf("no data available for column %s", columnName) } - return DecodeValue(data, column.DataType) + return DecodeValue(data, m.Columns[colIndex].DataType) } // SetColumnValue sets the value of a column, respecting its type @@ -98,7 +97,7 @@ func (m *CDCMessage) SetColumnValue(columnName string, value interface{}) error return err } - if m.Type == "DELETE" { + if m.Type == OperationDelete { m.OldTuple.Columns[colIndex] = &pglogrepl.TupleDataColumn{Data: encodedValue} } else { m.NewTuple.Columns[colIndex] = &pglogrepl.TupleDataColumn{Data: encodedValue} @@ -144,7 +143,7 @@ func EncodeCDCMessage(m CDCMessage) ([]byte, error) { } } - if err := enc.Encode(m.PrimaryKeyColumn); err != nil { + if err := enc.Encode(m.ReplicationKey); err != nil { return nil, err } @@ -204,7 +203,7 @@ func DecodeCDCMessage(data []byte) (*CDCMessage, error) { } } - if err := dec.Decode(&m.PrimaryKeyColumn); err != nil { + if err := dec.Decode(&m.ReplicationKey); err != nil { return nil, err } @@ -370,7 +369,7 @@ func (m *CDCMessage) GetDecodedMessage() (map[string]interface{}, error) { decodedMessage["Type"] = m.Type decodedMessage["Schema"] = m.Schema decodedMessage["Table"] = m.Table - decodedMessage["PrimaryKeyColumn"] = m.PrimaryKeyColumn + decodedMessage["ReplicationKey"] = m.ReplicationKey decodedMessage["LSN"] = m.LSN decodedMessage["EmittedAt"] = m.EmittedAt @@ -405,3 +404,69 @@ func (m *CDCMessage) GetDecodedMessage() (map[string]interface{}, error) { func (m *CDCMessage) IsColumnToasted(columnName string) bool { return m.ToastedColumns[columnName] } + +// GetOldColumnValue returns the value from OldTuple +func (m *CDCMessage) GetOldColumnValue(columnName string) (interface{}, error) { + colIndex := m.GetColumnIndex(columnName) + if colIndex == -1 { + return nil, fmt.Errorf("column %s not found", columnName) + } + + if m.OldTuple == nil || colIndex >= len(m.OldTuple.Columns) { + return nil, fmt.Errorf("no old value available for column %s", columnName) + } + + column := m.Columns[colIndex] + data := m.OldTuple.Columns[colIndex].Data + + // Add debug logging + decodedValue, err := DecodeValue(data, column.DataType) + if err != nil { + return nil, fmt.Errorf("failed to decode old value for column %s: %v", columnName, err) + } + + return decodedValue, nil +} + +// Add helper method to check if a column is part of the primary key +func (m *CDCMessage) IsPrimaryKeyColumn(columnName string) bool { + if m.ReplicationKey.Type != ReplicationKeyPK { + return false + } + for _, col := range m.ReplicationKey.Columns { + if col == columnName { + return true + } + } + return false +} + +// Add new helper method +func (m *CDCMessage) IsReplicationKeyColumn(columnName string) bool { + if m.ReplicationKey.Type == ReplicationKeyFull { + return true // All columns are part of replication key for FULL + } + for _, col := range m.ReplicationKey.Columns { + if col == columnName { + return true + } + } + return false +} + +// GetChangedColumns returns columns that have different values between old and new tuples +func (m *CDCMessage) GetChangedColumns() []string { + if m.Type != OperationUpdate || m.OldTuple == nil || m.NewTuple == nil { + return nil + } + + var changed []string + for i, col := range m.Columns { + oldVal := m.OldTuple.Columns[i].Data + newVal := m.NewTuple.Columns[i].Data + if !bytes.Equal(oldVal, newVal) { + changed = append(changed, col.Name) + } + } + return changed +} diff --git a/pkg/utils/shared.go b/pkg/utils/shared.go index 5be3da5..6395802 100644 --- a/pkg/utils/shared.go +++ b/pkg/utils/shared.go @@ -4,6 +4,7 @@ import ( "fmt" "reflect" "strconv" + "strings" "time" "github.com/jackc/pgtype" @@ -140,3 +141,21 @@ func ToBool(v interface{}) (bool, bool) { } return false, false } + +// IsValid checks if the replication key is properly configured +func (rk *ReplicationKey) IsValid() bool { + if rk.Type == ReplicationKeyFull { + return true // FULL doesn't require specific columns + } + + return len(rk.Columns) > 0 && + (rk.Type == ReplicationKeyPK || rk.Type == ReplicationKeyUnique) +} + +// String returns a string representation of the replication key +func (rk ReplicationKey) String() string { + if rk.Type == ReplicationKeyFull { + return "FULL" + } + return fmt.Sprintf("%s (%s)", strings.Join(rk.Columns, ", "), rk.Type) +} diff --git a/pkg/utils/shared_types.go b/pkg/utils/shared_types.go index d1e30c4..71caba8 100644 --- a/pkg/utils/shared_types.go +++ b/pkg/utils/shared_types.go @@ -7,4 +7,20 @@ const ( OperationInsert OperationType = "INSERT" OperationUpdate OperationType = "UPDATE" OperationDelete OperationType = "DELETE" + OperationDDL OperationType = "DDL" ) + +// ReplicationKeyType represents the type of replication key +type ReplicationKeyType string + +const ( + ReplicationKeyPK ReplicationKeyType = "PRIMARY KEY" + ReplicationKeyUnique ReplicationKeyType = "UNIQUE" + ReplicationKeyFull ReplicationKeyType = "FULL" // Replica identity full +) + +// ReplicationKey represents a key used for replication (either PK or unique constraint) +type ReplicationKey struct { + Type ReplicationKeyType + Columns []string +}