diff --git a/internal/dbtest/query_test.go b/internal/dbtest/query_test.go index 81f9abe99..3abe2df59 100644 --- a/internal/dbtest/query_test.go +++ b/internal/dbtest/query_test.go @@ -1008,6 +1008,14 @@ func TestQuery(t *testing.T) { Join("JOIN user ON user.id = story.user_id"). Where("user.id = ?", 1) }, + func(db *bun.DB) schema.QueryAppender { + q := db.NewCreateTable().Model(new(Story)).WithForeignKeys() + + // Check that building the query with .AppendQuery() multiple times does not add redundant FK constraints: + // https://github.com/uptrace/bun/pull/941#discussion_r1443647857 + _ = q.String() + return q + }, } timeRE := regexp.MustCompile(`'2\d{3}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}(\.\d+)?(\+\d{2}:\d{2})?'`) diff --git a/internal/dbtest/testdata/snapshots/TestQuery-mariadb-164 b/internal/dbtest/testdata/snapshots/TestQuery-mariadb-164 new file mode 100644 index 000000000..8df726dcb --- /dev/null +++ b/internal/dbtest/testdata/snapshots/TestQuery-mariadb-164 @@ -0,0 +1 @@ +CREATE TABLE `stories` (`id` BIGINT NOT NULL AUTO_INCREMENT, `name` VARCHAR(255), `user_id` BIGINT, PRIMARY KEY (`id`), FOREIGN KEY (`user_id`) REFERENCES `users` (`id`) ON UPDATE NO ACTION ON DELETE NO ACTION) diff --git a/internal/dbtest/testdata/snapshots/TestQuery-mssql2019-164 b/internal/dbtest/testdata/snapshots/TestQuery-mssql2019-164 new file mode 100644 index 000000000..967afa9ca --- /dev/null +++ b/internal/dbtest/testdata/snapshots/TestQuery-mssql2019-164 @@ -0,0 +1 @@ +CREATE TABLE "stories" ("id" BIGINT NOT NULL IDENTITY, "name" VARCHAR(255), "user_id" BIGINT, PRIMARY KEY ("id"), FOREIGN KEY ("user_id") REFERENCES "users" ("id") ON UPDATE NO ACTION ON DELETE NO ACTION) diff --git a/internal/dbtest/testdata/snapshots/TestQuery-mysql5-164 b/internal/dbtest/testdata/snapshots/TestQuery-mysql5-164 new file mode 100644 index 000000000..8df726dcb --- /dev/null +++ b/internal/dbtest/testdata/snapshots/TestQuery-mysql5-164 @@ -0,0 +1 @@ +CREATE TABLE `stories` (`id` BIGINT NOT NULL AUTO_INCREMENT, `name` VARCHAR(255), `user_id` BIGINT, PRIMARY KEY (`id`), FOREIGN KEY (`user_id`) REFERENCES `users` (`id`) ON UPDATE NO ACTION ON DELETE NO ACTION) diff --git a/internal/dbtest/testdata/snapshots/TestQuery-mysql8-164 b/internal/dbtest/testdata/snapshots/TestQuery-mysql8-164 new file mode 100644 index 000000000..8df726dcb --- /dev/null +++ b/internal/dbtest/testdata/snapshots/TestQuery-mysql8-164 @@ -0,0 +1 @@ +CREATE TABLE `stories` (`id` BIGINT NOT NULL AUTO_INCREMENT, `name` VARCHAR(255), `user_id` BIGINT, PRIMARY KEY (`id`), FOREIGN KEY (`user_id`) REFERENCES `users` (`id`) ON UPDATE NO ACTION ON DELETE NO ACTION) diff --git a/internal/dbtest/testdata/snapshots/TestQuery-pg-164 b/internal/dbtest/testdata/snapshots/TestQuery-pg-164 new file mode 100644 index 000000000..ccb26f415 --- /dev/null +++ b/internal/dbtest/testdata/snapshots/TestQuery-pg-164 @@ -0,0 +1 @@ +CREATE TABLE "stories" ("id" BIGSERIAL NOT NULL, "name" VARCHAR, "user_id" BIGINT, PRIMARY KEY ("id"), FOREIGN KEY ("user_id") REFERENCES "users" ("id") ON UPDATE NO ACTION ON DELETE NO ACTION) diff --git a/internal/dbtest/testdata/snapshots/TestQuery-pgx-164 b/internal/dbtest/testdata/snapshots/TestQuery-pgx-164 new file mode 100644 index 000000000..ccb26f415 --- /dev/null +++ b/internal/dbtest/testdata/snapshots/TestQuery-pgx-164 @@ -0,0 +1 @@ +CREATE TABLE "stories" ("id" BIGSERIAL NOT NULL, "name" VARCHAR, "user_id" BIGINT, PRIMARY KEY ("id"), FOREIGN KEY ("user_id") REFERENCES "users" ("id") ON UPDATE NO ACTION ON DELETE NO ACTION) diff --git a/internal/dbtest/testdata/snapshots/TestQuery-sqlite-164 b/internal/dbtest/testdata/snapshots/TestQuery-sqlite-164 new file mode 100644 index 000000000..5e4baec58 --- /dev/null +++ b/internal/dbtest/testdata/snapshots/TestQuery-sqlite-164 @@ -0,0 +1 @@ +CREATE TABLE "stories" ("id" INTEGER NOT NULL, "name" VARCHAR, "user_id" INTEGER, PRIMARY KEY ("id"), FOREIGN KEY ("user_id") REFERENCES "users" ("id") ON UPDATE NO ACTION ON DELETE NO ACTION) diff --git a/query_table_create.go b/query_table_create.go index 518dbfd1c..8b9cfd967 100644 --- a/query_table_create.go +++ b/query_table_create.go @@ -19,6 +19,7 @@ type CreateTableQuery struct { temp bool ifNotExists bool + fksFromRel bool // Create foreign keys captured in table's relations. // varchar changes the default length for VARCHAR columns. // Because some dialects require that length is always specified for VARCHAR type, @@ -120,21 +121,9 @@ func (q *CreateTableQuery) TableSpace(tablespace string) *CreateTableQuery { return q } +// WithForeignKeys adds a FOREIGN KEY clause for each of the model's existing relations. func (q *CreateTableQuery) WithForeignKeys() *CreateTableQuery { - for _, relation := range q.tableModel.Table().Relations { - if relation.Type == schema.ManyToManyRelation || - relation.Type == schema.HasManyRelation { - continue - } - - q = q.ForeignKey("(?) REFERENCES ? (?) ? ?", - Safe(appendColumns(nil, "", relation.BaseFields)), - relation.JoinTable.SQLName, - Safe(appendColumns(nil, "", relation.JoinFields)), - Safe(relation.OnUpdate), - Safe(relation.OnDelete), - ) - } + q.fksFromRel = true return q } @@ -212,6 +201,13 @@ func (q *CreateTableQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []by b = q.appendPKConstraint(b, q.table.PKs) b = q.appendUniqueConstraints(fmter, b) + + if q.fksFromRel { + b, err = q.appendFKConstraintsRel(fmter, b) + if err != nil { + return nil, err + } + } b, err = q.appendFKConstraints(fmter, b) if err != nil { return nil, err @@ -295,13 +291,38 @@ func (q *CreateTableQuery) appendUniqueConstraint( return b } +// appendFKConstraintsRel appends a FOREIGN KEY clause for each of the model's existing relations. +func (q *CreateTableQuery) appendFKConstraintsRel(fmter schema.Formatter, b []byte) (_ []byte, err error) { + for _, rel := range q.tableModel.Table().Relations { + if rel.References() { + b, err = q.appendFK(fmter, b, schema.QueryWithArgs{ + Query: "(?) REFERENCES ? (?) ? ?", + Args: []interface{}{ + Safe(appendColumns(nil, "", rel.BaseFields)), + rel.JoinTable.SQLName, + Safe(appendColumns(nil, "", rel.JoinFields)), + Safe(rel.OnUpdate), + Safe(rel.OnDelete), + }, + }) + if err != nil { + return nil, err + } + } + } + return b, nil +} + +func (q *CreateTableQuery) appendFK(fmter schema.Formatter, b []byte, fk schema.QueryWithArgs) (_ []byte, err error) { + b = append(b, ", FOREIGN KEY "...) + return fk.AppendQuery(fmter, b) +} + func (q *CreateTableQuery) appendFKConstraints( fmter schema.Formatter, b []byte, ) (_ []byte, err error) { for _, fk := range q.fks { - b = append(b, ", FOREIGN KEY "...) - b, err = fk.AppendQuery(fmter, b) - if err != nil { + if b, err = q.appendFK(fmter, b, fk); err != nil { return nil, err } } @@ -364,3 +385,12 @@ func (q *CreateTableQuery) afterCreateTableHook(ctx context.Context) error { } return nil } + +func (q *CreateTableQuery) String() string { + buf, err := q.AppendQuery(q.db.Formatter(), nil) + if err != nil { + panic(err) + } + + return string(buf) +} diff --git a/schema/relation.go b/schema/relation.go index 6636e26a6..9eb74f7e9 100644 --- a/schema/relation.go +++ b/schema/relation.go @@ -30,6 +30,12 @@ type Relation struct { M2MJoinFields []*Field } +// References returns true if the table to which the Relation belongs needs to declare a foreign key constraint to create the relation. +// For other relations, the constraint is created in either the referencing table (1:N, 'has-many' relations) or a mapping table (N:N, 'm2m' relations). +func (r *Relation) References() bool { + return r.Type == HasOneRelation || r.Type == BelongsToRelation +} + func (r *Relation) String() string { return fmt.Sprintf("relation=%s", r.Field.GoName) } diff --git a/schema/table.go b/schema/table.go index 9eb7d1bfe..c36198950 100644 --- a/schema/table.go +++ b/schema/table.go @@ -477,7 +477,7 @@ func (t *Table) belongsToRelation(field *Field) *Relation { } rel := &Relation{ - Type: HasOneRelation, + Type: BelongsToRelation, Field: field, JoinTable: joinTable, } @@ -571,7 +571,7 @@ func (t *Table) hasOneRelation(field *Field) *Relation { joinTable := t.dialect.Tables().Ref(field.IndirectType) rel := &Relation{ - Type: BelongsToRelation, + Type: HasOneRelation, Field: field, JoinTable: joinTable, }