Skip to content

Commit

Permalink
rebase & support sql.Scanner iface
Browse files Browse the repository at this point in the history
  • Loading branch information
4el0ve4ek committed Feb 2, 2025
1 parent 48ccba7 commit cc63007
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 11 deletions.
22 changes: 22 additions & 0 deletions internal/value/cast.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
package value

import (
"database/sql"
"database/sql/driver"

"github.com/google/uuid"
)

func CastTo(v Value, dst interface{}) error {
if dst == nil {
return errNilDestination
Expand All @@ -10,6 +17,21 @@ func CastTo(v Value, dst interface{}) error {
return nil
}

if _, ok := dst.(*uuid.UUID); ok {
return v.castTo(dst)
}

if scanner, has := dst.(sql.Scanner); has {
dv := new(driver.Value)

err := v.castTo(dv)
if err != nil {
return err
}

return scanner.Scan(*dv)
}

if scanner, has := dst.(Scanner); has {
return scanner.UnmarshalYDBValue(v)
}
Expand Down
38 changes: 33 additions & 5 deletions internal/value/cast_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package value

import (
"database/sql/driver"
"errors"
"reflect"
"testing"
"time"
Expand Down Expand Up @@ -32,12 +33,32 @@ func loadLocation(t *testing.T, name string) *time.Location {
return loc
}

type testStringValueScanner struct {
field string
}
type testStringValueScanner string

func (s *testStringValueScanner) UnmarshalYDBValue(v Value) error {
return CastTo(v, &s.field)
var tmp string

err := CastTo(v, &tmp)
if err != nil {
return err
}

*s = testStringValueScanner(tmp)

return nil
}

type testStringSQLScanner string

func (s *testStringSQLScanner) Scan(value any) error {
ts, ok := value.(string)
if !ok {
return errors.New("can't cast from " + reflect.TypeOf(value).String() + " to string")
}

*s = testStringSQLScanner(ts)

return nil
}

func TestCastTo(t *testing.T) {
Expand Down Expand Up @@ -440,7 +461,14 @@ func TestCastTo(t *testing.T) {
name: xtest.CurrentFileLine(),
value: TextValue("text-string"),
dst: ptr[testStringValueScanner](),
exp: testStringValueScanner{field: "text-string"},
exp: testStringValueScanner("text-string"),
err: nil,
},
{
name: xtest.CurrentFileLine(),
value: TextValue("text-string"),
dst: ptr[testStringSQLScanner](),
exp: testStringSQLScanner("text-string"),
err: nil,
},
}
Expand Down
12 changes: 6 additions & 6 deletions internal/value/value.go
Original file line number Diff line number Diff line change
Expand Up @@ -1303,7 +1303,7 @@ func (v *listValue) castTo(dst any) error {
inner.Set(newSlice)

for i, item := range v.ListItems() {
if err := item.castTo(inner.Index(i).Addr().Interface()); err != nil {
if err := CastTo(item, inner.Index(i).Addr().Interface()); err != nil {
return xerrors.WithStackTrace(fmt.Errorf(
"%w '%s(%+v)' to '%T' destination",
ErrCannotCast, v.Type().Yql(), v, dstValue,
Expand Down Expand Up @@ -1437,7 +1437,7 @@ func (v *setValue) castTo(dst any) error {
inner.Set(newSlice)

for i, item := range v.items {
if err := item.castTo(inner.Index(i).Addr().Interface()); err != nil {
if err := CastTo(item, inner.Index(i).Addr().Interface()); err != nil {
return xerrors.WithStackTrace(fmt.Errorf(
"%w '%s(%+v)' to '%T' destination",
ErrCannotCast, v.Type().Yql(), v, dstValue,
Expand Down Expand Up @@ -1545,7 +1545,7 @@ func (v *optionalValue) castTo(dst any) error {
return nil
}

if err := v.value.castTo(ptr.Interface()); err != nil {
if err := CastTo(v.value, (ptr.Interface())); err != nil {
return xerrors.WithStackTrace(err)
}

Expand All @@ -1560,7 +1560,7 @@ func (v *optionalValue) castTo(dst any) error {

inner.Set(reflect.New(inner.Type().Elem()))

if err := v.value.castTo(inner.Interface()); err != nil {
if err := CastTo(v.value, inner.Interface()); err != nil {
return xerrors.WithStackTrace(err)
}

Expand Down Expand Up @@ -1641,7 +1641,7 @@ func (v *structValue) castTo(dst any) error {
}

for i, field := range v.fields {
if err := field.V.castTo(inner.Field(i).Addr().Interface()); err != nil {
if err := CastTo(field.V, inner.Field(i).Addr().Interface()); err != nil {
return xerrors.WithStackTrace(fmt.Errorf(
"scan error on struct field name '%s': %w",
field.Name, err,
Expand Down Expand Up @@ -1768,7 +1768,7 @@ func (v *tupleValue) TupleItems() []Value {

func (v *tupleValue) castTo(dst any) error {
if len(v.items) == 1 {
return v.items[0].castTo(dst)
return CastTo(v.items[0], dst)
}

switch dstValue := dst.(type) {
Expand Down

0 comments on commit cc63007

Please sign in to comment.