Skip to content

Commit

Permalink
refactor: move complex conditional to a method
Browse files Browse the repository at this point in the history
The `f.SQLDefault == ""` part seems redundant before addReturningField,
as RETURNING makes sense for a field that was given its default value.
  • Loading branch information
bevzzz committed Jan 7, 2024
1 parent 75e1760 commit 3696bd4
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 9 deletions.
6 changes: 3 additions & 3 deletions internal/dbtest/query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -331,8 +331,8 @@ func TestQuery(t *testing.T) {
Where("model.id = _data.id")
},
func(db *bun.DB) schema.QueryAppender {
// "nullzero" marshalls zero values as DEFAULT or NULL (if DEFAULT placeholder is not supported)
// DB drivers that support DEFAULT placeholder resolve it to NULL for columns that do not have a DEFAULT value.
// "nullzero" marshals zero values as DEFAULT or NULL (if DEFAULT placeholder is not supported)
// DB drivers which support DEFAULT placeholder resolve it to NULL for columns that do not have a DEFAULT value.
type Model struct {
Int int64 `bun:",nullzero"`
Uint uint64 `bun:",nullzero"`
Expand Down Expand Up @@ -1026,7 +1026,7 @@ func TestQuery(t *testing.T) {
}).Column("is_active").WherePK()
},
func(db *bun.DB) schema.QueryAppender {
// "default" marshalls zero values as DEFAULT or specified default value
// "default" marshals zero values as DEFAULT or the specified default value
type Model struct {
Int int64 `bun:",default:42"`
Uint uint64 `bun:",default:42"`
Expand Down
17 changes: 11 additions & 6 deletions query_insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ func (q *InsertQuery) appendStructValues(
switch {
case isTemplate:
b = append(b, '?')
case f.IsPtr && f.HasNilValue(strct), f.HasZeroValue(strct) && (f.NullZero || f.SQLDefault != ""):
case q.marshalsToDefault(f, strct):
if q.db.HasFeature(feature.DefaultPlaceholder) {
b = append(b, "DEFAULT"...)
} else if f.SQLDefault != "" {
Expand Down Expand Up @@ -410,18 +410,23 @@ func (q *InsertQuery) getFields() ([]*schema.Field, error) {
q.addReturningField(f)
continue
}
if f.NotNull && f.SQLDefault == "" {
if (f.IsPtr && f.HasNilValue(strct)) || (f.NullZero && f.HasZeroValue(strct)) {
q.addReturningField(f)
continue
}
if f.NotNull && q.marshalsToDefault(f, strct) {
q.addReturningField(f)
continue
}
fields = append(fields, f)
}

return fields, nil
}

// marshalsToDefault checks if the value will be marshaled as DEFAULT or NULL (if DEFAULT placeholder is not supported)
// when appending it to the VALUES clause in place of the given field.
func (q InsertQuery) marshalsToDefault(f *schema.Field, v reflect.Value) bool {
return (f.IsPtr && f.HasNilValue(v)) ||
(f.HasZeroValue(v) && (f.NullZero || f.SQLDefault != ""))
}

func (q *InsertQuery) appendFields(
fmter schema.Formatter, b []byte, fields []*schema.Field,
) []byte {
Expand Down

0 comments on commit 3696bd4

Please sign in to comment.