Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert ALTER TABLE foo ADD CONSTRAINT bar UNIQUE (a) SQL to pgroll operation #507

Merged
merged 4 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 100 additions & 0 deletions pkg/sql2pgroll/alter_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,30 @@ func convertAlterTableStmt(stmt *pgq.AlterTableStmt) (migrations.Operations, err
op, err = convertAlterTableSetNotNull(stmt, alterTableCmd, false)
case pgq.AlterTableType_AT_AlterColumnType:
op, err = convertAlterTableAlterColumnType(stmt, alterTableCmd)
case pgq.AlterTableType_AT_AddConstraint:
op, err = convertAlterTableAddConstraint(stmt, alterTableCmd)
}

if err != nil {
return nil, err
}

if op == nil {
return nil, nil
}

ops = append(ops, op)
}

return ops, nil
}

// convertAlterTableSetNotNull converts SQL statements like:
//
// `ALTER TABLE foo ALTER COLUMN a SET NOT NULL`
// `ALTER TABLE foo ALTER COLUMN a DROP NOT NULL`
//
// to an OpAlterColumn operation.
func convertAlterTableSetNotNull(stmt *pgq.AlterTableStmt, cmd *pgq.AlterTableCmd, notNull bool) (migrations.Operation, error) {
return &migrations.OpAlterColumn{
Table: stmt.GetRelation().GetRelname(),
Expand All @@ -55,6 +67,11 @@ func convertAlterTableSetNotNull(stmt *pgq.AlterTableStmt, cmd *pgq.AlterTableCm
}, nil
}

// convertAlterTableAlterColumnType converts a SQL statement like:
//
// `ALTER TABLE foo ALTER COLUMN a SET DATA TYPE text`
//
// to an OpAlterColumn operation.
func convertAlterTableAlterColumnType(stmt *pgq.AlterTableStmt, cmd *pgq.AlterTableCmd) (migrations.Operation, error) {
node, ok := cmd.GetDef().Node.(*pgq.Node_ColumnDef)
if !ok {
Expand All @@ -70,6 +87,89 @@ func convertAlterTableAlterColumnType(stmt *pgq.AlterTableStmt, cmd *pgq.AlterTa
}, nil
}

// convertAlterTableAddConstraint converts SQL statements like:
//
// `ALTER TABLE foo ADD CONSTRAINT bar UNIQUE (a)`
//
// To an OpCreateConstraint operation.
func convertAlterTableAddConstraint(stmt *pgq.AlterTableStmt, cmd *pgq.AlterTableCmd) (migrations.Operation, error) {
node, ok := cmd.GetDef().Node.(*pgq.Node_Constraint)
if !ok {
return nil, fmt.Errorf("expected constraint definition, got %T", cmd.GetDef().Node)
}

var op migrations.Operation
var err error
switch node.Constraint.GetContype() {
case pgq.ConstrType_CONSTR_UNIQUE:
op, err = convertAlterTableAddUniqueConstraint(stmt, node.Constraint)
default:
return nil, nil
}

if err != nil {
return nil, err
}

return op, nil
}

// convertAlterTableAddUniqueConstraint converts SQL statements like:
//
// `ALTER TABLE foo ADD CONSTRAINT bar UNIQUE (a)`
//
// to an OpCreateConstraint operation.
func convertAlterTableAddUniqueConstraint(stmt *pgq.AlterTableStmt, constraint *pgq.Constraint) (migrations.Operation, error) {
if !canConvertUniqueConstraint(constraint) {
return nil, nil
}

// Extract the columns covered by the unique constraint
columns := make([]string, 0, len(constraint.GetKeys()))
for _, keyNode := range constraint.GetKeys() {
key, ok := keyNode.Node.(*pgq.Node_String_)
if !ok {
return nil, fmt.Errorf("expected string key, got %T", keyNode)
}
columns = append(columns, key.String_.GetSval())
}

// Build the up and down SQL placeholders for each column covered by the
// constraint
upDown := make(map[string]string, len(columns))
for _, column := range columns {
upDown[column] = PlaceHolderSQL
}

return &migrations.OpCreateConstraint{
Type: migrations.OpCreateConstraintTypeUnique,
Name: constraint.GetConname(),
Table: stmt.GetRelation().GetRelname(),
Columns: columns,
Down: upDown,
Up: upDown,
}, nil
}

// canConvertUniqueConstraint checks if the unique constraint `constraint` can
// be faithfully converted to an OpCreateConstraint operation without losing
// information.
func canConvertUniqueConstraint(constraint *pgq.Constraint) bool {
if constraint.GetNullsNotDistinct() {
return false
}
if len(constraint.GetIncluding()) > 0 {
return false
}
if len(constraint.GetOptions()) > 0 {
return false
}
if constraint.GetIndexspace() != "" {
return false
}
return true
}

func ptr[T any](x T) *T {
return &x
}
35 changes: 32 additions & 3 deletions pkg/sql2pgroll/alter_table_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,14 @@ func TestConvertAlterTableStatements(t *testing.T) {
sql: "ALTER TABLE foo ALTER COLUMN a TYPE text",
expectedOp: expect.AlterTableOp3,
},
{
sql: "ALTER TABLE foo ADD CONSTRAINT bar UNIQUE (a)",
expectedOp: expect.AlterTableOp4,
},
{
sql: "ALTER TABLE foo ADD CONSTRAINT bar UNIQUE (a, b)",
expectedOp: expect.AlterTableOp5,
},
}

for _, tc := range tests {
Expand All @@ -44,10 +52,31 @@ func TestConvertAlterTableStatements(t *testing.T) {

require.Len(t, ops, 1)

alterColumnOps, ok := ops[0].(*migrations.OpAlterColumn)
require.True(t, ok)
assert.Equal(t, tc.expectedOp, ops[0])
})
}
}

func TestUnconvertableAlterTableAddConstraintStatements(t *testing.T) {
t.Parallel()

tests := []string{
// UNIQUE constraints with various options that are not representable by
// `OpCreateConstraint` operations
"ALTER TABLE foo ADD CONSTRAINT bar UNIQUE NULLS NOT DISTINCT (a)",
"ALTER TABLE foo ADD CONSTRAINT bar UNIQUE (a) INCLUDE (b)",
"ALTER TABLE foo ADD CONSTRAINT bar UNIQUE (a) WITH (fillfactor=70)",
"ALTER TABLE foo ADD CONSTRAINT bar UNIQUE (a) USING INDEX TABLESPACE baz",
}

for _, sql := range tests {
t.Run(sql, func(t *testing.T) {
ops, err := sql2pgroll.Convert(sql)
require.NoError(t, err)

require.Len(t, ops, 1)

assert.Equal(t, tc.expectedOp, alterColumnOps)
assert.Equal(t, expect.RawSQLOp(sql), ops[0])
})
}
}
2 changes: 1 addition & 1 deletion pkg/sql2pgroll/create_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
)

// convertCreateStmt converts a CREATE TABLE statement to a pgroll operation.
func convertCreateStmt(stmt *pgq.CreateStmt) ([]migrations.Operation, error) {
func convertCreateStmt(stmt *pgq.CreateStmt) (migrations.Operations, error) {
columns := make([]migrations.Column, 0, len(stmt.TableElts))
for _, elt := range stmt.TableElts {
columns = append(columns, convertColumnDef(elt.GetColumnDef()))
Expand Down
24 changes: 24 additions & 0 deletions pkg/sql2pgroll/expect/alter_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,30 @@ var AlterTableOp3 = &migrations.OpAlterColumn{
Down: sql2pgroll.PlaceHolderSQL,
}

var AlterTableOp4 = &migrations.OpCreateConstraint{
Type: migrations.OpCreateConstraintTypeUnique,
Name: "bar",
Table: "foo",
Columns: []string{"a"},
Down: map[string]string{"a": sql2pgroll.PlaceHolderSQL},
Up: map[string]string{"a": sql2pgroll.PlaceHolderSQL},
}

var AlterTableOp5 = &migrations.OpCreateConstraint{
Type: migrations.OpCreateConstraintTypeUnique,
Name: "bar",
Table: "foo",
Columns: []string{"a", "b"},
Down: map[string]string{
"a": sql2pgroll.PlaceHolderSQL,
"b": sql2pgroll.PlaceHolderSQL,
},
Up: map[string]string{
"a": sql2pgroll.PlaceHolderSQL,
"b": sql2pgroll.PlaceHolderSQL,
},
}

func ptr[T any](v T) *T {
return &v
}
9 changes: 9 additions & 0 deletions pkg/sql2pgroll/expect/raw_sql.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// SPDX-License-Identifier: Apache-2.0

package expect

import "github.com/xataio/pgroll/pkg/migrations"

func RawSQLOp(sql string) *migrations.OpRawSQL {
return &migrations.OpRawSQL{Up: sql}
}
Loading