From 84661eb0390a557ec502a2df40fcf87ed1f5a0f6 Mon Sep 17 00:00:00 2001
From: Andrew Farries <andyrb@gmail.com>
Date: Tue, 3 Dec 2024 11:43:49 +0000
Subject: [PATCH] Convert `DROP NOT NULL` SQL to pgroll operation (#505)

Update the `sql2pgroll` package to convert SQL like:

```sql
ALTER TABLE foo ALTER COLUMN a DROP NOT NULL
```

to the equivalent `pgroll` operation:

```json
[
  {
    "alter_column": {
      "column": "a",
      "down": "TODO: Implement SQL data migration",
      "nullable": true,
      "table": "foo",
      "up": "TODO: Implement SQL data migration"
    }
  }
]
```
---
 pkg/sql2pgroll/alter_table.go        | 9 +++++----
 pkg/sql2pgroll/alter_table_test.go   | 4 ++++
 pkg/sql2pgroll/expect/alter_table.go | 8 ++++++++
 3 files changed, 17 insertions(+), 4 deletions(-)

diff --git a/pkg/sql2pgroll/alter_table.go b/pkg/sql2pgroll/alter_table.go
index 3351cd5d..4a51c18e 100644
--- a/pkg/sql2pgroll/alter_table.go
+++ b/pkg/sql2pgroll/alter_table.go
@@ -22,21 +22,22 @@ func convertAlterTableStmt(stmt *pgq.AlterTableStmt) (migrations.Operations, err
 			continue
 		}
 
-		//nolint:gocritic
 		switch alterTableCmd.Subtype {
 		case pgq.AlterTableType_AT_SetNotNull:
-			ops = append(ops, convertAlterTableSetNotNull(stmt, alterTableCmd))
+			ops = append(ops, convertAlterTableSetNotNull(stmt, alterTableCmd, true))
+		case pgq.AlterTableType_AT_DropNotNull:
+			ops = append(ops, convertAlterTableSetNotNull(stmt, alterTableCmd, false))
 		}
 	}
 
 	return ops, nil
 }
 
-func convertAlterTableSetNotNull(stmt *pgq.AlterTableStmt, cmd *pgq.AlterTableCmd) migrations.Operation {
+func convertAlterTableSetNotNull(stmt *pgq.AlterTableStmt, cmd *pgq.AlterTableCmd, notNull bool) migrations.Operation {
 	return &migrations.OpAlterColumn{
 		Table:    stmt.GetRelation().GetRelname(),
 		Column:   cmd.GetName(),
-		Nullable: ptr(false),
+		Nullable: ptr(!notNull),
 		Up:       PlaceHolderSQL,
 		Down:     PlaceHolderSQL,
 	}
diff --git a/pkg/sql2pgroll/alter_table_test.go b/pkg/sql2pgroll/alter_table_test.go
index b074d9f5..3def8c8f 100644
--- a/pkg/sql2pgroll/alter_table_test.go
+++ b/pkg/sql2pgroll/alter_table_test.go
@@ -23,6 +23,10 @@ func TestConvertAlterTableStatements(t *testing.T) {
 			sql:        "ALTER TABLE foo ALTER COLUMN a SET NOT NULL",
 			expectedOp: expect.AlterTableOp1,
 		},
+		{
+			sql:        "ALTER TABLE foo ALTER COLUMN a DROP NOT NULL",
+			expectedOp: expect.AlterTableOp2,
+		},
 	}
 
 	for _, tc := range tests {
diff --git a/pkg/sql2pgroll/expect/alter_table.go b/pkg/sql2pgroll/expect/alter_table.go
index f1da4a85..a235c2db 100644
--- a/pkg/sql2pgroll/expect/alter_table.go
+++ b/pkg/sql2pgroll/expect/alter_table.go
@@ -15,6 +15,14 @@ var AlterTableOp1 = &migrations.OpAlterColumn{
 	Down:     sql2pgroll.PlaceHolderSQL,
 }
 
+var AlterTableOp2 = &migrations.OpAlterColumn{
+	Table:    "foo",
+	Column:   "a",
+	Nullable: ptr(true),
+	Up:       sql2pgroll.PlaceHolderSQL,
+	Down:     sql2pgroll.PlaceHolderSQL,
+}
+
 func ptr[T any](v T) *T {
 	return &v
 }