diff --git a/pkg/sql2pgroll/alter_table.go b/pkg/sql2pgroll/alter_table.go index 33a7df8f..c5b12b5b 100644 --- a/pkg/sql2pgroll/alter_table.go +++ b/pkg/sql2pgroll/alter_table.go @@ -33,18 +33,30 @@ func convertAlterTableStmt(stmt *pgq.AlterTableStmt) (migrations.Operations, err op, err = convertAlterTableSetNotNull(stmt, alterTableCmd, false) case pgq.AlterTableType_AT_AlterColumnType: op, err = convertAlterTableAlterColumnType(stmt, alterTableCmd) + case pgq.AlterTableType_AT_AddConstraint: + op, err = convertAlterTableAddConstraint(stmt, alterTableCmd) } if err != nil { return nil, err } + if op == nil { + return nil, nil + } + ops = append(ops, op) } return ops, nil } +// convertAlterTableSetNotNull converts SQL statements like: +// +// `ALTER TABLE foo ALTER COLUMN a SET NOT NULL` +// `ALTER TABLE foo ALTER COLUMN a DROP NOT NULL` +// +// to an OpAlterColumn operation. func convertAlterTableSetNotNull(stmt *pgq.AlterTableStmt, cmd *pgq.AlterTableCmd, notNull bool) (migrations.Operation, error) { return &migrations.OpAlterColumn{ Table: stmt.GetRelation().GetRelname(), @@ -55,6 +67,11 @@ func convertAlterTableSetNotNull(stmt *pgq.AlterTableStmt, cmd *pgq.AlterTableCm }, nil } +// convertAlterTableAlterColumnType converts a SQL statement like: +// +// `ALTER TABLE foo ALTER COLUMN a SET DATA TYPE text` +// +// to an OpAlterColumn operation. func convertAlterTableAlterColumnType(stmt *pgq.AlterTableStmt, cmd *pgq.AlterTableCmd) (migrations.Operation, error) { node, ok := cmd.GetDef().Node.(*pgq.Node_ColumnDef) if !ok { @@ -70,6 +87,89 @@ func convertAlterTableAlterColumnType(stmt *pgq.AlterTableStmt, cmd *pgq.AlterTa }, nil } +// convertAlterTableAddConstraint converts SQL statements like: +// +// `ALTER TABLE foo ADD CONSTRAINT bar UNIQUE (a)` +// +// To an OpCreateConstraint operation. +func convertAlterTableAddConstraint(stmt *pgq.AlterTableStmt, cmd *pgq.AlterTableCmd) (migrations.Operation, error) { + node, ok := cmd.GetDef().Node.(*pgq.Node_Constraint) + if !ok { + return nil, fmt.Errorf("expected constraint definition, got %T", cmd.GetDef().Node) + } + + var op migrations.Operation + var err error + switch node.Constraint.GetContype() { + case pgq.ConstrType_CONSTR_UNIQUE: + op, err = convertAlterTableAddUniqueConstraint(stmt, node.Constraint) + default: + return nil, nil + } + + if err != nil { + return nil, err + } + + return op, nil +} + +// convertAlterTableAddUniqueConstraint converts SQL statements like: +// +// `ALTER TABLE foo ADD CONSTRAINT bar UNIQUE (a)` +// +// to an OpCreateConstraint operation. +func convertAlterTableAddUniqueConstraint(stmt *pgq.AlterTableStmt, constraint *pgq.Constraint) (migrations.Operation, error) { + if !canConvertUniqueConstraint(constraint) { + return nil, nil + } + + // Extract the columns covered by the unique constraint + columns := make([]string, 0, len(constraint.GetKeys())) + for _, keyNode := range constraint.GetKeys() { + key, ok := keyNode.Node.(*pgq.Node_String_) + if !ok { + return nil, fmt.Errorf("expected string key, got %T", keyNode) + } + columns = append(columns, key.String_.GetSval()) + } + + // Build the up and down SQL placeholders for each column covered by the + // constraint + upDown := make(map[string]string, len(columns)) + for _, column := range columns { + upDown[column] = PlaceHolderSQL + } + + return &migrations.OpCreateConstraint{ + Type: migrations.OpCreateConstraintTypeUnique, + Name: constraint.GetConname(), + Table: stmt.GetRelation().GetRelname(), + Columns: columns, + Down: upDown, + Up: upDown, + }, nil +} + +// canConvertUniqueConstraint checks if the unique constraint `constraint` can +// be faithfully converted to an OpCreateConstraint operation without losing +// information. +func canConvertUniqueConstraint(constraint *pgq.Constraint) bool { + if constraint.GetNullsNotDistinct() { + return false + } + if len(constraint.GetIncluding()) > 0 { + return false + } + if len(constraint.GetOptions()) > 0 { + return false + } + if constraint.GetIndexspace() != "" { + return false + } + return true +} + func ptr[T any](x T) *T { return &x } diff --git a/pkg/sql2pgroll/alter_table_test.go b/pkg/sql2pgroll/alter_table_test.go index 1cdb6521..666661db 100644 --- a/pkg/sql2pgroll/alter_table_test.go +++ b/pkg/sql2pgroll/alter_table_test.go @@ -35,6 +35,14 @@ func TestConvertAlterTableStatements(t *testing.T) { sql: "ALTER TABLE foo ALTER COLUMN a TYPE text", expectedOp: expect.AlterTableOp3, }, + { + sql: "ALTER TABLE foo ADD CONSTRAINT bar UNIQUE (a)", + expectedOp: expect.AlterTableOp4, + }, + { + sql: "ALTER TABLE foo ADD CONSTRAINT bar UNIQUE (a, b)", + expectedOp: expect.AlterTableOp5, + }, } for _, tc := range tests { @@ -44,10 +52,31 @@ func TestConvertAlterTableStatements(t *testing.T) { require.Len(t, ops, 1) - alterColumnOps, ok := ops[0].(*migrations.OpAlterColumn) - require.True(t, ok) + assert.Equal(t, tc.expectedOp, ops[0]) + }) + } +} + +func TestUnconvertableAlterTableAddConstraintStatements(t *testing.T) { + t.Parallel() + + tests := []string{ + // UNIQUE constraints with various options that are not representable by + // `OpCreateConstraint` operations + "ALTER TABLE foo ADD CONSTRAINT bar UNIQUE NULLS NOT DISTINCT (a)", + "ALTER TABLE foo ADD CONSTRAINT bar UNIQUE (a) INCLUDE (b)", + "ALTER TABLE foo ADD CONSTRAINT bar UNIQUE (a) WITH (fillfactor=70)", + "ALTER TABLE foo ADD CONSTRAINT bar UNIQUE (a) USING INDEX TABLESPACE baz", + } + + for _, sql := range tests { + t.Run(sql, func(t *testing.T) { + ops, err := sql2pgroll.Convert(sql) + require.NoError(t, err) + + require.Len(t, ops, 1) - assert.Equal(t, tc.expectedOp, alterColumnOps) + assert.Equal(t, expect.RawSQLOp(sql), ops[0]) }) } } diff --git a/pkg/sql2pgroll/create_table.go b/pkg/sql2pgroll/create_table.go index dbaeb504..089b8d4a 100644 --- a/pkg/sql2pgroll/create_table.go +++ b/pkg/sql2pgroll/create_table.go @@ -8,7 +8,7 @@ import ( ) // convertCreateStmt converts a CREATE TABLE statement to a pgroll operation. -func convertCreateStmt(stmt *pgq.CreateStmt) ([]migrations.Operation, error) { +func convertCreateStmt(stmt *pgq.CreateStmt) (migrations.Operations, error) { columns := make([]migrations.Column, 0, len(stmt.TableElts)) for _, elt := range stmt.TableElts { columns = append(columns, convertColumnDef(elt.GetColumnDef())) diff --git a/pkg/sql2pgroll/expect/alter_table.go b/pkg/sql2pgroll/expect/alter_table.go index 4f7aceea..d2cccff6 100644 --- a/pkg/sql2pgroll/expect/alter_table.go +++ b/pkg/sql2pgroll/expect/alter_table.go @@ -31,6 +31,30 @@ var AlterTableOp3 = &migrations.OpAlterColumn{ Down: sql2pgroll.PlaceHolderSQL, } +var AlterTableOp4 = &migrations.OpCreateConstraint{ + Type: migrations.OpCreateConstraintTypeUnique, + Name: "bar", + Table: "foo", + Columns: []string{"a"}, + Down: map[string]string{"a": sql2pgroll.PlaceHolderSQL}, + Up: map[string]string{"a": sql2pgroll.PlaceHolderSQL}, +} + +var AlterTableOp5 = &migrations.OpCreateConstraint{ + Type: migrations.OpCreateConstraintTypeUnique, + Name: "bar", + Table: "foo", + Columns: []string{"a", "b"}, + Down: map[string]string{ + "a": sql2pgroll.PlaceHolderSQL, + "b": sql2pgroll.PlaceHolderSQL, + }, + Up: map[string]string{ + "a": sql2pgroll.PlaceHolderSQL, + "b": sql2pgroll.PlaceHolderSQL, + }, +} + func ptr[T any](v T) *T { return &v } diff --git a/pkg/sql2pgroll/expect/raw_sql.go b/pkg/sql2pgroll/expect/raw_sql.go new file mode 100644 index 00000000..715eaddb --- /dev/null +++ b/pkg/sql2pgroll/expect/raw_sql.go @@ -0,0 +1,9 @@ +// SPDX-License-Identifier: Apache-2.0 + +package expect + +import "github.com/xataio/pgroll/pkg/migrations" + +func RawSQLOp(sql string) *migrations.OpRawSQL { + return &migrations.OpRawSQL{Up: sql} +}