Skip to content

Commit

Permalink
Support non-db fields
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexander Matveev committed Sep 27, 2023
1 parent c0506c7 commit 3cc29ed
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 15 deletions.
31 changes: 22 additions & 9 deletions dbmap.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,29 +4,42 @@ import (
"reflect"
)

const tagName = "db"

func Columns(t reflect.Type) []string {
vs := make([]string, t.NumField())
var vs []string
for i := 0; i < t.NumField(); i++ {
vs[i] = t.Field(i).Tag.Get("db")
tag := t.Field(i).Tag.Get(tagName)
if tag == "" {
continue
}
vs = append(vs, tag)
}
return vs
}

func Values(s any) []any {
v := reflect.ValueOf(s)
vs := make([]interface{}, v.NumField())
t := reflect.TypeOf(s)
var vs []any
for i := 0; i < v.NumField(); i++ {
vs[i] = v.Field(i).Interface()
if t.Field(i).Tag.Get(tagName) == "" {
continue
}
vs = append(vs, v.Field(i).Interface())
}
return vs
}

func Scan(s any) []any {
vo := reflect.ValueOf(s).Elem()
vs := make([]interface{}, vo.NumField())
for i := 0; i < vo.NumField(); i++ {
f := vo.Field(i)
vs[i] = f.Addr().Interface()
v := reflect.ValueOf(s).Elem()
t := reflect.TypeOf(s).Elem()
var vs []any
for i := 0; i < v.NumField(); i++ {
if t.Field(i).Tag.Get(tagName) == "" {
continue
}
vs = append(vs, v.Field(i).Addr().Interface())
}
return vs
}
18 changes: 12 additions & 6 deletions dbmap_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@ import (
)

type Entity struct {
Id int `db:"id"`
Name string `db:"name"`
CreatedAt time.Time `db:"created_at"`
Id int `db:"id"`
NullableField *int `db:"nullable_field"`
Name string `db:"name"`
CreatedAt time.Time `db:"created_at"`
FieldNotIdDb int
}

func TestColumns(t *testing.T) {
Expand All @@ -28,6 +30,7 @@ func TestColumns(t *testing.T) {
},
want: []string{
"id",
"nullable_field",
"name",
"created_at",
},
Expand All @@ -47,9 +50,10 @@ func TestScan(t *testing.T) {
s any
}
e := Entity{
Id: 1,
Name: "Test",
CreatedAt: time.Now(),
Id: 1,
NullableField: nil,
Name: "Test",
CreatedAt: time.Now(),
}
tests := []struct {
name string
Expand All @@ -63,6 +67,7 @@ func TestScan(t *testing.T) {
},
want: []any{
&e.Id,
&e.NullableField,
&e.Name,
&e.CreatedAt,
},
Expand Down Expand Up @@ -98,6 +103,7 @@ func TestValues(t *testing.T) {
},
want: []any{
e.Id,
e.NullableField,
e.Name,
e.CreatedAt,
},
Expand Down

0 comments on commit 3cc29ed

Please sign in to comment.