diff --git a/internal/dbtest/query_test.go b/internal/dbtest/query_test.go index 016763a78..2143bb6de 100644 --- a/internal/dbtest/query_test.go +++ b/internal/dbtest/query_test.go @@ -331,8 +331,8 @@ func TestQuery(t *testing.T) { Where("model.id = _data.id") }, func(db *bun.DB) schema.QueryAppender { - // "nullzero" marshalls zero values as DEFAULT or NULL (if DEFAULT placeholder is not supported) - // DB drivers that support DEFAULT placeholder resolve it to NULL for columns that do not have a DEFAULT value. + // "nullzero" marshals zero values as DEFAULT or NULL (if DEFAULT placeholder is not supported) + // DB drivers which support DEFAULT placeholder resolve it to NULL for columns that do not have a DEFAULT value. type Model struct { Int int64 `bun:",nullzero"` Uint uint64 `bun:",nullzero"` @@ -1034,7 +1034,7 @@ func TestQuery(t *testing.T) { }).Column("is_active").WherePK() }, func(db *bun.DB) schema.QueryAppender { - // "default" marshalls zero values as DEFAULT or specified default value + // "default" marshals zero values as DEFAULT or the specified default value type Model struct { Int int64 `bun:",default:42"` Uint uint64 `bun:",default:42"` diff --git a/query_insert.go b/query_insert.go index f64bff21b..6d38a4efe 100644 --- a/query_insert.go +++ b/query_insert.go @@ -332,7 +332,7 @@ func (q *InsertQuery) appendStructValues( switch { case isTemplate: b = append(b, '?') - case f.IsPtr && f.HasNilValue(strct), f.HasZeroValue(strct) && (f.NullZero || f.SQLDefault != ""): + case q.marshalsToDefault(f, strct): if q.db.HasFeature(feature.DefaultPlaceholder) { b = append(b, "DEFAULT"...) } else if f.SQLDefault != "" { @@ -410,11 +410,9 @@ func (q *InsertQuery) getFields() ([]*schema.Field, error) { q.addReturningField(f) continue } - if f.NotNull && f.SQLDefault == "" { - if (f.IsPtr && f.HasNilValue(strct)) || (f.NullZero && f.HasZeroValue(strct)) { - q.addReturningField(f) - continue - } + if f.NotNull && q.marshalsToDefault(f, strct) { + q.addReturningField(f) + continue } fields = append(fields, f) } @@ -422,6 +420,13 @@ func (q *InsertQuery) getFields() ([]*schema.Field, error) { return fields, nil } +// marshalsToDefault checks if the value will be marshaled as DEFAULT or NULL (if DEFAULT placeholder is not supported) +// when appending it to the VALUES clause in place of the given field. +func (q InsertQuery) marshalsToDefault(f *schema.Field, v reflect.Value) bool { + return (f.IsPtr && f.HasNilValue(v)) || + (f.HasZeroValue(v) && (f.NullZero || f.SQLDefault != "")) +} + func (q *InsertQuery) appendFields( fmter schema.Formatter, b []byte, fields []*schema.Field, ) []byte {