Skip to content

Commit

Permalink
rebase & support sql.Scanner iface
Browse files Browse the repository at this point in the history
  • Loading branch information
4el0ve4ek committed Feb 1, 2025
1 parent 48ccba7 commit 98b74f9
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 10 deletions.
16 changes: 16 additions & 0 deletions internal/value/cast.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
package value

import (
"database/sql"
"database/sql/driver"
)

func CastTo(v Value, dst interface{}) error {
if dst == nil {
return errNilDestination
Expand All @@ -10,6 +15,17 @@ func CastTo(v Value, dst interface{}) error {
return nil
}

if scanner, has := dst.(sql.Scanner); has {
dv := new(driver.Value)

err := v.castTo(dv)
if err != nil {
return err
}

return scanner.Scan(*dv)
}

if scanner, has := dst.(Scanner); has {
return scanner.UnmarshalYDBValue(v)
}
Expand Down
38 changes: 33 additions & 5 deletions internal/value/cast_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package value

import (
"database/sql/driver"
"errors"
"reflect"
"testing"
"time"
Expand Down Expand Up @@ -32,12 +33,32 @@ func loadLocation(t *testing.T, name string) *time.Location {
return loc
}

type testStringValueScanner struct {
field string
}
type testStringValueScanner string

func (s *testStringValueScanner) UnmarshalYDBValue(v Value) error {
return CastTo(v, &s.field)
var tmp string

err := CastTo(v, &tmp)
if err != nil {
return err
}

*s = testStringValueScanner(tmp)

return nil
}

type testStringSQLScanner string

func (s *testStringSQLScanner) Scan(value any) error {
ts, ok := value.(string)
if !ok {
return errors.New("can't cast from " + reflect.TypeOf(value).String() + " to string")
}

*s = testStringSQLScanner(ts)

return nil
}

func TestCastTo(t *testing.T) {
Expand Down Expand Up @@ -440,7 +461,14 @@ func TestCastTo(t *testing.T) {
name: xtest.CurrentFileLine(),
value: TextValue("text-string"),
dst: ptr[testStringValueScanner](),
exp: testStringValueScanner{field: "text-string"},
exp: testStringValueScanner("text-string"),
err: nil,
},
{
name: xtest.CurrentFileLine(),
value: TextValue("text-string"),
dst: ptr[testStringSQLScanner](),
exp: testStringSQLScanner("text-string"),
err: nil,
},
}
Expand Down
6 changes: 1 addition & 5 deletions internal/value/value.go
Original file line number Diff line number Diff line change
Expand Up @@ -2331,12 +2331,8 @@ type uuidValue struct {

func (v *uuidValue) castTo(dst any) error {
switch vv := dst.(type) {
case *uuid.UUID:
*vv = v.value

return nil
case *driver.Value:
*vv = v.value
*vv = v.value[:]

return nil
case *string:
Expand Down

0 comments on commit 98b74f9

Please sign in to comment.