diff --git a/dialect/append.go b/dialect/append.go index 9fe4c00a6..5d10ef608 100644 --- a/dialect/append.go +++ b/dialect/append.go @@ -77,16 +77,16 @@ func AppendString(b []byte, s string) []byte { return b } -func AppendBytes(b []byte, bytes []byte) []byte { - if bytes == nil { +func AppendBytes(b []byte, bs []byte) []byte { + if bs == nil { return AppendNull(b) } b = append(b, `'\x`...) s := len(b) - b = append(b, make([]byte, hex.EncodedLen(len(bytes)))...) - hex.Encode(b[s:], bytes) + b = append(b, make([]byte, hex.EncodedLen(len(bs)))...) + hex.Encode(b[s:], bs) b = append(b, '\'') diff --git a/dialect/pgdialect/append.go b/dialect/pgdialect/append.go index 430522b09..90fac2a10 100644 --- a/dialect/pgdialect/append.go +++ b/dialect/pgdialect/append.go @@ -2,6 +2,7 @@ package pgdialect import ( "database/sql/driver" + "encoding/hex" "fmt" "reflect" "strconv" @@ -64,7 +65,7 @@ func arrayAppend(fmter schema.Formatter, b []byte, v interface{}) []byte { case bool: return dialect.AppendBool(b, v) case []byte: - return dialect.AppendBytes(b, v) + return arrayAppendBytes(b, v) case string: return arrayAppendString(b, v) case time.Time: @@ -76,12 +77,17 @@ func arrayAppend(fmter schema.Formatter, b []byte, v interface{}) []byte { } func arrayElemAppender(typ reflect.Type) schema.AppenderFunc { - if typ.Kind() == reflect.String { - return arrayAppendStringValue - } if typ.Implements(driverValuerType) { return arrayAppendDriverValue } + switch typ.Kind() { + case reflect.String: + return arrayAppendStringValue + case reflect.Slice: + if typ.Elem().Kind() == reflect.Uint8 { + return arrayAppendBytesValue + } + } return schema.Appender(typ, customAppender) } @@ -89,6 +95,10 @@ func arrayAppendStringValue(fmter schema.Formatter, b []byte, v reflect.Value) [ return arrayAppendString(b, v.String()) } +func arrayAppendBytesValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte { + return arrayAppendBytes(b, v.Bytes()) +} + func arrayAppendDriverValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte { iface, err := v.Interface().(driver.Valuer).Value() if err != nil { @@ -280,6 +290,22 @@ func appendFloat64Slice(b []byte, floats []float64) []byte { //------------------------------------------------------------------------------ +func arrayAppendBytes(b []byte, bs []byte) []byte { + if bs == nil { + return dialect.AppendNull(b) + } + + b = append(b, `"\\x`...) + + s := len(b) + b = append(b, make([]byte, hex.EncodedLen(len(bs)))...) + hex.Encode(b[s:], bs) + + b = append(b, '"') + + return b +} + func arrayAppendString(b []byte, s string) []byte { b = append(b, '"') for _, r := range s { diff --git a/dialect/pgdialect/array_parser.go b/dialect/pgdialect/array_parser.go index 1c927fca0..d3b6035ce 100644 --- a/dialect/pgdialect/array_parser.go +++ b/dialect/pgdialect/array_parser.go @@ -2,6 +2,7 @@ package pgdialect import ( "bytes" + "encoding/hex" "fmt" "io" ) @@ -114,6 +115,16 @@ func (p *arrayParser) readSubstring() ([]byte, error) { c = next } + if bytes.HasPrefix(p.buf, []byte("\\x")) && len(p.buf)%2 == 0 { + data := p.buf[2:] + buf := make([]byte, hex.DecodedLen(len(data))) + n, err := hex.Decode(buf, data) + if err != nil { + return nil, err + } + return buf[:n], nil + } + return p.buf, nil } diff --git a/internal/dbtest/pg_test.go b/internal/dbtest/pg_test.go index 523c5a031..c9d3f3fd3 100644 --- a/internal/dbtest/pg_test.go +++ b/internal/dbtest/pg_test.go @@ -2,6 +2,8 @@ package dbtest_test import ( "database/sql" + "database/sql/driver" + "fmt" "net" "reflect" "testing" @@ -16,13 +18,14 @@ import ( func TestPGArray(t *testing.T) { type Model struct { - ID int + ID int64 Array1 []string `bun:",array"` Array2 *[]string `bun:",array"` Array3 *[]string `bun:",array"` } db := pg(t) + defer db.Close() _, err := db.NewDropTable().Model((*Model)(nil)).IfExists().Exec(ctx) require.NoError(t, err) @@ -57,6 +60,49 @@ func TestPGArray(t *testing.T) { require.Nil(t, strs) } +type Hash [32]byte + +func (h *Hash) Scan(src interface{}) error { + srcB, ok := src.([]byte) + if !ok { + return fmt.Errorf("can't scan %T into Hash", src) + } + if len(srcB) != len(h) { + return fmt.Errorf("can't scan []byte of len %d into Hash, want %d", len(srcB), len(h)) + } + copy(h[:], srcB) + return nil +} + +func (h Hash) Value() (driver.Value, error) { + return h[:], nil +} + +func TestPGArrayValuer(t *testing.T) { + type Model struct { + ID int64 + Array []Hash `bun:",array"` + } + + db := pg(t) + defer db.Close() + + err := db.ResetModel(ctx, (*Model)(nil)) + require.NoError(t, err) + + model1 := &Model{ + ID: 123, + Array: []Hash{Hash{}}, + } + _, err = db.NewInsert().Model(model1).Exec(ctx) + require.NoError(t, err) + + model2 := new(Model) + err = db.NewSelect().Model(model2).Scan(ctx) + require.NoError(t, err) + require.Equal(t, model1, model2) +} + type Recipe struct { bun.BaseModel `bun:"?tenant.recipes"` diff --git a/migrate/migration.go b/migrate/migration.go index 87881d6a2..7d2d318eb 100644 --- a/migrate/migration.go +++ b/migrate/migration.go @@ -126,7 +126,11 @@ func init() { } ` -const sqlTemplate = `SELECT 1 +const sqlTemplate = `SET statement_timeout = 0; + +--bun:split + +SELECT 1 --bun:split