From 3696bd453bd958b4ebbc63aa6da34674f3304fbf Mon Sep 17 00:00:00 2001 From: dyma solovei Date: Wed, 29 Nov 2023 19:22:05 +0100 Subject: [PATCH] refactor: move complex conditional to a method The `f.SQLDefault == ""` part seems redundant before addReturningField, as RETURNING makes sense for a field that was given its default value. --- internal/dbtest/query_test.go | 6 +++--- query_insert.go | 17 +++++++++++------ 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/internal/dbtest/query_test.go b/internal/dbtest/query_test.go index 7902089f7..8bcc06110 100644 --- a/internal/dbtest/query_test.go +++ b/internal/dbtest/query_test.go @@ -331,8 +331,8 @@ func TestQuery(t *testing.T) { Where("model.id = _data.id") }, func(db *bun.DB) schema.QueryAppender { - // "nullzero" marshalls zero values as DEFAULT or NULL (if DEFAULT placeholder is not supported) - // DB drivers that support DEFAULT placeholder resolve it to NULL for columns that do not have a DEFAULT value. + // "nullzero" marshals zero values as DEFAULT or NULL (if DEFAULT placeholder is not supported) + // DB drivers which support DEFAULT placeholder resolve it to NULL for columns that do not have a DEFAULT value. type Model struct { Int int64 `bun:",nullzero"` Uint uint64 `bun:",nullzero"` @@ -1026,7 +1026,7 @@ func TestQuery(t *testing.T) { }).Column("is_active").WherePK() }, func(db *bun.DB) schema.QueryAppender { - // "default" marshalls zero values as DEFAULT or specified default value + // "default" marshals zero values as DEFAULT or the specified default value type Model struct { Int int64 `bun:",default:42"` Uint uint64 `bun:",default:42"` diff --git a/query_insert.go b/query_insert.go index f64bff21b..6d38a4efe 100644 --- a/query_insert.go +++ b/query_insert.go @@ -332,7 +332,7 @@ func (q *InsertQuery) appendStructValues( switch { case isTemplate: b = append(b, '?') - case f.IsPtr && f.HasNilValue(strct), f.HasZeroValue(strct) && (f.NullZero || f.SQLDefault != ""): + case q.marshalsToDefault(f, strct): if q.db.HasFeature(feature.DefaultPlaceholder) { b = append(b, "DEFAULT"...) } else if f.SQLDefault != "" { @@ -410,11 +410,9 @@ func (q *InsertQuery) getFields() ([]*schema.Field, error) { q.addReturningField(f) continue } - if f.NotNull && f.SQLDefault == "" { - if (f.IsPtr && f.HasNilValue(strct)) || (f.NullZero && f.HasZeroValue(strct)) { - q.addReturningField(f) - continue - } + if f.NotNull && q.marshalsToDefault(f, strct) { + q.addReturningField(f) + continue } fields = append(fields, f) } @@ -422,6 +420,13 @@ func (q *InsertQuery) getFields() ([]*schema.Field, error) { return fields, nil } +// marshalsToDefault checks if the value will be marshaled as DEFAULT or NULL (if DEFAULT placeholder is not supported) +// when appending it to the VALUES clause in place of the given field. +func (q InsertQuery) marshalsToDefault(f *schema.Field, v reflect.Value) bool { + return (f.IsPtr && f.HasNilValue(v)) || + (f.HasZeroValue(v) && (f.NullZero || f.SQLDefault != "")) +} + func (q *InsertQuery) appendFields( fmter schema.Formatter, b []byte, fields []*schema.Field, ) []byte {