Skip to content

Commit

Permalink
CreateTableQuery: Small cleanups for foreign key handling (#941)
Browse files Browse the repository at this point in the history
* chore: set correct relation types

HasOneRelation and BelongsToRelation aren't checked anywhere at the moment,
but the change is necessary to avoid any hidden bugs in the future.

* refactor: check for FOREIGN KEY clause in a dedicated method

+ Document the logic for omitting the clause for some Relations

The name References() is derived from the REFERENCES keyword
used to declare an FK constraint

* fix: make WithForeignKeys() idempotent

Subsequent calls to WithForeignKeys() should not create additional
FOREIGN KEY clauses in the query.

DBs probably have no problem dealing with duplicated FKs, but
this behaviour is hardly expected / relied on by anyone.

* test: capture bug

#941 (comment)

* fix: append relation FKs directly to query bytes
  • Loading branch information
bevzzz authored Jan 10, 2024
1 parent ff944ac commit 4785291
Show file tree
Hide file tree
Showing 11 changed files with 70 additions and 19 deletions.
8 changes: 8 additions & 0 deletions internal/dbtest/query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1008,6 +1008,14 @@ func TestQuery(t *testing.T) {
Join("JOIN user ON user.id = story.user_id").
Where("user.id = ?", 1)
},
func(db *bun.DB) schema.QueryAppender {
q := db.NewCreateTable().Model(new(Story)).WithForeignKeys()

// Check that building the query with .AppendQuery() multiple times does not add redundant FK constraints:
// https://github.com/uptrace/bun/pull/941#discussion_r1443647857
_ = q.String()
return q
},
}

timeRE := regexp.MustCompile(`'2\d{3}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}(\.\d+)?(\+\d{2}:\d{2})?'`)
Expand Down
1 change: 1 addition & 0 deletions internal/dbtest/testdata/snapshots/TestQuery-mariadb-164
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
CREATE TABLE `stories` (`id` BIGINT NOT NULL AUTO_INCREMENT, `name` VARCHAR(255), `user_id` BIGINT, PRIMARY KEY (`id`), FOREIGN KEY (`user_id`) REFERENCES `users` (`id`) ON UPDATE NO ACTION ON DELETE NO ACTION)
1 change: 1 addition & 0 deletions internal/dbtest/testdata/snapshots/TestQuery-mssql2019-164
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
CREATE TABLE "stories" ("id" BIGINT NOT NULL IDENTITY, "name" VARCHAR(255), "user_id" BIGINT, PRIMARY KEY ("id"), FOREIGN KEY ("user_id") REFERENCES "users" ("id") ON UPDATE NO ACTION ON DELETE NO ACTION)
1 change: 1 addition & 0 deletions internal/dbtest/testdata/snapshots/TestQuery-mysql5-164
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
CREATE TABLE `stories` (`id` BIGINT NOT NULL AUTO_INCREMENT, `name` VARCHAR(255), `user_id` BIGINT, PRIMARY KEY (`id`), FOREIGN KEY (`user_id`) REFERENCES `users` (`id`) ON UPDATE NO ACTION ON DELETE NO ACTION)
1 change: 1 addition & 0 deletions internal/dbtest/testdata/snapshots/TestQuery-mysql8-164
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
CREATE TABLE `stories` (`id` BIGINT NOT NULL AUTO_INCREMENT, `name` VARCHAR(255), `user_id` BIGINT, PRIMARY KEY (`id`), FOREIGN KEY (`user_id`) REFERENCES `users` (`id`) ON UPDATE NO ACTION ON DELETE NO ACTION)
1 change: 1 addition & 0 deletions internal/dbtest/testdata/snapshots/TestQuery-pg-164
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
CREATE TABLE "stories" ("id" BIGSERIAL NOT NULL, "name" VARCHAR, "user_id" BIGINT, PRIMARY KEY ("id"), FOREIGN KEY ("user_id") REFERENCES "users" ("id") ON UPDATE NO ACTION ON DELETE NO ACTION)
1 change: 1 addition & 0 deletions internal/dbtest/testdata/snapshots/TestQuery-pgx-164
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
CREATE TABLE "stories" ("id" BIGSERIAL NOT NULL, "name" VARCHAR, "user_id" BIGINT, PRIMARY KEY ("id"), FOREIGN KEY ("user_id") REFERENCES "users" ("id") ON UPDATE NO ACTION ON DELETE NO ACTION)
1 change: 1 addition & 0 deletions internal/dbtest/testdata/snapshots/TestQuery-sqlite-164
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
CREATE TABLE "stories" ("id" INTEGER NOT NULL, "name" VARCHAR, "user_id" INTEGER, PRIMARY KEY ("id"), FOREIGN KEY ("user_id") REFERENCES "users" ("id") ON UPDATE NO ACTION ON DELETE NO ACTION)
64 changes: 47 additions & 17 deletions query_table_create.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ type CreateTableQuery struct {

temp bool
ifNotExists bool
fksFromRel bool // Create foreign keys captured in table's relations.

// varchar changes the default length for VARCHAR columns.
// Because some dialects require that length is always specified for VARCHAR type,
Expand Down Expand Up @@ -120,21 +121,9 @@ func (q *CreateTableQuery) TableSpace(tablespace string) *CreateTableQuery {
return q
}

// WithForeignKeys adds a FOREIGN KEY clause for each of the model's existing relations.
func (q *CreateTableQuery) WithForeignKeys() *CreateTableQuery {
for _, relation := range q.tableModel.Table().Relations {
if relation.Type == schema.ManyToManyRelation ||
relation.Type == schema.HasManyRelation {
continue
}

q = q.ForeignKey("(?) REFERENCES ? (?) ? ?",
Safe(appendColumns(nil, "", relation.BaseFields)),
relation.JoinTable.SQLName,
Safe(appendColumns(nil, "", relation.JoinFields)),
Safe(relation.OnUpdate),
Safe(relation.OnDelete),
)
}
q.fksFromRel = true
return q
}

Expand Down Expand Up @@ -212,6 +201,13 @@ func (q *CreateTableQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []by

b = q.appendPKConstraint(b, q.table.PKs)
b = q.appendUniqueConstraints(fmter, b)

if q.fksFromRel {
b, err = q.appendFKConstraintsRel(fmter, b)
if err != nil {
return nil, err
}
}
b, err = q.appendFKConstraints(fmter, b)
if err != nil {
return nil, err
Expand Down Expand Up @@ -295,13 +291,38 @@ func (q *CreateTableQuery) appendUniqueConstraint(
return b
}

// appendFKConstraintsRel appends a FOREIGN KEY clause for each of the model's existing relations.
func (q *CreateTableQuery) appendFKConstraintsRel(fmter schema.Formatter, b []byte) (_ []byte, err error) {
for _, rel := range q.tableModel.Table().Relations {
if rel.References() {
b, err = q.appendFK(fmter, b, schema.QueryWithArgs{
Query: "(?) REFERENCES ? (?) ? ?",
Args: []interface{}{
Safe(appendColumns(nil, "", rel.BaseFields)),
rel.JoinTable.SQLName,
Safe(appendColumns(nil, "", rel.JoinFields)),
Safe(rel.OnUpdate),
Safe(rel.OnDelete),
},
})
if err != nil {
return nil, err
}
}
}
return b, nil
}

func (q *CreateTableQuery) appendFK(fmter schema.Formatter, b []byte, fk schema.QueryWithArgs) (_ []byte, err error) {
b = append(b, ", FOREIGN KEY "...)
return fk.AppendQuery(fmter, b)
}

func (q *CreateTableQuery) appendFKConstraints(
fmter schema.Formatter, b []byte,
) (_ []byte, err error) {
for _, fk := range q.fks {
b = append(b, ", FOREIGN KEY "...)
b, err = fk.AppendQuery(fmter, b)
if err != nil {
if b, err = q.appendFK(fmter, b, fk); err != nil {
return nil, err
}
}
Expand Down Expand Up @@ -364,3 +385,12 @@ func (q *CreateTableQuery) afterCreateTableHook(ctx context.Context) error {
}
return nil
}

func (q *CreateTableQuery) String() string {
buf, err := q.AppendQuery(q.db.Formatter(), nil)
if err != nil {
panic(err)
}

return string(buf)
}
6 changes: 6 additions & 0 deletions schema/relation.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@ type Relation struct {
M2MJoinFields []*Field
}

// References returns true if the table to which the Relation belongs needs to declare a foreign key constraint to create the relation.
// For other relations, the constraint is created in either the referencing table (1:N, 'has-many' relations) or a mapping table (N:N, 'm2m' relations).
func (r *Relation) References() bool {
return r.Type == HasOneRelation || r.Type == BelongsToRelation
}

func (r *Relation) String() string {
return fmt.Sprintf("relation=%s", r.Field.GoName)
}
4 changes: 2 additions & 2 deletions schema/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ func (t *Table) belongsToRelation(field *Field) *Relation {
}

rel := &Relation{
Type: HasOneRelation,
Type: BelongsToRelation,
Field: field,
JoinTable: joinTable,
}
Expand Down Expand Up @@ -571,7 +571,7 @@ func (t *Table) hasOneRelation(field *Field) *Relation {

joinTable := t.dialect.Tables().Ref(field.IndirectType)
rel := &Relation{
Type: BelongsToRelation,
Type: HasOneRelation,
Field: field,
JoinTable: joinTable,
}
Expand Down

0 comments on commit 4785291

Please sign in to comment.