diff --git a/pkg/sql2pgroll/alter_table.go b/pkg/sql2pgroll/alter_table.go index 4a51c18e..33a7df8f 100644 --- a/pkg/sql2pgroll/alter_table.go +++ b/pkg/sql2pgroll/alter_table.go @@ -3,6 +3,8 @@ package sql2pgroll import ( + "fmt" + pgq "github.com/pganalyze/pg_query_go/v6" "github.com/xataio/pgroll/pkg/migrations" ) @@ -22,25 +24,50 @@ func convertAlterTableStmt(stmt *pgq.AlterTableStmt) (migrations.Operations, err continue } - switch alterTableCmd.Subtype { + var op migrations.Operation + var err error + switch alterTableCmd.GetSubtype() { case pgq.AlterTableType_AT_SetNotNull: - ops = append(ops, convertAlterTableSetNotNull(stmt, alterTableCmd, true)) + op, err = convertAlterTableSetNotNull(stmt, alterTableCmd, true) case pgq.AlterTableType_AT_DropNotNull: - ops = append(ops, convertAlterTableSetNotNull(stmt, alterTableCmd, false)) + op, err = convertAlterTableSetNotNull(stmt, alterTableCmd, false) + case pgq.AlterTableType_AT_AlterColumnType: + op, err = convertAlterTableAlterColumnType(stmt, alterTableCmd) + } + + if err != nil { + return nil, err } + + ops = append(ops, op) } return ops, nil } -func convertAlterTableSetNotNull(stmt *pgq.AlterTableStmt, cmd *pgq.AlterTableCmd, notNull bool) migrations.Operation { +func convertAlterTableSetNotNull(stmt *pgq.AlterTableStmt, cmd *pgq.AlterTableCmd, notNull bool) (migrations.Operation, error) { return &migrations.OpAlterColumn{ Table: stmt.GetRelation().GetRelname(), Column: cmd.GetName(), Nullable: ptr(!notNull), Up: PlaceHolderSQL, Down: PlaceHolderSQL, + }, nil +} + +func convertAlterTableAlterColumnType(stmt *pgq.AlterTableStmt, cmd *pgq.AlterTableCmd) (migrations.Operation, error) { + node, ok := cmd.GetDef().Node.(*pgq.Node_ColumnDef) + if !ok { + return nil, fmt.Errorf("expected column definition, got %T", cmd.GetDef().Node) } + + return &migrations.OpAlterColumn{ + Table: stmt.GetRelation().GetRelname(), + Column: cmd.GetName(), + Type: ptr(convertTypeName(node.ColumnDef.GetTypeName())), + Up: PlaceHolderSQL, + Down: PlaceHolderSQL, + }, nil } func ptr[T any](x T) *T { diff --git a/pkg/sql2pgroll/alter_table_test.go b/pkg/sql2pgroll/alter_table_test.go index 3def8c8f..1cdb6521 100644 --- a/pkg/sql2pgroll/alter_table_test.go +++ b/pkg/sql2pgroll/alter_table_test.go @@ -27,6 +27,14 @@ func TestConvertAlterTableStatements(t *testing.T) { sql: "ALTER TABLE foo ALTER COLUMN a DROP NOT NULL", expectedOp: expect.AlterTableOp2, }, + { + sql: "ALTER TABLE foo ALTER COLUMN a SET DATA TYPE text", + expectedOp: expect.AlterTableOp3, + }, + { + sql: "ALTER TABLE foo ALTER COLUMN a TYPE text", + expectedOp: expect.AlterTableOp3, + }, } for _, tc := range tests { diff --git a/pkg/sql2pgroll/create_table.go b/pkg/sql2pgroll/create_table.go index 63604895..dbaeb504 100644 --- a/pkg/sql2pgroll/create_table.go +++ b/pkg/sql2pgroll/create_table.go @@ -3,9 +3,6 @@ package sql2pgroll import ( - "fmt" - "strings" - pgq "github.com/pganalyze/pg_query_go/v6" "github.com/xataio/pgroll/pkg/migrations" ) @@ -26,42 +23,8 @@ func convertCreateStmt(stmt *pgq.CreateStmt) ([]migrations.Operation, error) { } func convertColumnDef(col *pgq.ColumnDef) migrations.Column { - ignoredTypeParts := map[string]bool{ - "pg_catalog": true, - } - - // Build the type name, including any schema qualifiers - typeParts := make([]string, 0, len(col.GetTypeName().Names)) - for _, node := range col.GetTypeName().Names { - typePart := node.GetString_().GetSval() - if _, ok := ignoredTypeParts[typePart]; ok { - continue - } - typeParts = append(typeParts, typePart) - } - - // Build the type modifiers, such as precision and scale for numeric types - var typeMods []string - for _, node := range col.GetTypeName().Typmods { - if x, ok := node.GetAConst().Val.(*pgq.A_Const_Ival); ok { - typeMods = append(typeMods, fmt.Sprintf("%d", x.Ival.GetIval())) - } - } - var typeModifier string - if len(typeMods) > 0 { - typeModifier = fmt.Sprintf("(%s)", strings.Join(typeMods, ",")) - } - - // Build the array bounds for array types - var arrayBounds string - for _, node := range col.GetTypeName().ArrayBounds { - bound := node.GetInteger().GetIval() - if bound == -1 { - arrayBounds = "[]" - } else { - arrayBounds = fmt.Sprintf("%s[%d]", arrayBounds, bound) - } - } + // Convert the column type + typeString := convertTypeName(col.TypeName) // Determine column nullability, uniqueness, and primary key status var notNull, unique, pk bool @@ -81,7 +44,7 @@ func convertColumnDef(col *pgq.ColumnDef) migrations.Column { return migrations.Column{ Name: col.Colname, - Type: strings.Join(typeParts, ".") + typeModifier + arrayBounds, + Type: typeString, Nullable: !notNull, Unique: unique, Default: defaultValue, diff --git a/pkg/sql2pgroll/expect/alter_table.go b/pkg/sql2pgroll/expect/alter_table.go index a235c2db..4f7aceea 100644 --- a/pkg/sql2pgroll/expect/alter_table.go +++ b/pkg/sql2pgroll/expect/alter_table.go @@ -23,6 +23,14 @@ var AlterTableOp2 = &migrations.OpAlterColumn{ Down: sql2pgroll.PlaceHolderSQL, } +var AlterTableOp3 = &migrations.OpAlterColumn{ + Table: "foo", + Column: "a", + Type: ptr("text"), + Up: sql2pgroll.PlaceHolderSQL, + Down: sql2pgroll.PlaceHolderSQL, +} + func ptr[T any](v T) *T { return &v } diff --git a/pkg/sql2pgroll/typename.go b/pkg/sql2pgroll/typename.go new file mode 100644 index 00000000..f2c5a11d --- /dev/null +++ b/pkg/sql2pgroll/typename.go @@ -0,0 +1,52 @@ +// SPDX-License-Identifier: Apache-2.0 + +package sql2pgroll + +import ( + "fmt" + "strings" + + pgq "github.com/pganalyze/pg_query_go/v6" +) + +// convertTypeName converts a TypeName node to a string. +func convertTypeName(typeName *pgq.TypeName) string { + ignoredTypeParts := map[string]bool{ + "pg_catalog": true, + } + + // Build the type name, including any schema qualifiers + typeParts := make([]string, 0, len(typeName.Names)) + for _, node := range typeName.Names { + typePart := node.GetString_().GetSval() + if _, ok := ignoredTypeParts[typePart]; ok { + continue + } + typeParts = append(typeParts, typePart) + } + + // Build the type modifiers, such as precision and scale for numeric types + var typeMods []string + for _, node := range typeName.Typmods { + if x, ok := node.GetAConst().Val.(*pgq.A_Const_Ival); ok { + typeMods = append(typeMods, fmt.Sprintf("%d", x.Ival.GetIval())) + } + } + var typeModifier string + if len(typeMods) > 0 { + typeModifier = fmt.Sprintf("(%s)", strings.Join(typeMods, ",")) + } + + // Build the array bounds for array types + var arrayBounds string + for _, node := range typeName.ArrayBounds { + bound := node.GetInteger().GetIval() + if bound == -1 { + arrayBounds = "[]" + } else { + arrayBounds = fmt.Sprintf("%s[%d]", arrayBounds, bound) + } + } + + return strings.Join(typeParts, ".") + typeModifier + arrayBounds +}