Skip to content

Commit

Permalink
Merge pull request #20 from manicar2093/fix/check-type-implement-sql-…
Browse files Browse the repository at this point in the history
…scann

fix: check type implement sql scan
  • Loading branch information
manicar2093 authored Nov 6, 2024
2 parents b984888 + ea29564 commit 7b752ed
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 6 deletions.
33 changes: 27 additions & 6 deletions sql.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package goption

import (
"database/sql"
"database/sql/driver"
"fmt"
"reflect"
Expand All @@ -27,22 +28,42 @@ import (
func (c *Optional[T]) Scan(src any) error {
srcValue, isSrcValid := isValidData(src)
c.isValidValue = isSrcValid
if isSrcValid {
destType := reflect.TypeOf(c.value)
if !srcValue.Type().ConvertibleTo(destType) {
return fmt.Errorf("interface conversion: interface {} is %s, not %s", srcValue.Type(), destType)
}
if !isSrcValid {
return nil
}

destType := reflect.TypeOf(c.value)
if srcValue.Type().ConvertibleTo(destType) {
c.value = srcValue.Convert(destType).Interface().(T)
return nil
}
return nil

destTypeP := reflect.TypeOf(&c.value)
scannerType := reflect.TypeOf((*sql.Scanner)(nil))
if destTypeP.Implements(scannerType.Elem()) {
var s T
if asScanner, ok := interface{}(&s).(sql.Scanner); ok {
if err := asScanner.Scan(src); err != nil {
return err
}
c.value = s
return nil
}
}

return fmt.Errorf("interface conversion: interface {} is %s, not %s nor implements sql.Scanner", srcValue.Type(), destType)
}

// Value returns a driver Value.
// Value must not panic.
func (c Optional[T]) Value() (driver.Value, error) {
if !c.isValidValue {
return nil, nil
}

if asValuer, ok := interface{}(&c.value).(driver.Valuer); ok {
return asValuer.Value()
}
return c.value, nil

}
35 changes: 35 additions & 0 deletions sql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package goption_test

import (
"database/sql"
"database/sql/driver"
"fmt"
"log"
"os"
Expand All @@ -20,6 +21,19 @@ type Activity struct {
Description goption.Optional[string] `json:"description"`
}

type WithScan struct {
data int
}

func (c *WithScan) Scan(src any) error {
c.data = src.(int)
return nil
}

func (c WithScan) Value() (driver.Value, error) {
return c.data, nil
}

var _ = Describe("Sql", func() {

It("works on db", Label("Integration"), func() {
Expand Down Expand Up @@ -144,6 +158,16 @@ var _ = Describe("Sql", func() {
Expect(opt.IsPresent()).To(BeTrue())
})
})

When("data implements its own scan method", func() {
It("calls it to do transform", func() {
opt := goption.Empty[WithScan]()

Expect(opt.Scan(400)).To(Succeed())
Expect(opt.IsPresent()).To(BeTrue())
Expect(opt.MustGet().data).To(Equal(400))
})
})
})

Describe("Value", func() {
Expand All @@ -168,6 +192,17 @@ var _ = Describe("Sql", func() {
Expect(got).ToNot(BeNil())
})
})

When("data implements its own value method", func() {
It("calls it to do transform", func() {
var opt = goption.Of(WithScan{data: 300})

got, err := opt.Value()

Expect(err).ToNot(HaveOccurred())
Expect(got.(int)).To(Equal(opt.MustGet().data))
})
})
})

})

0 comments on commit 7b752ed

Please sign in to comment.