Skip to content

Commit

Permalink
#100: Change First/Value/Find to accept structs (#101)
Browse files Browse the repository at this point in the history
Fix column name misspelling if in upper-case (ID, URL etc)
  • Loading branch information
arthurkushman authored Oct 10, 2023
1 parent 856f050 commit 52d5874
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 70 deletions.
43 changes: 21 additions & 22 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ func main() {
### InRandomOrder

```go
res, err = db.Table("users").Select("name", "post", "user_id").InRandomOrder().ScanStruct(&dataStruct)
err = db.Table("users").Select("name", "post", "user_id").InRandomOrder().ScanStruct(dataStruct)
```

## GroupBy / Having
Expand All @@ -94,7 +94,7 @@ The GroupBy and Having methods may be used to group the query results.
The having method's signature is similar to that of the where method:

```go
res, err := db.table("users").GroupBy("account_id").Having("account_id", ">", 100).ScanStruct(&dataStruct)
err = db.table("users").GroupBy("account_id").Having("account_id", ">", 100).ScanStruct(dataStruct)
```

## Where, AndWhere, OrWhere clauses
Expand All @@ -106,7 +106,7 @@ The second argument is an operator, which can be any of the database's supported
Finally, the third argument is the value to evaluate against the column.

```go
res, err := db.Table("table1").Select("foo", "bar", "baz").Where("foo", "=", cmp).AndWhere("bar", "!=", "foo").OrWhere("baz", "=", 123)..ScanStruct(&dataStruct)
err = db.Table("table1").Select("foo", "bar", "baz").Where("foo", "=", cmp).AndWhere("bar", "!=", "foo").OrWhere("baz", "=", 123)..ScanStruct(dataStruct)
```

You may chain where constraints together as well as add or clauses to the query.
Expand All @@ -117,15 +117,15 @@ The orWhere method accepts the same arguments as the where method.
The whereIn method verifies that a given column's value is contained within the given slice:

```go
res, err := db.Table("table1").WhereIn("id", []int64{1, 2, 3}).OrWhereIn("name", []string{"John", "Paul"})..ScanStruct(&dataStruct)
err = db.Table("table1").WhereIn("id", []int64{1, 2, 3}).OrWhereIn("name", []string{"John", "Paul"}).ScanStruct(dataStruct)
```

## WhereNull / WhereNotNull

The whereNull method verifies that the value of the given column is NULL:

```go
res, err := db.Table("posts").WhereNull("points").OrWhereNotNull("title")..ScanStruct(&dataStruct)
err = db.Table("posts").WhereNull("points").OrWhereNotNull("title")..ScanStruct(dataStruct)
```

## Left / Right / Cross / Inner / Left Outer Joins
Expand All @@ -137,7 +137,7 @@ while the remaining arguments specify the column constraints for the join.
You can even join to multiple tables in a single query:

```go
res, err := db.Table("users").Select("name", "post", "user_id").LeftJoin("posts", "users.id", "=", "posts.user_id").EachToStruct(func(rows *sql.Rows) error {
err = db.Table("users").Select("name", "post", "user_id").LeftJoin("posts", "users.id", "=", "posts.user_id").EachToStruct(func(rows *sql.Rows) error {
err = db.Next(rows, &dataStruct)
if err != nil {
return err
Expand All @@ -155,7 +155,7 @@ The Insert/InsertBatch methods accept a structure (or slice of structs) of colum

```go
// insert without getting id
err := db.Table("table1").Insert(DataStruct{
err = db.Table("table1").Insert(DataStruct{
Foo: "foo foo foo",
Bar: "bar bar bar",
Baz: &baz,
Expand All @@ -169,7 +169,7 @@ id, err := db.Table("table1").InsertGetId(DataStruct{
})

// batch insert
err := db.Table("table1").InsertBatch([]DataStruct{
err = db.Table("table1").InsertBatch([]DataStruct{
{Foo: "foo foo foo", Bar: "bar bar bar", Baz: &baz},
{Foo: "foo foo foo foo", Bar: "bar bar bar bar", Baz: &baz},
{Foo: "foo foo foo foo foo", Bar: "bar bar bar bar bar", Baz: &baz},
Expand Down Expand Up @@ -232,7 +232,7 @@ For example, you may create an initial query and use the union method to union i

```go
union := db.Table("posts").Select("title", "likes").Union()
res, err := union.Table("users").Select("name", "points").ScanStruct(&testStruct)
res, err := union.Table("users").Select("name", "points").ScanStruct(dataStruct)

// or if UNION ALL is of need
// union := db.Table("posts").Select("title", "likes").UnionAll()
Expand All @@ -246,7 +246,7 @@ or committed if everything is ok:

```go
err := db.InTransaction(func () (interface{}, error) {
return db.Table("users").Select("name", "post", "user_id").ScanStruct(&testStruct)
return db.Table("users").Select("name", "post", "user_id").ScanStruct(dataStruct)
})
```

Expand Down Expand Up @@ -282,39 +282,38 @@ If you just need to retrieve a single row from the database table, you may use t
This method will return a single `map[string]interface{}`:

```go
post, err := db.Table("posts").Select("title").OrderBy("created_at", "desc").First()
err = db.Table("posts").Select("title").OrderBy("created_at", "desc").First(dataStruct)

// usage ex: post["title"]
// usage ex: dataStruct.Title
```

If you don't even need an entire row, you may extract a single value from a record using the `Value` method.
This method will return the value of the column directly:

```go
res, err := db.Table("users").OrderBy("points", "desc").Value("name")
err = db.Table("users").OrderBy("points", "desc").Value(dataStruct, "name")

// res -> "Alex Shmidt"
// dataStruct.Name -> "Alex Shmidt"
```

To retrieve a single row by its id column value, use the `find` method:

```go
user, err := db.Table("users").Find(id)
user, err := db.Table("users").Find(dataStruct, id)

// user["id"], user["name"], user["email"] etc
// dataStruct.ID, dataStruct.Name, dataStruct.Email etc
```

## WhereExists / WhereNotExists

The whereExists method allows you to write where exists SQL clauses.
The whereExists method accepts a *DB argument,
which will receive a query builder instance allowing you to define the query that should be placed inside of the "
exists" clause:
which will receive a query builder instance allowing you to define the query that should be placed inside the "exists" clause:

```go
res, er := db.Table("users").Select("name").WhereExists(
err = db.Table("users").Select("name").WhereExists(
db.Table("users").Select("name").Where("points", ">=", int64(12345)),
).First()
).First(dataStruct)
```

Any query that is of need to build one can place inside `WhereExists` clause/func.
Expand All @@ -324,13 +323,13 @@ Any query that is of need to build one can place inside `WhereExists` clause/fun
The whereBetween func verifies that a column's value is between two values:

```go
res, err := db.Table(UsersTable).Select("name").WhereBetween("points", 1233, 12345).ScanStruct(&testStruct)
err = db.Table(UsersTable).Select("name").WhereBetween("points", 1233, 12345).ScanStruct(&testStruct)
```

The whereNotBetween func verifies that a column's value lies outside of two values:

```go
res, err := db.Table(UsersTable).Select("name").WhereNotBetween("points", 123, 123456).ScanStruct(&testStruct)
err = db.Table(UsersTable).Select("name").WhereNotBetween("points", 123, 123456).ScanStruct(&testStruct)
```

## Determining If Records Exist
Expand Down
28 changes: 10 additions & 18 deletions advanced.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,36 +7,28 @@ import (
)

// First getting the 1st row of query
func (r *DB) First() (map[string]interface{}, error) {
res, err := r.Get()
func (r *DB) First(src any) error {
err := r.ScanStruct(src)
if err != nil {
return nil, err
return err
}

if len(res) > 0 {
return res[0], nil
}
return nil, fmt.Errorf("no records were produced by query: %s", r.Builder.buildSelect())
return nil
}

// Value gets the value of column in first query resulting row
func (r *DB) Value(column string) (val interface{}, err error) {
r.Select(column)
res, err := r.First()
func (r *DB) Value(src any, column string) error {
err := r.Select(column).ScanStruct(src)
if err != nil {
return
}

if val, ok := res[column]; ok {
return val, err
return err
}

return
return nil
}

// Find retrieves a single row by it's id column value
func (r *DB) Find(id uint64) (map[string]interface{}, error) {
return r.Where("id", "=", id).First()
func (r *DB) Find(src any, id uint64) error {
return r.Where("id", "=", id).First(src)
}

// Pluck getting values of a particular column and place them into slice
Expand Down
66 changes: 39 additions & 27 deletions builder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -546,11 +546,12 @@ func TestDB_First(t *testing.T) {
})
require.NoError(t, err)

res, err := db.Table(TestTable).Select("baz").OrderBy("baz", "desc").OrderBy("foo", "desc").First()
dataStruct := &DataStruct{}
err = db.Table(TestTable).Select("baz").OrderBy("baz", "desc").OrderBy("foo", "desc").First(dataStruct)
require.NoError(t, err)
require.Equal(t, res["baz"], int64(1234))
require.Equal(t, *dataStruct.Baz, int64(1234))

_, err = db.Table(TestTable).Select("baz").OrderBy("baz", "desc").OrderBy("fo", "desc").First()
err = db.Table(TestTable).Select("baz").OrderBy("baz", "desc").OrderBy("fo", "desc").First(dataStruct)
require.Error(t, err)

_, err = db.Truncate(TestTable)
Expand All @@ -564,9 +565,12 @@ func TestDB_Find(t *testing.T) {
id, err := db.Table(TestTable).InsertGetId(data)
require.NoError(t, err)

res, err := db.Table(TestTable).Find(id)
dataStruct := &DataStructID{}
err = db.Table(TestTable).Find(dataStruct, id)
require.NoError(t, err)
require.Equal(t, res["id"], int64(id))
require.Equal(t, dataStruct.Foo, data.Foo)
require.Equal(t, dataStruct.Bar, data.Bar)
require.Equal(t, dataStruct.Baz, *data.Baz)

_, err = db.Truncate(TestTable)
require.NoError(t, err)
Expand All @@ -579,11 +583,12 @@ func TestDB_WhereExists(t *testing.T) {
err = db.Table(UsersTable).InsertBatch(batchUsers)
require.NoError(t, err)

res, er := db.Table(UsersTable).Select("name").WhereExists(
dataStruct := &DataStructUser{}
er := db.Table(UsersTable).Select("name").WhereExists(
db.Table(UsersTable).Select("name").Where("points", ">=", int64(12345)),
).First()
).First(dataStruct)
require.NoError(t, er)
require.Equal(t, TestUserName, res["name"])
require.Equal(t, TestUserName, dataStruct.Name)

_, err = db.Truncate(UsersTable)
require.NoError(t, err)
Expand All @@ -596,11 +601,12 @@ func TestDB_WhereNotExists(t *testing.T) {
err = db.Table(UsersTable).InsertBatch(batchUsers)
require.NoError(t, err)

res, er := db.Table(UsersTable).Select("name").WhereNotExists(
dataStruct := &DataStructUser{}
er := db.Table(UsersTable).Select("name").WhereNotExists(
db.Table(UsersTable).Select("name").Where("points", ">=", int64(12345)),
).First()
).First(dataStruct)
require.NoError(t, er)
require.Equal(t, TestUserName, res["name"])
require.Equal(t, TestUserName, dataStruct.Name)

_, err = db.Truncate(UsersTable)
require.NoError(t, err)
Expand All @@ -612,11 +618,13 @@ func TestDB_Value(t *testing.T) {

err = db.Table(UsersTable).InsertBatch(batchUsers)
require.NoError(t, err)
res, err := db.Table(UsersTable).OrderBy("points", "desc").Value("name")

dataStruct := &DataStructUser{}
err = db.Table(UsersTable).OrderBy("points", "desc").Value(dataStruct, "name")
require.NoError(t, err)
require.Equal(t, TestUserName, res)
require.Equal(t, TestUserName, dataStruct.Name)

_, err = db.Table(UsersTable).OrderBy("poin", "desc").Value("name")
err = db.Table(UsersTable).OrderBy("poin", "desc").Value(dataStruct, "name")
require.Error(t, err)

_, err = db.Truncate(UsersTable)
Expand Down Expand Up @@ -869,10 +877,11 @@ func TestDB_OrderByRaw(t *testing.T) {
err = db.Table(PostsTable).InsertBatch(batchPosts)
require.NoError(t, err)

res, err := db.Table(PostsTable).Select("title").OrderByRaw("updated_at - created_at DESC").First()
dataStruct := &DataStructPost{}
err = db.Table(PostsTable).Select("title").OrderByRaw("updated_at - created_at DESC").First(dataStruct)
require.NoError(t, err)

require.Equal(t, batchPosts[0].Title, res["title"])
require.Equal(t, batchPosts[0].Title, dataStruct.Title)

_, err = db.Truncate(PostsTable)
require.NoError(t, err)
Expand All @@ -885,14 +894,15 @@ func TestDB_SelectRaw(t *testing.T) {
err = db.Table(UsersTable).InsertBatch(batchUsers)
require.NoError(t, err)

res, err := db.Table(UsersTable).SelectRaw("SUM(points) as pts").First()
dataStruct := &DataStructUser{}
err = db.Table(UsersTable).SelectRaw("SUM(points) as points").First(dataStruct)
require.NoError(t, err)

var sum int64
for _, v := range batchUsers {
sum += v.Points
}
require.Equal(t, sum, res["pts"])
require.Equal(t, sum, dataStruct.Points)

_, err = db.Truncate(UsersTable)
require.NoError(t, err)
Expand All @@ -905,17 +915,17 @@ func TestDB_AndWhereBetween(t *testing.T) {
err = db.Table(UsersTable).InsertBatch(batchUsers)
require.NoError(t, err)

res, err := db.Table(UsersTable).Select("name").WhereBetween("points", 1233, 12345).
OrWhereBetween("points", 123456, 67891023).AndWhereNotBetween("points", 12, 23).First()
dataStruct := &DataStructUser{}
err = db.Table(UsersTable).Select("name").WhereBetween("points", 1233, 12345).
OrWhereBetween("points", 123456, 67891023).AndWhereNotBetween("points", 12, 23).
First(dataStruct)
require.NoError(t, err)
require.Equal(t, "Darth Vader", dataStruct.Name)

require.Equal(t, "Darth Vader", res["name"])

res, err = db.Table(UsersTable).Select("name").WhereNotBetween("points", 12, 123).
AndWhereBetween("points", 1233, 12345).OrWhereNotBetween("points", 12, 23).First()
err = db.Table(UsersTable).Select("name").WhereNotBetween("points", 12, 123).
AndWhereBetween("points", 1233, 12345).OrWhereNotBetween("points", 12, 23).First(dataStruct)
require.NoError(t, err)

require.Equal(t, "Alex Shmidt", res["name"])
require.Equal(t, "Alex Shmidt", dataStruct.Name)

_, err = db.Truncate(UsersTable)
require.NoError(t, err)
Expand Down Expand Up @@ -1257,7 +1267,9 @@ func TestDB_FirsNoRecordsErr(t *testing.T) {
_, err := db.Truncate(UsersTable)
require.NoError(t, err)

_, err = db.Table(TestTable).Select("baz").OrderBy("baz", "desc").OrderBy("foo", "desc").First()
dataStruct := &DataStructUser{}
err = db.Table(TestTable).Select("baz").OrderBy("baz", "desc").
OrderBy("foo", "desc").First(dataStruct)
require.Errorf(t, err, "no records were produced by query: %s")

_, err = db.Truncate(UsersTable)
Expand Down
15 changes: 12 additions & 3 deletions factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,8 @@ func (r *DB) Next(rows *sql.Rows, src any) error {
}

func setResourceValue(resource reflect.Value, src any, col string, value any) {
if !resource.FieldByName(col).IsValid() { // try to get field by db: tag
upperCaseFieldName := cases.Upper(language.English).String(col)
if !resource.FieldByName(col).IsValid() && !resource.FieldByName(upperCaseFieldName).IsValid() { // try to get field by db: tag
fields := structs.Fields(src)
for i, f := range fields {
tag := f.Tag("db")
Expand All @@ -263,7 +264,12 @@ func setResourceValue(resource reflect.Value, src any, col string, value any) {
}
}

setValue(resource.FieldByName(col), value)
colName := col
if resource.FieldByName(upperCaseFieldName).IsValid() {
colName = upperCaseFieldName
}

setValue(resource.FieldByName(colName), value)
}

func setValue(field reflect.Value, val any) {
Expand Down Expand Up @@ -298,8 +304,11 @@ func setValue(field reflect.Value, val any) {
func validateFields(resource reflect.Value, src any, columns []string) error {
for _, col := range columns {
foundColByTag := false
// standard fields parse
fieldName := cases.Title(language.English).String(col)
if !resource.FieldByName(fieldName).IsValid() {
// uppercase letters fields parse e.g.: ID, URL etc
upperCaseFieldName := cases.Upper(language.English).String(col)
if !resource.FieldByName(fieldName).IsValid() && !resource.FieldByName(upperCaseFieldName).IsValid() {
fields := structs.Fields(src)
for _, f := range fields {
tag := f.Tag("db")
Expand Down

0 comments on commit 52d5874

Please sign in to comment.