diff --git a/CHANGELOG.md b/CHANGELOG.md index a9526ac97..edcb74d96 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,5 @@ +* Added support of custom types to row.ScanStruct using sql.Scanner interface + ## v3.99.5 * Fixed error `Empty query text` using prepared statements and `ydb.WithExecuteDataQueryOverQueryClient(true)` option * Prepared statements always send query text on Execute call from now (previous behaviour - send query ID) diff --git a/internal/value/cast.go b/internal/value/cast.go index 7106b62fd..6d382a151 100644 --- a/internal/value/cast.go +++ b/internal/value/cast.go @@ -1,5 +1,12 @@ package value +import ( + "database/sql" + "database/sql/driver" + + "github.com/google/uuid" +) + func CastTo(v Value, dst interface{}) error { if dst == nil { return errNilDestination @@ -10,5 +17,20 @@ func CastTo(v Value, dst interface{}) error { return nil } + if _, ok := dst.(*uuid.UUID); ok { + return v.castTo(dst) + } + + if scanner, has := dst.(sql.Scanner); has { + dv := new(driver.Value) + + err := v.castTo(dv) + if err != nil { + return err + } + + return scanner.Scan(*dv) + } + return v.castTo(dst) } diff --git a/internal/value/cast_test.go b/internal/value/cast_test.go index 863587f59..00b519f3c 100644 --- a/internal/value/cast_test.go +++ b/internal/value/cast_test.go @@ -2,6 +2,7 @@ package value import ( "database/sql/driver" + "errors" "reflect" "testing" "time" @@ -32,6 +33,19 @@ func loadLocation(t *testing.T, name string) *time.Location { return loc } +type testStringSQLScanner string + +func (s *testStringSQLScanner) Scan(value any) error { + ts, ok := value.(string) + if !ok { + return errors.New("can't cast from " + reflect.TypeOf(value).String() + " to string") + } + + *s = testStringSQLScanner(ts) + + return nil +} + func TestCastTo(t *testing.T) { testsCases := []struct { name string @@ -428,6 +442,13 @@ func TestCastTo(t *testing.T) { exp: DateValueFromTime(time.Date(2024, 1, 2, 0, 0, 0, 0, time.UTC)), err: nil, }, + { + name: xtest.CurrentFileLine(), + value: TextValue("text-string"), + dst: ptr[testStringSQLScanner](), + exp: testStringSQLScanner("text-string"), + err: nil, + }, } for _, tt := range testsCases { t.Run(tt.name, func(t *testing.T) { diff --git a/internal/value/value.go b/internal/value/value.go index 0b8331874..866fa37ad 100644 --- a/internal/value/value.go +++ b/internal/value/value.go @@ -1303,7 +1303,7 @@ func (v *listValue) castTo(dst any) error { inner.Set(newSlice) for i, item := range v.ListItems() { - if err := item.castTo(inner.Index(i).Addr().Interface()); err != nil { + if err := CastTo(item, inner.Index(i).Addr().Interface()); err != nil { return xerrors.WithStackTrace(fmt.Errorf( "%w '%s(%+v)' to '%T' destination", ErrCannotCast, v.Type().Yql(), v, dstValue, @@ -1437,7 +1437,7 @@ func (v *setValue) castTo(dst any) error { inner.Set(newSlice) for i, item := range v.items { - if err := item.castTo(inner.Index(i).Addr().Interface()); err != nil { + if err := CastTo(item, inner.Index(i).Addr().Interface()); err != nil { return xerrors.WithStackTrace(fmt.Errorf( "%w '%s(%+v)' to '%T' destination", ErrCannotCast, v.Type().Yql(), v, dstValue, @@ -1545,7 +1545,7 @@ func (v *optionalValue) castTo(dst any) error { return nil } - if err := v.value.castTo(ptr.Interface()); err != nil { + if err := CastTo(v.value, (ptr.Interface())); err != nil { return xerrors.WithStackTrace(err) } @@ -1560,7 +1560,7 @@ func (v *optionalValue) castTo(dst any) error { inner.Set(reflect.New(inner.Type().Elem())) - if err := v.value.castTo(inner.Interface()); err != nil { + if err := CastTo(v.value, inner.Interface()); err != nil { return xerrors.WithStackTrace(err) } @@ -1641,7 +1641,7 @@ func (v *structValue) castTo(dst any) error { } for i, field := range v.fields { - if err := field.V.castTo(inner.Field(i).Addr().Interface()); err != nil { + if err := CastTo(field.V, inner.Field(i).Addr().Interface()); err != nil { return xerrors.WithStackTrace(fmt.Errorf( "scan error on struct field name '%s': %w", field.Name, err, @@ -1768,7 +1768,7 @@ func (v *tupleValue) TupleItems() []Value { func (v *tupleValue) castTo(dst any) error { if len(v.items) == 1 { - return v.items[0].castTo(dst) + return CastTo(v.items[0], dst) } switch dstValue := dst.(type) { diff --git a/tests/integration/query_range_test.go b/tests/integration/query_range_test.go index 03505e30e..3dfc904ca 100644 --- a/tests/integration/query_range_test.go +++ b/tests/integration/query_range_test.go @@ -19,6 +19,19 @@ import ( "github.com/ydb-platform/ydb-go-sdk/v3/query" ) +type testStringScanner struct { + field string +} + +func (v *testStringScanner) Scan(value any) error { + vs, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type of value %v", value) + } + v.field = vs + return nil +} + func TestQueryRange(t *testing.T) { ctx, cancel := context.WithCancel(xtest.Context(t)) defer cancel() @@ -84,19 +97,22 @@ func TestQueryRange(t *testing.T) { p1 string p2 uint64 p3 time.Duration + p4 testStringScanner ) err := db.Query().Do(ctx, func(ctx context.Context, s query.Session) error { r, err := s.Query(ctx, ` DECLARE $p1 AS Text; DECLARE $p2 AS Uint64; DECLARE $p3 AS Interval; - SELECT $p1, $p2, $p3; + DECLARE $p4 AS Text; + SELECT $p1, $p2, $p3, $p4; `, query.WithParameters( ydb.ParamsBuilder(). Param("$p1").Text("test"). Param("$p2").Uint64(100500000000). Param("$p3").Interval(time.Duration(100500000000)). + Param("$p4").Text("test2"). Build(), ), query.WithSyntax(query.SyntaxYQL), @@ -112,7 +128,7 @@ func TestQueryRange(t *testing.T) { if err != nil { return err } - err = row.Scan(&p1, &p2, &p3) + err = row.Scan(&p1, &p2, &p3, &p4) if err != nil { return err } @@ -126,6 +142,9 @@ func TestQueryRange(t *testing.T) { if p3 != time.Duration(100500000000) { return fmt.Errorf("unexpected p3 value: %v", p3) } + if p4.field != "test2" { + return fmt.Errorf("unexpected p4 value: %v", p4) + } } } return nil