diff --git a/dbmap.go b/dbmap.go index 01b218d..8a761d3 100644 --- a/dbmap.go +++ b/dbmap.go @@ -4,29 +4,42 @@ import ( "reflect" ) +const tagName = "db" + func Columns(t reflect.Type) []string { - vs := make([]string, t.NumField()) + var vs []string for i := 0; i < t.NumField(); i++ { - vs[i] = t.Field(i).Tag.Get("db") + tag := t.Field(i).Tag.Get(tagName) + if tag == "" { + continue + } + vs = append(vs, tag) } return vs } func Values(s any) []any { v := reflect.ValueOf(s) - vs := make([]interface{}, v.NumField()) + t := reflect.TypeOf(s) + var vs []any for i := 0; i < v.NumField(); i++ { - vs[i] = v.Field(i).Interface() + if t.Field(i).Tag.Get(tagName) == "" { + continue + } + vs = append(vs, v.Field(i).Interface()) } return vs } func Scan(s any) []any { - vo := reflect.ValueOf(s).Elem() - vs := make([]interface{}, vo.NumField()) - for i := 0; i < vo.NumField(); i++ { - f := vo.Field(i) - vs[i] = f.Addr().Interface() + v := reflect.ValueOf(s).Elem() + t := reflect.TypeOf(s).Elem() + var vs []any + for i := 0; i < v.NumField(); i++ { + if t.Field(i).Tag.Get(tagName) == "" { + continue + } + vs = append(vs, v.Field(i).Addr().Interface()) } return vs } diff --git a/dbmap_test.go b/dbmap_test.go index 5ac5bd9..db7d3d8 100644 --- a/dbmap_test.go +++ b/dbmap_test.go @@ -7,9 +7,11 @@ import ( ) type Entity struct { - Id int `db:"id"` - Name string `db:"name"` - CreatedAt time.Time `db:"created_at"` + Id int `db:"id"` + NullableField *int `db:"nullable_field"` + Name string `db:"name"` + CreatedAt time.Time `db:"created_at"` + FieldNotIdDb int } func TestColumns(t *testing.T) { @@ -28,6 +30,7 @@ func TestColumns(t *testing.T) { }, want: []string{ "id", + "nullable_field", "name", "created_at", }, @@ -47,9 +50,10 @@ func TestScan(t *testing.T) { s any } e := Entity{ - Id: 1, - Name: "Test", - CreatedAt: time.Now(), + Id: 1, + NullableField: nil, + Name: "Test", + CreatedAt: time.Now(), } tests := []struct { name string @@ -63,6 +67,7 @@ func TestScan(t *testing.T) { }, want: []any{ &e.Id, + &e.NullableField, &e.Name, &e.CreatedAt, }, @@ -98,6 +103,7 @@ func TestValues(t *testing.T) { }, want: []any{ e.Id, + e.NullableField, e.Name, e.CreatedAt, },