From 98659a4d7c815912f41bfe5e43df9b38e8a695f5 Mon Sep 17 00:00:00 2001 From: Arthur Kushman Date: Mon, 16 Oct 2023 13:11:27 +0300 Subject: [PATCH] * #102: Change Chunk method to accept struct and returning slice of those structs Add common abstract method with EachToStruct processing data Add more tests to cover errs Add common abstract method with EachToStruct processing data Add more tests to cover errs - Remove outdated Get method --- README.md | 38 +++- advanced.go | 83 ++++--- aggregates.go | 39 ++-- builder.go | 40 ++-- builder_test.go | 576 ++++++++++++++++++++++++++++++++++++++---------- factory.go | 98 +++----- helpers.go | 2 +- 7 files changed, 619 insertions(+), 257 deletions(-) diff --git a/README.md b/README.md index 7a3c189..c49115a 100644 --- a/README.md +++ b/README.md @@ -34,6 +34,7 @@ library [![Tweet](http://jpillora.com/github-twitter-button/img/tweet.png)](http * [Create table](#user-content-create-table) * [Add / Modify / Drop columns](#user-content-add--modify--drop-columns) * [Chunking Results](#user-content-chunking-results) +* [Pluck / PluckMap](#user-content-pluck--pluckmap) ## Installation @@ -433,17 +434,42 @@ If you need to work with thousands of database records, consider using the chunk This method retrieves a small chunk of the results at a time and feeds each chunk into a closure for processing. ```go -err = db.Table("user_achievements").Select("points").Where("id", "=", id).Chunk(100, func (users []map[string]interface{}) bool { - for _, m := range users { - if val, ok := m["points"];ok { - pointsCalc += diffFormula(val.(int64)) - } - // or you can return false here to stop running chunks +var sumOfPoints int64 +dataStruct := &DataStructUser{} +err = db.Table(UsersTable).Select("name", "points").Chunk(dataStruct, 100, func(users []any) bool { + for _, v := range users { + user := v.(DataStructUser) + // your code goes here e.g.: + sumOfPoints += user.Points } + + // or you can return false here to stop running chunks return true }) ``` +## Pluck / PluckMap + +If you would like to get values of a particular column(s) of a struct and place them into slice - use `Pluck` method: +```go + dataStruct := &DataStructUser{} + res, err := db.Table(UsersTable).Pluck(dataStruct) + for k, v := range res { + val := v.(DataStructUser) + fmt.Println(val.Name) // f.e.: Alex Shmidt + } + + // or use a PluckMap method to aggregate key/value pairs to a map + res, err := db.Table(UsersTable).PluckMap(dataStruct, "name", "points") + for k, m := range res { + for key, value := range m { + keyVal := key.(string) + valueVal := value.(DataStructUser) + // rest of the code ... + } + } +``` + PS Why use buildsqlx? Because it is simple and fast, yet versatile. The builder code-style has been inherited from greatest web-frameworks, so u can easily query anything from db. Supporters gratitude: diff --git a/advanced.go b/advanced.go index f41846c..77685fa 100644 --- a/advanced.go +++ b/advanced.go @@ -1,8 +1,10 @@ package buildsqlx import ( + "database/sql" "fmt" "math" + "reflect" "strconv" ) @@ -31,44 +33,50 @@ func (r *DB) Find(src any, id uint64) error { return r.Where("id", "=", id).First(src) } -// Pluck getting values of a particular column and place them into slice -func (r *DB) Pluck(column string) (val []interface{}, err error) { - res, err := r.Get() +// Pluck getting values of a particular column(s) of a struct and place them into slice +func (r *DB) Pluck(src any) ([]any, error) { + res, err := r.eachToStructRows(src, r.Builder.offset, r.Builder.limit) if err != nil { return nil, err } - val = make([]interface{}, len(res)) - for k, m := range res { - val[k] = m[column] - } - return + return res, nil } // PluckMap getting values of a particular key/value columns and place them into map -func (r *DB) PluckMap(colKey, colValue string) (val []map[interface{}]interface{}, err error) { - res, err := r.Get() +// values of the returning map is a structure passed as src and filled with data from DB +func (r *DB) PluckMap(src any, colKey, colValue string) (val []map[any]any, err error) { + resource := reflect.ValueOf(src).Elem() + if err = validateFields(resource, src, []string{colKey, colValue}); err != nil { + return nil, err + } + + res, err := r.eachToStructRows(src, r.Builder.offset, r.Builder.limit) if err != nil { return nil, err } - val = make([]map[interface{}]interface{}, len(res)) + val = make([]map[any]any, len(res)) for k, m := range res { - val[k] = make(map[interface{}]interface{}) - val[k][m[colKey]] = m[colValue] + val[k] = make(map[any]any) + + fieldKeyData := getFieldValue(m, colKey) + val[k][fieldKeyData] = reflect.ValueOf(m).Interface() } + return } // Exists checks whether conditional rows are existing (returns true) or not (returns false) func (r *DB) Exists() (exists bool, err error) { - builder := r.Builder - if builder.table == "" { + bldr := r.Builder + if bldr.table == "" { return false, errTableCallBeforeOp } - query := `SELECT EXISTS(SELECT 1 FROM "` + builder.table + `" ` + builder.buildClauses() + `)` + query := `SELECT EXISTS(SELECT 1 FROM "` + bldr.table + `" ` + bldr.buildClauses() + `)` err = r.Sql().QueryRow(query, prepareValues(r.Builder.whereBindings)...).Scan(&exists) + return } @@ -78,6 +86,7 @@ func (r *DB) DoesntExists() (bool, error) { if err != nil { return false, err } + return !ex, nil } @@ -93,8 +102,8 @@ func (r *DB) Decrement(column string, on uint64) (int64, error) { // increments or decrements depending on sign func (r *DB) incrDecr(column, sign string, on uint64) (int64, error) { - builder := r.Builder - if builder.table == "" { + bldr := r.Builder + if bldr.table == "" { return 0, errTableCallBeforeOp } @@ -110,7 +119,7 @@ func (r *DB) incrDecr(column, sign string, on uint64) (int64, error) { // Chunk run queries by chinks by passing user-land function with an ability to stop execution when needed // by returning false and proceed to execute queries when return true -func (r *DB) Chunk(amount int64, fn func(rows []map[string]interface{}) bool) error { +func (r *DB) Chunk(src any, amount int64, fn func(rows []any) bool) error { cols := r.Builder.columns cnt, err := r.Count() if err != nil { @@ -123,26 +132,50 @@ func (r *DB) Chunk(amount int64, fn func(rows []map[string]interface{}) bool) er } if cnt < amount { - res, err := r.Get() + structRows, err := r.eachToStructRows(src, 0, 0) if err != nil { return err } - fn(res) // execute all resulting records + + fn(structRows) // execute all resulting records + return nil } // executing chunks amount < cnt c := int64(math.Ceil(float64(cnt / amount))) - var i int64 - for i = 0; i < c; i++ { - rows, err := r.Offset(i * amount).Limit(amount).Get() // by 100 rows from 100 x n + for i := int64(0); i < c; i++ { + structRows, err := r.eachToStructRows(src, i*amount, amount) if err != nil { return err } - res := fn(rows) + + res := fn(structRows) if !res { // stop an execution when false returned by user break } } + return nil } + +func (r *DB) eachToStructRows(src any, offset, limit int64) ([]any, error) { + var structRows []any + if limit > 0 { + r.Offset(offset).Limit(limit) + } + + err := r.EachToStruct(func(rows *sql.Rows) error { + err := r.Next(rows, src) + if err != nil { + return err + } + + v := reflect.ValueOf(src).Elem().Interface() + structRows = append(structRows, v) + + return nil + }) + + return structRows, err +} diff --git a/aggregates.go b/aggregates.go index 418667f..a49e9af 100644 --- a/aggregates.go +++ b/aggregates.go @@ -2,45 +2,50 @@ package buildsqlx // Count counts resulting rows based on clause func (r *DB) Count() (cnt int64, err error) { - builder := r.Builder - builder.columns = []string{"COUNT(*)"} - query := builder.buildSelect() + bldr := r.Builder + bldr.columns = []string{"COUNT(*)"} + query := bldr.buildSelect() err = r.Sql().QueryRow(query, prepareValues(r.Builder.whereBindings)...).Scan(&cnt) + return } // Avg calculates average for specified column func (r *DB) Avg(column string) (avg float64, err error) { - builder := r.Builder - builder.columns = []string{"AVG(" + column + ")"} - query := builder.buildSelect() + bldr := r.Builder + bldr.columns = []string{"AVG(" + column + ")"} + query := bldr.buildSelect() err = r.Sql().QueryRow(query, prepareValues(r.Builder.whereBindings)...).Scan(&avg) + return } // Min calculates minimum for specified column func (r *DB) Min(column string) (min float64, err error) { - builder := r.Builder - builder.columns = []string{"MIN(" + column + ")"} - query := builder.buildSelect() + bldr := r.Builder + bldr.columns = []string{"MIN(" + column + ")"} + query := bldr.buildSelect() err = r.Sql().QueryRow(query, prepareValues(r.Builder.whereBindings)...).Scan(&min) + return } // Max calculates maximum for specified column func (r *DB) Max(column string) (max float64, err error) { - builder := r.Builder - builder.columns = []string{"MAX(" + column + ")"} - query := builder.buildSelect() + bldr := r.Builder + bldr.columns = []string{"MAX(" + column + ")"} + query := bldr.buildSelect() err = r.Sql().QueryRow(query, prepareValues(r.Builder.whereBindings)...).Scan(&max) + return } // Sum calculates sum for specified column -func (r *DB) Sum(column string) (max float64, err error) { - builder := r.Builder - builder.columns = []string{"SUM(" + column + ")"} - query := builder.buildSelect() - err = r.Sql().QueryRow(query, prepareValues(r.Builder.whereBindings)...).Scan(&max) +func (r *DB) Sum(column string) (sum float64, err error) { + bldr := r.Builder + bldr.columns = []string{"SUM(" + column + ")"} + query := bldr.buildSelect() + err = r.Sql().QueryRow(query, prepareValues(r.Builder.whereBindings)...).Scan(&sum) + return } diff --git a/builder.go b/builder.go index 9c793e8..ed66252 100644 --- a/builder.go +++ b/builder.go @@ -371,9 +371,10 @@ func (r *DB) Rename(from, to string) (sql.Result, error) { // WhereIn appends IN (val1, val2, val3...) stmt to WHERE clause func (r *DB) WhereIn(field string, in any) *DB { ins, err := interfaceToSlice(in) - if err != nil { - return nil + if err != nil { // don't want the code run on prod falling just because user didn't pass slice as `in` param + log.Panicln(err) } + r.buildWhere("", field, "IN", ins) return r } @@ -382,8 +383,9 @@ func (r *DB) WhereIn(field string, in any) *DB { func (r *DB) WhereNotIn(field string, in any) *DB { ins, err := interfaceToSlice(in) if err != nil { - return nil + log.Panicln(err) } + r.buildWhere("", field, "NOT IN", ins) return r } @@ -392,8 +394,9 @@ func (r *DB) WhereNotIn(field string, in any) *DB { func (r *DB) OrWhereIn(field string, in any) *DB { ins, err := interfaceToSlice(in) if err != nil { - return nil + log.Panicln(err) } + r.buildWhere("OR", field, "IN", ins) return r } @@ -402,8 +405,9 @@ func (r *DB) OrWhereIn(field string, in any) *DB { func (r *DB) OrWhereNotIn(field string, in any) *DB { ins, err := interfaceToSlice(in) if err != nil { - return nil + log.Panicln(err) } + r.buildWhere("OR", field, "NOT IN", ins) return r } @@ -412,8 +416,9 @@ func (r *DB) OrWhereNotIn(field string, in any) *DB { func (r *DB) AndWhereIn(field string, in any) *DB { ins, err := interfaceToSlice(in) if err != nil { - return nil + log.Panicln(err) } + r.buildWhere("AND", field, "IN", ins) // r.buildWhere("AND", field, "IN", prepareSlice(ins)) return r @@ -423,8 +428,9 @@ func (r *DB) AndWhereIn(field string, in any) *DB { func (r *DB) AndWhereNotIn(field string, in any) *DB { ins, err := interfaceToSlice(in) if err != nil { - return nil + log.Panicln(err) } + r.buildWhere("AND", field, "NOT IN", ins) return r } @@ -459,26 +465,6 @@ func (r *DB) AndWhereNotNull(field string) *DB { return r.buildWhere(sqlOperatorAnd, field, sqlOperatorIs, sqlSpecificValueNotNull) } -// prepares slice for Where bindings, IN/NOT IN etc -func prepareSlice(in []any) (out []string) { - for _, value := range in { - switch v := value.(type) { - case string: - out = append(out, v) - case int: - out = append(out, strconv.FormatInt(int64(v), 10)) - case float64: - out = append(out, fmt.Sprintf("%g", v)) - case int64: - out = append(out, strconv.FormatInt(v, 10)) - case uint64: - out = append(out, strconv.FormatUint(v, 10)) - } - } - - return -} - // From prepares sql stmt to set data from another table, ex.: // UPDATE employees SET sales_count = sales_count + 1 FROM accounts func (r *DB) From(fromTbl string) *DB { diff --git a/builder_test.go b/builder_test.go index 771a681..1164d43 100644 --- a/builder_test.go +++ b/builder_test.go @@ -101,6 +101,7 @@ func TestSelectAndLimit(t *testing.T) { require.NoError(t, err) require.Equal(t, data, toScan) require.Len(t, testStructs, 2) + for _, strct := range testStructs { require.Equal(t, data.Foo, strct.Foo) require.Equal(t, data.Bar, strct.Bar) @@ -111,21 +112,51 @@ func TestSelectAndLimit(t *testing.T) { require.NoError(t, err) } -func TestInsert(t *testing.T) { +func TestDB_EachToStruct_RetErr(t *testing.T) { + err := db.Table(TestTable).Select("foo", "bar").EachToStruct(func(rows *sql.Rows) error { + return errors.New("some err") + }) + require.EqualError(t, err, "some err") +} + +func TestDB_ChunkErrs(t *testing.T) { _, err := db.Truncate(TestTable) require.NoError(t, err) err = db.Table(TestTable).Insert(data) require.NoError(t, err) - res, err := db.Table(TestTable).Select("foo", "bar", "baz").Get() + err = db.Table("").Chunk(&DataStruct{}, 1, func(rows []any) bool { + return false + }) + require.EqualError(t, err, "pq: zero-length delimited identifier at or near \"\"\"\"") + + var nilPtr *int = nil + err = db.Table(TestTable).Chunk(nilPtr, 1, func(rows []any) bool { + return false + }) + require.EqualError(t, err, "cannot decode into nil type *int") + + err = db.Table(TestTable).Chunk(nilPtr, 2, func(rows []any) bool { + return false + }) + require.EqualError(t, err, "cannot decode into nil type *int") + + _, err = db.Truncate(TestTable) require.NoError(t, err) +} - for k, mapVal := range dataMap { - for _, v := range res { - require.Equal(t, v[k], mapVal) - } - } +func TestInsert(t *testing.T) { + _, err := db.Truncate(TestTable) + require.NoError(t, err) + + err = db.Table(TestTable).Insert(data) + require.NoError(t, err) + + dataStruct := &DataStruct{} + err = db.Table(TestTable).Select("foo", "bar", "baz").ScanStruct(dataStruct) + require.NoError(t, err) + require.Equal(t, data, data) _, err = db.Truncate(TestTable) require.NoError(t, err) @@ -174,10 +205,12 @@ func TestWhereOnly(t *testing.T) { err = db.Table(TestTable).InsertBatch(batchDataStruct) require.NoError(t, err) - res, err := db.Table(TestTable).Select("foo", "bar", "baz").Where("foo", "=", cmp).Get() + + dataStruct := &DataStruct{} + err = db.Table(TestTable).Select("foo", "bar", "baz").Where("foo", "=", cmp).ScanStruct(dataStruct) require.NoError(t, err) - require.Equal(t, res[0]["foo"], cmp) + require.Equal(t, dataStruct.Foo, cmp) _, err = db.Truncate(TestTable) require.NoError(t, err) @@ -191,10 +224,13 @@ func TestWhereAndOr(t *testing.T) { err = db.Table(TestTable).InsertBatch(batchDataStruct) require.NoError(t, err) - res, err := db.Table(TestTable).Select("foo", "bar", "baz").Where("foo", "=", cmp).AndWhere("bar", "!=", "foo").OrWhere("baz", "=", 123).Get() - require.NoError(t, err) - require.Equal(t, res[0]["foo"], cmp) + dataStruct := &DataStruct{} + err = db.Table(TestTable).Select("foo", "bar", "baz").Where("foo", "=", cmp). + AndWhere("bar", "!=", "foo"). + OrWhere("baz", "=", 123).ScanStruct(dataStruct) + require.NoError(t, err) + require.Equal(t, dataStruct.Foo, cmp) _, err = db.Truncate(TestTable) require.NoError(t, err) @@ -232,6 +268,14 @@ type DataStructUser struct { Points int64 } +type DataStructUserPosts struct { + ID int64 + Name *string + Post *string + Points *int64 + UserID *int64 `db:"user_id"` +} + var batchUsersStruct = []DataStructUser{ 0: {ID: int64(1), Name: "Alex Shmidt", Points: int64(123)}, 1: {ID: int64(2), Name: "Darth Vader", Points: int64(1234)}, @@ -260,13 +304,24 @@ func TestJoins(t *testing.T) { err = db.Table(PostsTable).InsertBatch(posts) require.NoError(t, err) - res, err := db.Table(UsersTable).Select("name", "post", "user_id").LeftJoin( - PostsTable, UsersTable+".id", "=", PostsTable+".user_id").Get() + var dataStructs []DataStructUserPosts + dataStruct := &DataStructUserPosts{} + err = db.Table(UsersTable).Select("name", "post", "user_id").LeftJoin( + PostsTable, UsersTable+".id", "=", PostsTable+".user_id").EachToStruct(func(rows *sql.Rows) error { + err = db.Next(rows, dataStruct) + if err != nil { + return err + } + + dataStructs = append(dataStructs, *dataStruct) + return nil + }) require.NoError(t, err) - for k, val := range res { - require.Equal(t, val["name"], batchUsers[k].Name) - require.Equal(t, val["user_id"], batchUsers[k].ID) + for k, val := range dataStructs { + require.Equal(t, *val.Name, batchUsers[k].Name) + require.Equal(t, *val.Post, *posts[k].Post) + require.Equal(t, *val.UserID, *posts[k].UserID) } _, err = db.Truncate(UsersTable) @@ -288,16 +343,19 @@ func TestUpdate(t *testing.T) { require.NoError(t, err) for _, obj := range rowsToUpdate { - err := db.Table(TestTable).Insert(obj.insert) + err = db.Table(TestTable).Insert(obj.insert) require.NoError(t, err) rows, err := db.Table(TestTable).Where("foo", "=", "foo foo foo").Update(obj.update) require.NoError(t, err) require.GreaterOrEqual(t, rows, int64(1)) - res, err := db.Table(TestTable).Select("foo").Where("foo", "=", obj.update.Foo).Get() + dataStruct := &DataStruct{} + err = db.Table(TestTable).Select("foo").Where("foo", "=", obj.update.Foo). + ScanStruct(dataStruct) + require.NoError(t, err) - require.Equal(t, obj.update.Foo, res[0]["foo"]) + require.Equal(t, obj.update.Foo, dataStruct.Foo) } _, err = db.Truncate(TestTable) @@ -346,19 +404,19 @@ func TestDB_Increment_Decrement(t *testing.T) { _, err = db.Table(TestTable).Increment("baz", obj.incr) require.NoError(t, err) - res, err := db.Table(TestTable).Select("baz").Where("baz", "=", obj.incrRes).Get() + dataStruct := &DataStruct{} + err = db.Table(TestTable).Select("baz").Where("baz", "=", obj.incrRes). + ScanStruct(dataStruct) require.NoError(t, err) - require.GreaterOrEqual(t, len(res), 1) - require.Equal(t, res[0]["baz"], int64(obj.incrRes)) + require.Equal(t, *dataStruct.Baz, int64(obj.incrRes)) _, err = db.Table(TestTable).Decrement("baz", obj.decr) require.NoError(t, err) - res, err = db.Table(TestTable).Select("baz").Where("baz", "=", obj.decrRes).Get() + err = db.Table(TestTable).Select("baz").Where("baz", "=", obj.decrRes).ScanStruct(dataStruct) require.NoError(t, err) - require.GreaterOrEqual(t, len(res), 1) - require.Equal(t, res[0]["baz"], int64(obj.decrRes)) + require.Equal(t, *dataStruct.Baz, int64(obj.decrRes)) } _, err = db.Truncate(TestTable) @@ -392,10 +450,11 @@ func TestDB_Replace(t *testing.T) { require.NoError(t, err) require.GreaterOrEqual(t, rows, int64(1)) - res, err := db.Table(TestTable).Select("foo").Where("baz", "=", obj.replace.Baz).Get() + dataStruct := &DataStruct{} + err = db.Table(TestTable).Select("foo").Where("baz", "=", obj.replace.Baz). + ScanStruct(dataStruct) require.NoError(t, err) - require.GreaterOrEqual(t, len(res), 1) - require.Equal(t, res[0]["foo"], obj.replace.Foo) + require.Equal(t, dataStruct.Foo, obj.replace.Foo) } _, err = db.Truncate(TestTable) @@ -637,15 +696,17 @@ func TestDB_Pluck(t *testing.T) { err = db.Table(UsersTable).InsertBatch(batchUsers) require.NoError(t, err) - res, err := db.Table(UsersTable).Pluck("name") + + dataStruct := &DataStructUser{} + res, err := db.Table(UsersTable).Pluck(dataStruct) require.NoError(t, err) for k, v := range res { - resVal := v.(string) - require.Equal(t, batchUsers[k].Name, resVal) + resVal := v.(DataStructUser) + require.Equal(t, batchUsers[k].Name, resVal.Name) } - _, err = db.Table("nonexistent").Pluck("name") + _, err = db.Table("nonexistent").Pluck(dataStruct) require.Error(t, err) _, err = db.Truncate(UsersTable) @@ -658,21 +719,26 @@ func TestDB_PluckMap(t *testing.T) { err = db.Table(UsersTable).InsertBatch(batchUsers) require.NoError(t, err) - res, err := db.Table(UsersTable).PluckMap("name", "points") + + dataStruct := &DataStructUser{} + res, err := db.Table(UsersTable).PluckMap(dataStruct, "name", "points") require.NoError(t, err) for k, m := range res { for key, value := range m { keyVal := key.(string) - valueVal := value.(int64) + valueVal := value.(DataStructUser) require.Equal(t, batchUsers[k].Name, keyVal) - require.Equal(t, batchUsers[k].Points, valueVal) + require.Equal(t, batchUsers[k].Points, valueVal.Points) } } - _, err = db.Table("nonexistent").PluckMap("name", "points") + _, err = db.Table("nonexistent").PluckMap(dataStruct, "name", "points") require.Error(t, err) + _, err = db.Table("nonexistent").PluckMap(dataStruct, "namee", "points") + require.EqualError(t, err, "field 'Namee' not found in struct") + _, err = db.Truncate(UsersTable) require.NoError(t, err) } @@ -796,9 +862,20 @@ func TestDB_GroupByHaving(t *testing.T) { err = db.Table(UsersTable).InsertBatch(batchUsers) require.NoError(t, err) - res, err := db.Table(UsersTable).Select("points").GroupBy("points").Having("points", ">=", 123).Get() + dataStruct := &DataStructUser{} + var dataStructs []DataStructUser + err = db.Table(UsersTable).Select("points").GroupBy("points"). + Having("points", ">=", 123).EachToStruct(func(rows *sql.Rows) error { + err = db.Next(rows, dataStruct) + if err != nil { + return err + } + + dataStructs = append(dataStructs, *dataStruct) + return nil + }) require.NoError(t, err) - require.Equal(t, len(res), len(batchUsers)-1) + require.Equal(t, len(dataStructs), len(batchUsers)-1) _, err = db.Truncate(UsersTable) require.NoError(t, err) @@ -811,9 +888,21 @@ func TestDB_HavingRaw(t *testing.T) { err = db.Table(UsersTable).InsertBatch(batchUsers) require.NoError(t, err) - res, err := db.Table(UsersTable).Select("points").GroupBy("points").HavingRaw("points > 123").AndHavingRaw("points < 12345").OrHavingRaw("points = 0").Get() + dataStruct := &DataStructUser{} + var dataStructs []DataStructUser + err = db.Table(UsersTable).Select("points").GroupBy("points"). + HavingRaw("points > 123").AndHavingRaw("points < 12345").OrHavingRaw("points = 0"). + EachToStruct(func(rows *sql.Rows) error { + err = db.Next(rows, dataStruct) + if err != nil { + return err + } + + dataStructs = append(dataStructs, *dataStruct) + return nil + }) require.NoError(t, err) - require.Equal(t, len(batchUsers)-3, len(res)) + require.Equal(t, len(batchUsers)-3, len(dataStructs)) _, err = db.Truncate(UsersTable) require.NoError(t, err) @@ -832,36 +921,86 @@ func TestDB_AllJoins(t *testing.T) { err = db.Table(PostsTable).InsertBatch(batchPosts) require.NoError(t, err) - res, err := db.Table(UsersTable).Select("name", "post", "user_id").InnerJoin(PostsTable, UsersTable+".id", - "=", PostsTable+".user_id").Get() - require.NoError(t, err) + dataStruct := &DataStructUserPosts{} + var dataStructs []DataStructUserPosts + err = db.Table(UsersTable).Select("name", "post", "user_id").InnerJoin(PostsTable, UsersTable+".id", + "=", PostsTable+".user_id").EachToStruct(func(rows *sql.Rows) error { + err = db.Next(rows, dataStruct) + if err != nil { + return err + } - require.Equal(t, len(res), len(batchPosts)-1) + dataStructs = append(dataStructs, *dataStruct) - res, err = db.Table(PostsTable).Select("name", "post", "user_id").RightJoin(UsersTable, PostsTable+".user_id", - "=", UsersTable+".id").Get() + return nil + }) require.NoError(t, err) + require.Equal(t, len(dataStructs), len(batchPosts)-1) - require.Equal(t, len(res), len(batchUsers)) + dataStruct = &DataStructUserPosts{} + dataStructs = []DataStructUserPosts{} + err = db.Table(PostsTable).Select("name", "post", "user_id").RightJoin(UsersTable, PostsTable+".user_id", + "=", UsersTable+".id").EachToStruct(func(rows *sql.Rows) error { + err = db.Next(rows, dataStruct) + if err != nil { + return err + } + + dataStructs = append(dataStructs, *dataStruct) - res, err = db.Table(UsersTable).Select("name", "post", "user_id").FullJoin(PostsTable, UsersTable+".id", - "=", PostsTable+".user_id").Get() + return nil + }) require.NoError(t, err) + require.Equal(t, len(dataStructs), len(batchUsers)) + + dataStruct = &DataStructUserPosts{} + dataStructs = []DataStructUserPosts{} + err = db.Table(UsersTable).Select("name", "post", "user_id").FullJoin(PostsTable, UsersTable+".id", + "=", PostsTable+".user_id").EachToStruct(func(rows *sql.Rows) error { + err = db.Next(rows, dataStruct) + if err != nil { + return err + } - require.Equal(t, len(res), len(batchUsers)+1) + dataStructs = append(dataStructs, *dataStruct) - res, err = db.Table(UsersTable).Select("name", "post", "user_id").FullJoin(PostsTable, - UsersTable+".id", "=", PostsTable+".user_id").Get() + return nil + }) require.NoError(t, err) + require.Equal(t, len(dataStructs), len(batchUsers)+1) + + dataStruct = &DataStructUserPosts{} + dataStructs = []DataStructUserPosts{} + err = db.Table(UsersTable).Select("name", "post", "user_id").FullOuterJoin(PostsTable, + UsersTable+".id", "=", PostsTable+".user_id").EachToStruct(func(rows *sql.Rows) error { + err = db.Next(rows, dataStruct) + if err != nil { + return err + } - require.Equal(t, len(res), len(batchUsers)+1) + dataStructs = append(dataStructs, *dataStruct) - // note InRandomOrder check - res, err = db.Table(UsersTable).Select("name", "post", "user_id").FullJoin( - PostsTable, UsersTable+".id", "=", PostsTable+".user_id").InRandomOrder().Get() + return nil + }) require.NoError(t, err) + require.Equal(t, len(dataStructs), len(batchUsers)+1) + + // note InRandomOrder check + dataStruct = &DataStructUserPosts{} + dataStructs = []DataStructUserPosts{} + err = db.Table(UsersTable).Select("name", "post", "user_id").FullJoin( + PostsTable, UsersTable+".id", "=", PostsTable+".user_id").InRandomOrder().EachToStruct(func(rows *sql.Rows) error { + err = db.Next(rows, dataStruct) + if err != nil { + return err + } - require.Equal(t, len(res), len(batchUsers)+1) + dataStructs = append(dataStructs, *dataStruct) + + return nil + }) + require.NoError(t, err) + require.Equal(t, len(dataStructs), len(batchUsers)+1) _, err = db.Truncate(PostsTable) require.NoError(t, err) @@ -938,9 +1077,21 @@ func TestDB_WhereRaw(t *testing.T) { err = db.Table(UsersTable).InsertBatch(batchUsers) require.NoError(t, err) - res, err := db.Table(UsersTable).Select("name").WhereRaw("LENGTH(name) > 15").OrWhereRaw("points > 1234").Get() + dataStruct := &DataStructUser{} + var dataStructs []DataStructUser + err = db.Table(UsersTable).Select("name").WhereRaw("LENGTH(name) > 15"). + OrWhereRaw("points > 1234").EachToStruct(func(rows *sql.Rows) error { + err = db.Next(rows, dataStruct) + if err != nil { + return err + } + + dataStructs = append(dataStructs, *dataStruct) + + return nil + }) require.NoError(t, err) - require.Equal(t, len(res), 2) + require.Equal(t, len(dataStructs), 2) cnt, err := db.Table(UsersTable).WhereRaw("points > 123").AndWhereRaw("points < 12345").Count() require.NoError(t, err) @@ -957,9 +1108,20 @@ func TestDB_Offset(t *testing.T) { err = db.Table(UsersTable).InsertBatch(batchUsers) require.NoError(t, err) - res, err := db.Table(UsersTable).Offset(2).Limit(10).Get() + dataStruct := &DataStructUser{} + var dataStructs []DataStructUser + err = db.Table(UsersTable).Offset(2).Limit(10).EachToStruct(func(rows *sql.Rows) error { + err = db.Next(rows, dataStruct) + if err != nil { + return err + } + + dataStructs = append(dataStructs, *dataStruct) + + return nil + }) require.NoError(t, err) - require.Equal(t, len(res), 2) + require.Equal(t, len(dataStructs), 2) _, err = db.Truncate(UsersTable) require.NoError(t, err) @@ -995,13 +1157,52 @@ func TestDB_WhereIn(t *testing.T) { err = db.Table(UsersTable).InsertBatch(batchUsers) require.NoError(t, err) - res, err := db.Table(UsersTable).Select("name").WhereIn("points", []int64{123, 1234}).OrWhereIn("id", []int64{1, 2}).Get() + dataStruct := &DataStructUser{} + var dataStructs []DataStructUser + err = db.Table(UsersTable).Select("name").WhereIn("points", []int64{123, 1234}). + OrWhereIn("id", []int64{1, 2}).EachToStruct(func(rows *sql.Rows) error { + err = db.Next(rows, dataStruct) + if err != nil { + return err + } + + dataStructs = append(dataStructs, *dataStruct) + + return nil + }) require.NoError(t, err) - require.Equal(t, len(res), 2) + require.Equal(t, len(dataStructs), 2) - res, err = db.Table(UsersTable).Select("name").WhereIn("points", []int64{123, 1234}).AndWhereIn("id", []int64{1, 2}).Get() + dataStruct = &DataStructUser{} + dataStructs = []DataStructUser{} + err = db.Table(UsersTable).Select("name").WhereIn("points", []int64{123, 1234}). + AndWhereIn("id", []int64{1, 2}).EachToStruct(func(rows *sql.Rows) error { + err = db.Next(rows, dataStruct) + if err != nil { + return err + } + + dataStructs = append(dataStructs, *dataStruct) + + return nil + }) require.NoError(t, err) - require.Equal(t, len(res), 2) + require.Equal(t, len(dataStructs), 2) + + require.Panicsf(t, func() { + _ = db.Table(UsersTable).Select("name").WhereIn("points", DataStruct{}). + AndWhereIn("id", []int64{1, 2}) + }, "interfaceToSlice() given a non-slice type") + + require.Panicsf(t, func() { + _ = db.Table(UsersTable).Select("name").WhereIn("points", []int64{123, 1234}). + AndWhereIn("id", DataStruct{}) + }, "interfaceToSlice() given a non-slice type") + + require.Panicsf(t, func() { + _ = db.Table(UsersTable).Select("name").WhereIn("points", []int64{123, 1234}). + OrWhereIn("id", DataStruct{}) + }, "interfaceToSlice() given a non-slice type") _, err = db.Truncate(UsersTable) require.NoError(t, err) @@ -1013,13 +1214,53 @@ func TestDB_WhereNotIn(t *testing.T) { err = db.Table(UsersTable).InsertBatch(batchUsers) require.NoError(t, err) - res, err := db.Table(UsersTable).Select("name").WhereNotIn("points", []int64{123, 1234}).OrWhereNotIn("id", []int64{1, 2}).Get() + dataStruct := &DataStructUser{} + var dataStructs []DataStructUser + err = db.Table(UsersTable).Select("name").WhereNotIn("points", []int64{123, 1234}). + OrWhereNotIn("id", []int64{1, 2}).EachToStruct(func(rows *sql.Rows) error { + err = db.Next(rows, dataStruct) + if err != nil { + return err + } + + dataStructs = append(dataStructs, *dataStruct) + + return nil + }) require.NoError(t, err) - require.Equal(t, len(res), 2) + require.Equal(t, len(dataStructs), 2) + + require.Panicsf(t, func() { + _ = db.Table(UsersTable).Select("name").WhereNotIn("points", DataStruct{}). + OrWhereNotIn("id", []int64{1, 2}) + }, "interfaceToSlice() given a non-slice type") + + require.Panicsf(t, func() { + _ = db.Table(UsersTable).Select("name").WhereNotIn("points", []int64{123, 1234}). + OrWhereNotIn("id", DataStruct{}) + }, "interfaceToSlice() given a non-slice type") + + dataStruct = &DataStructUser{} + dataStructs = []DataStructUser{} + err = db.Table(UsersTable).Select("name").WhereNotIn("points", []int64{123, 1234}). + AndWhereNotIn("id", []int64{1, 2}).EachToStruct(func(rows *sql.Rows) error { + err = db.Next(rows, dataStruct) + if err != nil { + return err + } + + dataStructs = append(dataStructs, *dataStruct) - res, err = db.Table(UsersTable).Select("name").WhereNotIn("points", []int64{123, 1234}).AndWhereNotIn("id", []int64{1, 2}).Get() + return nil + }) require.NoError(t, err) - require.Equal(t, len(res), 2) + require.Equal(t, len(dataStructs), 2) + + require.Panicsf(t, func() { + _ = db.Table(UsersTable).Select("name").WhereNotIn("points", []int64{123, 1234}). + AndWhereNotIn("id", DataStruct{}) + }, "interfaceToSlice() given a non-slice type") + _, err = db.Truncate(UsersTable) require.NoError(t, err) } @@ -1031,14 +1272,38 @@ func TestDB_WhereNull(t *testing.T) { err = db.Table(PostsTable).InsertBatch(batchPosts) require.NoError(t, err) - res, err := db.Table(PostsTable).Select("title").WhereNull("post").AndWhereNull("user_id").Get() + dataStruct := &DataStructPost{} + var dataStructs []DataStructPost + err = db.Table(PostsTable).Select("title").WhereNull("post").AndWhereNull("user_id"). + EachToStruct(func(rows *sql.Rows) error { + err = db.Next(rows, dataStruct) + if err != nil { + return err + } + + dataStructs = append(dataStructs, *dataStruct) + + return nil + }) db.Dump() require.NoError(t, err) - require.Equal(t, len(res), 1) + require.Equal(t, len(dataStructs), 1) - res, err = db.Table(PostsTable).Select("title").WhereNull("post").OrWhereNull("user_id").Get() + dataStruct = &DataStructPost{} + dataStructs = []DataStructPost{} + err = db.Table(PostsTable).Select("title").WhereNull("post").OrWhereNull("user_id"). + EachToStruct(func(rows *sql.Rows) error { + err = db.Next(rows, dataStruct) + if err != nil { + return err + } + + dataStructs = append(dataStructs, *dataStruct) + + return nil + }) require.NoError(t, err) - require.Equal(t, len(res), 1) + require.Equal(t, len(dataStructs), 1) _, err = db.Truncate(PostsTable) require.NoError(t, err) @@ -1051,18 +1316,54 @@ func TestDB_WhereNotNull(t *testing.T) { err = db.Table(UsersTable).InsertBatch(batchUsers) require.NoError(t, err) - res, err := db.Table(UsersTable).Select("name").WhereNotNull("points").AndWhereNotNull("name").Get() + dataStruct := &DataStructUser{} + var dataStructs []DataStructUser + err = db.Table(UsersTable).Select("name").WhereNotNull("points"). + AndWhereNotNull("name").EachToStruct(func(rows *sql.Rows) error { + err = db.Next(rows, dataStruct) + if err != nil { + return err + } + + dataStructs = append(dataStructs, *dataStruct) + + return nil + }) require.NoError(t, err) - require.Equal(t, len(res), len(batchUsers)) + require.Equal(t, len(dataStructs), len(batchUsers)) - res, err = db.Table(UsersTable).Select("name").WhereNotNull("points").OrWhereNotNull("name").Get() + dataStruct = &DataStructUser{} + dataStructs = []DataStructUser{} + err = db.Table(UsersTable).Select("name").WhereNotNull("points").OrWhereNotNull("name"). + EachToStruct(func(rows *sql.Rows) error { + err = db.Next(rows, dataStruct) + if err != nil { + return err + } + + dataStructs = append(dataStructs, *dataStruct) + + return nil + }) require.NoError(t, err) - require.Equal(t, len(res), len(batchUsers)) + require.Equal(t, len(dataStructs), len(batchUsers)) - res, err = db.Table(UsersTable).Select("name").Where("id", "=", 1). - OrWhere("id", "=", 2).AndWhereNotNull("points").Get() + dataStruct = &DataStructUser{} + dataStructs = []DataStructUser{} + err = db.Table(UsersTable).Select("name").Where("id", "=", 1). + OrWhere("id", "=", 2).AndWhereNotNull("points"). + EachToStruct(func(rows *sql.Rows) error { + err = db.Next(rows, dataStruct) + if err != nil { + return err + } + + dataStructs = append(dataStructs, *dataStruct) + + return nil + }) require.NoError(t, err) - require.Equal(t, len(res), 2) + require.Equal(t, len(dataStructs), 2) _, err = db.Truncate(UsersTable) require.NoError(t, err) @@ -1075,9 +1376,21 @@ func TestDB_LockForUpdate(t *testing.T) { err = db.Table(UsersTable).InsertBatch(batchUsers) require.NoError(t, err) - res, err := db.Table(UsersTable).Select("name").LockForUpdate().Get() + dataStruct := &DataStructUser{} + var dataStructs []DataStructUser + err = db.Table(UsersTable).Select("name").LockForUpdate(). + EachToStruct(func(rows *sql.Rows) error { + err = db.Next(rows, dataStruct) + if err != nil { + return err + } + + dataStructs = append(dataStructs, *dataStruct) + + return nil + }) require.NoError(t, err) - require.Equal(t, len(res), len(batchUsers)) + require.Equal(t, len(dataStructs), len(batchUsers)) _, err = db.Truncate(UsersTable) require.NoError(t, err) @@ -1090,9 +1403,21 @@ func TestDB_UnionAll(t *testing.T) { err = db.Table(UsersTable).InsertBatch(batchUsers) require.NoError(t, err) - res, err := db.Table(UsersTable).Select("name").UnionAll().Table(UsersTable).Select("name").Get() + dataStruct := &DataStructUser{} + var dataStructs []DataStructUser + err = db.Table(UsersTable).Select("name").UnionAll().Table(UsersTable).Select("name"). + EachToStruct(func(rows *sql.Rows) error { + err = db.Next(rows, dataStruct) + if err != nil { + return err + } + + dataStructs = append(dataStructs, *dataStruct) + + return nil + }) require.NoError(t, err) - require.Equal(t, len(res), len(batchUsers)*2) + require.Equal(t, len(dataStructs), len(batchUsers)*2) _, err = db.Truncate(UsersTable) require.NoError(t, err) @@ -1108,10 +1433,22 @@ func TestDB_FullOuterJoin(t *testing.T) { err = db.Table(UsersTable).InsertBatch(batchUsers) require.NoError(t, err) - res, err := db.Table(UsersTable).Select("name").FullOuterJoin(PostsTable, UsersTable+".id", "=", - PostsTable+".user_id").Get() + dataStruct := &DataStructUser{} + var dataStructs []DataStructUser + err = db.Table(UsersTable).Select("name").FullOuterJoin(PostsTable, UsersTable+".id", "=", + PostsTable+".user_id"). + EachToStruct(func(rows *sql.Rows) error { + err = db.Next(rows, dataStruct) + if err != nil { + return err + } + + dataStructs = append(dataStructs, *dataStruct) + + return nil + }) require.NoError(t, err) - require.Equal(t, len(res), len(batchUsers)) + require.Equal(t, len(dataStructs), len(batchUsers)) _, err = db.Truncate(UsersTable) require.NoError(t, err) @@ -1123,17 +1460,19 @@ func TestDB_Chunk(t *testing.T) { err = db.Table(UsersTable).InsertBatch(batchUsers) require.NoError(t, err) + var sumOfPoints int64 - err = db.Table(UsersTable).Select("name", "points").Chunk(2, func(users []map[string]interface{}) bool { - for _, m := range users { - if val, ok := m["points"]; ok { - sumOfPoints += val.(int64) - } + dataStruct := &DataStructUser{} + err = db.Table(UsersTable).Select("name", "points").Chunk(dataStruct, 2, func(users []any) bool { + for _, v := range users { + user := v.(DataStructUser) + sumOfPoints += user.Points } + return true }) - require.NoError(t, err) + var initialSum int64 for _, mm := range batchUsers { initialSum += mm.Points @@ -1151,15 +1490,17 @@ func TestDB_ChunkFalse(t *testing.T) { err = db.Table(UsersTable).InsertBatch(batchUsers) require.NoError(t, err) var sumOfPoints int64 - err = db.Table(UsersTable).Select("name", "points").Chunk(2, func(users []map[string]interface{}) bool { - for _, m := range users { + dataStruct := &DataStructUser{} + err = db.Table(UsersTable).Select("name", "points").Chunk(dataStruct, 2, func(users []any) bool { + for _, v := range users { if sumOfPoints > 0 { return false } - if val, ok := m["points"]; ok { - sumOfPoints += val.(int64) - } + + user := v.(DataStructUser) + sumOfPoints += user.Points } + return true }) @@ -1170,7 +1511,7 @@ func TestDB_ChunkFalse(t *testing.T) { require.NoError(t, err) } -func TestDB_ChunkLessThenAmount(t *testing.T) { +func TestDB_ChunkLessThanAmount(t *testing.T) { _, err := db.Truncate(UsersTable) require.NoError(t, err) @@ -1178,12 +1519,13 @@ func TestDB_ChunkLessThenAmount(t *testing.T) { require.NoError(t, err) var sumOfPoints int64 - err = db.Table(UsersTable).Select("name", "points").Chunk(int64(len(batchUsers)+1), func(users []map[string]interface{}) bool { - for _, m := range users { - if val, ok := m["points"]; ok { - sumOfPoints += val.(int64) - } + dataStruct := &DataStructUser{} + err = db.Table(UsersTable).Select("name", "points").Chunk(dataStruct, int64(len(batchUsers)+1), func(users []any) bool { + for _, v := range users { + user := v.(DataStructUser) + sumOfPoints += user.Points } + return true }) require.NoError(t, err) @@ -1201,12 +1543,13 @@ func TestDB_ChunkLessThenZeroErr(t *testing.T) { require.NoError(t, err) var sumOfPoints int64 - err = db.Table(UsersTable).Select("name", "points").Chunk(int64(-1), func(users []map[string]interface{}) bool { - for _, m := range users { - if val, ok := m["points"]; ok { - sumOfPoints += val.(int64) - } + dataStruct := &DataStructUser{} + err = db.Table(UsersTable).Select("name", "points").Chunk(dataStruct, int64(-1), func(users []any) bool { + for _, v := range users { + user := v.(DataStructUser) + sumOfPoints += user.Points } + return true }) require.Errorf(t, err, "chunk can't be <= 0, your chunk is: -1") @@ -1223,7 +1566,18 @@ func TestDB_ChunkBuilderTableErr(t *testing.T) { err = db.InsertBatch(batchUsers) require.Error(t, err, errTableCallBeforeOp) - _, err = db.Select("foo", "bar", "baz").Get() + dataStruct := &DataStruct{} + var dataStructs []DataStruct + err = db.Select("foo", "bar", "baz").EachToStruct(func(rows *sql.Rows) error { + err = db.Next(rows, dataStruct) + if err != nil { + return err + } + + dataStructs = append(dataStructs, *dataStruct) + + return nil + }) require.Error(t, err, errTableCallBeforeOp) err = db.Insert(dataMap) diff --git a/factory.go b/factory.go index f9bd7d8..840083d 100644 --- a/factory.go +++ b/factory.go @@ -28,74 +28,6 @@ var ( type EachToStructFunc func(rows *sql.Rows) error -// Get builds all sql statements chained before and executes query collecting data to the slice -// Deprecated: this method will no longer be used in future releases, because of ScanStruct and EachToStruct replacement -func (r *DB) Get() ([]map[string]any, error) { - bldr := r.Builder - if bldr.table == "" { - return nil, errTableCallBeforeOp - } - - query := "" - if len(bldr.union) > 0 { // got union - need different logic to glue - for _, uBuilder := range bldr.union { - query += uBuilder + " UNION " - - if bldr.isUnionAll { - query += "ALL " - } - } - - query += bldr.buildSelect() - // clean union (all) after ensuring selects are built - r.Builder.union = []string{} - r.Builder.isUnionAll = false - } else { // std bldr - query = bldr.buildSelect() - } - - rows, err := r.Sql().Query(query, prepareValues(r.Builder.whereBindings)...) - if err != nil { - return nil, err - } - - columns, _ := rows.Columns() - count := len(columns) - values := make([]any, count) - valuePtrs := make([]any, count) - - // collecting data from struct with fields - var res []map[string]any - - for rows.Next() { - collect := make(map[string]any, count) - - for i := range columns { - valuePtrs[i] = &values[i] - } - - err = rows.Scan(valuePtrs...) - if err != nil { - return nil, err - } - - for i, col := range columns { - val := values[i] - - b, ok := val.([]byte) - if ok { - collect[col] = string(b) - } else { - collect[col] = val - } - } - - res = append(res, collect) - } - - return res, nil -} - // ScanStruct scans query into specific struct func (r *DB) ScanStruct(src any) error { if reflect.ValueOf(src).IsNil() { @@ -275,7 +207,9 @@ func setResourceValue(resource reflect.Value, src any, col string, value any) { func setValue(field reflect.Value, val any) { if field.Kind() == reflect.Ptr { newVal := reflect.New(field.Type().Elem()) - newVal.Elem().Set(reflect.ValueOf(val)) + if val != nil { + newVal.Elem().Set(reflect.ValueOf(val)) + } field.Set(newVal) return @@ -319,7 +253,7 @@ func validateFields(resource reflect.Value, src any, columns []string) error { } if !foundColByTag { - return fmt.Errorf("field %s not found in struct", fieldName) + return fmt.Errorf("field '%s' not found in struct", fieldName) } } } @@ -634,6 +568,30 @@ func getColumn(structField reflect.StructField) string { return col } +func getFieldValue(src any, col string) any { + fieldTitleName := cases.Title(language.English).String(col) + fieldUpperName := cases.Upper(language.English).String(col) + + resource := reflect.ValueOf(src) + if resource.FieldByName(fieldTitleName).IsValid() { + return resource.FieldByName(fieldTitleName).Interface() + } + + if !resource.FieldByName(fieldUpperName).IsValid() { + return resource.FieldByName(fieldUpperName).Interface() + } + + resourceType := reflect.TypeOf(src) + for i := 0; i < resourceType.NumField(); i++ { + f := resourceType.Field(i) + if f.Tag.Get("db") == col { + return resource.Field(i).Interface() + } + } + + return nil +} + // InsertBatch inserts multiple rows based on transaction func (r *DB) InsertBatch(data any) error { bldr := r.Builder diff --git a/helpers.go b/helpers.go index 4d4f9be..9ddb81b 100644 --- a/helpers.go +++ b/helpers.go @@ -9,7 +9,7 @@ func interfaceToSlice(slice interface{}) ([]interface{}, error) { var err error s := reflect.ValueOf(slice) if s.Kind() != reflect.Slice { - err = errors.New("interfaceToSlice() given a non-slice type") + return nil, errors.New("interfaceToSlice() given a non-slice type") } ret := make([]interface{}, s.Len())