Skip to content

Commit

Permalink
Only track DDLs for table that part of the publication
Browse files Browse the repository at this point in the history
This way, we don't accidentally end up tracking DDLs from other tables, including temp
  • Loading branch information
shayonj committed Nov 12, 2024
1 parent cff3f38 commit a936549
Show file tree
Hide file tree
Showing 8 changed files with 365 additions and 156 deletions.
197 changes: 156 additions & 41 deletions internal/scripts/e2e_ddl.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,17 @@ set -euo pipefail

source "$(dirname "$0")/e2e_common.sh"

create_users() {
log "Creating initial test table..."
run_sql "DROP TABLE IF EXISTS public.users;"
run_sql "CREATE TABLE public.users (id serial PRIMARY KEY, data text);"
success "Initial test table created"
create_test_tables() {
log "Creating test schemas and tables..."
run_sql "DROP SCHEMA IF EXISTS app CASCADE; CREATE SCHEMA app;"
run_sql "DROP SCHEMA IF EXISTS public CASCADE; CREATE SCHEMA public;"

run_sql "CREATE TABLE app.users (id serial PRIMARY KEY, data text);"
run_sql "CREATE TABLE app.posts (id serial PRIMARY KEY, content text);"

run_sql "CREATE TABLE app.comments (id serial PRIMARY KEY, text text);"
run_sql "CREATE TABLE public.metrics (id serial PRIMARY KEY, value numeric);"
success "Test tables created"
}

start_pg_flo_replication() {
Expand All @@ -23,8 +29,8 @@ start_pg_flo_replication() {
--user "$PG_USER" \
--password "$PG_PASSWORD" \
--group "group_ddl" \
--tables "users" \
--schema "public" \
--schema "app" \
--tables "users,posts" \
--nats-url "$NATS_URL" \
--track-ddl \
>"$pg_flo_LOG" 2>&1 &
Expand Down Expand Up @@ -61,60 +67,169 @@ start_pg_flo_worker() {

perform_ddl_operations() {
log "Performing DDL operations..."
run_sql "ALTER TABLE users ADD COLUMN new_column int;"
run_sql "CREATE INDEX CONCURRENTLY idx_users_data ON users (data);"
run_sql "ALTER TABLE users RENAME COLUMN data TO old_data;"
run_sql "DROP INDEX idx_users_data;"
run_sql "ALTER TABLE users ADD COLUMN new_column_one int;"
run_sql "ALTER TABLE users ALTER COLUMN old_data TYPE varchar(255);"

# Column operations on tracked tables
run_sql "ALTER TABLE app.users ADD COLUMN email text;"
run_sql "ALTER TABLE app.users ADD COLUMN status varchar(50) DEFAULT 'active';"
run_sql "ALTER TABLE app.posts ADD COLUMN category text;"

# Index operations on tracked tables
run_sql "CREATE INDEX CONCURRENTLY idx_users_email ON app.users (email);"
run_sql "CREATE UNIQUE INDEX idx_posts_unique ON app.posts (content) WHERE content IS NOT NULL;"

# Column modifications on tracked tables
run_sql "ALTER TABLE app.users ALTER COLUMN status SET DEFAULT 'pending';"
run_sql "ALTER TABLE app.posts ALTER COLUMN category TYPE varchar(100);"

# Rename operations on tracked tables
run_sql "ALTER TABLE app.users RENAME COLUMN data TO profile;"

# Drop operations on tracked tables
run_sql "DROP INDEX CONCURRENTLY IF EXISTS idx_users_email;"
run_sql "ALTER TABLE app.posts DROP COLUMN IF EXISTS category;"

# Operations on non-tracked tables (should be ignored)
run_sql "ALTER TABLE app.comments ADD COLUMN author text;"
run_sql "CREATE INDEX idx_comments_text ON app.comments (text);"
run_sql "ALTER TABLE public.metrics ADD COLUMN timestamp timestamptz;"

success "DDL operations performed"
}

verify_ddl_changes() {
log "Verifying DDL changes..."
log "Verifying DDL changes in target database..."
local failures=0

# Check table structure in target database
local new_column_exists=$(run_sql_target "SELECT COUNT(*) FROM information_schema.columns WHERE table_name = 'users' AND column_name = 'new_column';")
local new_column_one_exists=$(run_sql_target "SELECT COUNT(*) FROM information_schema.columns WHERE table_name = 'users' AND column_name = 'new_column_one';")
local old_data_type=$(run_sql_target "SELECT data_type FROM information_schema.columns WHERE table_name = 'users' AND column_name = 'old_data';")
old_data_type=$(echo "$old_data_type" | xargs)
check_column() {
local table=$1
local column=$2
local expected_exists=$3
local expected_type=${4:-""}
local expected_default=${5:-""}
local query="
SELECT COUNT(*),
data_type,
character_maximum_length,
column_default
FROM information_schema.columns
WHERE table_schema='app'
AND table_name='$table'
AND column_name='$column'
GROUP BY data_type, character_maximum_length, column_default;"

if [ "$new_column_exists" -eq 1 ]; then
success "new_column exists in target database"
else
error "new_column does not exist in target database"
return 1
fi
local result
result=$(run_sql_target "$query")

if [ "$new_column_one_exists" -eq 1 ]; then
success "new_column_one exists in target database"
else
error "new_column_one does not exist in target database"
return 1
fi
if [ -z "$result" ]; then
exists=0
data_type=""
char_length=""
default_value=""
else
read exists data_type char_length default_value < <(echo "$result" | tr '|' ' ')
fi

if [ "$old_data_type" = "character varying" ]; then
success "old_data column type is character varying"
else
error "old_data column type is not character varying (got: '$old_data_type')"
return 1
fi
exists=${exists:-0}

if [ "$exists" -eq "$expected_exists" ]; then
if [ "$expected_exists" -eq 1 ]; then
local type_ok=true
local default_ok=true

if [ -n "$expected_type" ]; then
# Handle character varying type specifically
if [ "$expected_type" = "character varying" ]; then
if [ "$data_type" = "character varying" ] || [ "$data_type" = "varchar" ] || [ "$data_type" = "character" ]; then
type_ok=true
else
type_ok=false
fi
elif [ "$data_type" != "$expected_type" ]; then
type_ok=false
fi
fi

if [ -n "$expected_default" ]; then
if [[ "$default_value" == *"$expected_default"* ]]; then
default_ok=true
else
default_ok=false
fi
fi

if [ "$type_ok" = true ] && [ "$default_ok" = true ]; then
if [[ "$expected_type" == "character varying" && -n "$char_length" ]]; then
success "Column app.$table.$column verification passed (type: $data_type($char_length), default: $default_value)"
else
success "Column app.$table.$column verification passed (type: $data_type, default: $default_value)"
fi
else
if [ "$type_ok" = false ]; then
error "Column app.$table.$column type mismatch (expected: $expected_type, got: $data_type)"
failures=$((failures + 1))
fi
if [ "$default_ok" = false ]; then
error "Column app.$table.$column default value mismatch (expected: $expected_default, got: $default_value)"
failures=$((failures + 1))
fi
fi
else
success "Column app.$table.$column verification passed (not exists)"
fi
else
error "Column app.$table.$column verification failed (expected: $expected_exists, got: $exists)"
failures=$((failures + 1))
fi
}

check_index() {
local index=$1
local expected=$2
local exists=$(run_sql_target "SELECT COUNT(*) FROM pg_indexes WHERE schemaname='app' AND indexname='$index';")

if [ "$exists" -eq "$expected" ]; then
success "Index app.$index verification passed (expected: $expected)"
else
error "Index app.$index verification failed (expected: $expected, got: $exists)"
failures=$((failures + 1))
fi
}

# Verify app.users changes
check_column "users" "email" 1 "text"
check_column "users" "status" 1 "character varying" "'pending'"
check_column "users" "data" 0
check_column "users" "profile" 1 "text"

# Verify app.posts changes
check_column "posts" "category" 0
check_column "posts" "content" 1 "text"
check_index "idx_posts_unique" 1 "unique"

# Verify non-tracked tables
check_column "comments" "author" 0
check_index "idx_comments_text" 0

# Check if internal table is empty
local remaining_rows=$(run_sql "SELECT COUNT(*) FROM internal_pg_flo.ddl_log;")
if [ "$remaining_rows" -eq 0 ]; then
success "internal_pg_flo.ddl_log table is empty"
else
error "internal_pg_flo.ddl_log table is not empty. Remaining rows: $remaining_rows"
return 1
failures=$((failures + 1))
fi

return 0
if [ "$failures" -eq 0 ]; then
success "All DDL changes verified successfully"
return 0
else
error "DDL verification failed with $failures errors"
return 1
fi
}

test_pg_flo_ddl() {
setup_postgres
create_users
create_test_tables
start_pg_flo_worker
sleep 5
start_pg_flo_replication
Expand Down
6 changes: 3 additions & 3 deletions internal/scripts/e2e_test_local.sh
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ make build

setup_docker

log "Running e2e routing tests..."
if CI=false ./internal/scripts/e2e_routing.sh; then
success "Original e2e tests completed successfully"
log "Running e2e ddl tests..."
if CI=false ./internal/scripts/e2e_ddl.sh; then
success "e2e ddl tests completed successfully"
else
error "Original e2e tests failed"
exit 1
Expand Down
82 changes: 63 additions & 19 deletions pkg/replicator/base_replicator.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ type BaseReplicator struct {

// NewBaseReplicator creates a new BaseReplicator instance
func NewBaseReplicator(config Config, replicationConn ReplicationConnection, standardConn StandardConnection, natsClient NATSClient) *BaseReplicator {
if config.Schema == "" {
config.Schema = "public"
}

logger := log.With().Str("component", "replicator").Logger()

br := &BaseReplicator{
Expand Down Expand Up @@ -78,6 +82,26 @@ func NewBaseReplicator(config Config, replicationConn ReplicationConnection, sta
return br
}

// buildCreatePublicationQuery constructs the SQL query for creating a publication
func (r *BaseReplicator) buildCreatePublicationQuery() (string, error) {
publicationName := GeneratePublicationName(r.Config.Group)

tables, err := r.GetConfiguredTables(context.Background())
if err != nil {
return "", fmt.Errorf("failed to get configured tables: %w", err)
}

sanitizedTables := make([]string, len(tables))
for i, table := range tables {
parts := strings.Split(table, ".")
sanitizedTables[i] = pgx.Identifier{parts[0], parts[1]}.Sanitize()
}

return fmt.Sprintf("CREATE PUBLICATION %s FOR TABLE %s",
pgx.Identifier{publicationName}.Sanitize(),
strings.Join(sanitizedTables, ", ")), nil
}

// CreatePublication creates a new publication if it doesn't exist
func (r *BaseReplicator) CreatePublication() error {
publicationName := GeneratePublicationName(r.Config.Group)
Expand All @@ -91,7 +115,11 @@ func (r *BaseReplicator) CreatePublication() error {
return nil
}

query := r.buildCreatePublicationQuery()
query, err := r.buildCreatePublicationQuery()
if err != nil {
return fmt.Errorf("failed to build publication query: %w", err)
}

_, err = r.StandardConn.Exec(context.Background(), query)
if err != nil {
return fmt.Errorf("failed to create publication: %w", err)
Expand All @@ -101,24 +129,6 @@ func (r *BaseReplicator) CreatePublication() error {
return nil
}

// buildCreatePublicationQuery constructs the SQL query for creating a publication
func (r *BaseReplicator) buildCreatePublicationQuery() string {
publicationName := GeneratePublicationName(r.Config.Group)
if len(r.Config.Tables) == 0 {
return fmt.Sprintf("CREATE PUBLICATION %s FOR ALL TABLES",
pgx.Identifier{publicationName}.Sanitize())
}

fullyQualifiedTables := make([]string, len(r.Config.Tables))
for i, table := range r.Config.Tables {
fullyQualifiedTables[i] = pgx.Identifier{r.Config.Schema, table}.Sanitize()
}

return fmt.Sprintf("CREATE PUBLICATION %s FOR TABLE %s",
pgx.Identifier{publicationName}.Sanitize(),
strings.Join(fullyQualifiedTables, ", "))
}

// checkPublicationExists checks if a publication with the given name exists
func (r *BaseReplicator) checkPublicationExists(publicationName string) (bool, error) {
var exists bool
Expand Down Expand Up @@ -559,3 +569,37 @@ func (r *BaseReplicator) CheckReplicationSlotStatus(ctx context.Context) error {
r.Logger.Info().Str("slotName", publicationName).Str("restartLSN", restartLSN).Msg("Replication slot status")
return nil
}

// GetConfiguredTables returns all tables based on configuration
// If no specific tables are configured, returns all tables from the configured schema
func (r *BaseReplicator) GetConfiguredTables(ctx context.Context) ([]string, error) {
if len(r.Config.Tables) > 0 {
fullyQualifiedTables := make([]string, len(r.Config.Tables))
for i, table := range r.Config.Tables {
fullyQualifiedTables[i] = fmt.Sprintf("%s.%s", r.Config.Schema, table)
}
return fullyQualifiedTables, nil
}

rows, err := r.StandardConn.Query(ctx, `
SELECT schemaname || '.' || tablename
FROM pg_tables
WHERE schemaname = $1
AND schemaname NOT IN ('pg_catalog', 'information_schema', 'internal_pg_flo')
`, r.Config.Schema)
if err != nil {
return nil, fmt.Errorf("failed to query tables: %v", err)
}
defer rows.Close()

var tables []string
for rows.Next() {
var tableName string
if err := rows.Scan(&tableName); err != nil {
return nil, fmt.Errorf("failed to scan table name: %v", err)
}
tables = append(tables, tableName)
}

return tables, nil
}
Loading

0 comments on commit a936549

Please sign in to comment.