diff --git a/.gitignore b/.gitignore index 13eedbe..0f498f5 100644 --- a/.gitignore +++ b/.gitignore @@ -10,4 +10,5 @@ # Output of the go coverage tool, specifically when used with LiteIDE *.out -vendor/ \ No newline at end of file +vendor/ +.idea/ \ No newline at end of file diff --git a/README.md b/README.md index c2fcfa5..85bcb22 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,7 @@ # buildsqlx -Go Database query builder library [![Tweet](http://jpillora.com/github-twitter-button/img/tweet.png)](https://twitter.com/intent/tweet?text=Go%20database%20query%20builder%20library%20&url=https://github.com/arthurkushman/buildsqlx&hashtags=go,golang,sql,builder,postgresql,sql-builder,developers) + +Go Database query builder +library [![Tweet](http://jpillora.com/github-twitter-button/img/tweet.png)](https://twitter.com/intent/tweet?text=Go%20database%20query%20builder%20library%20&url=https://github.com/arthurkushman/buildsqlx&hashtags=go,golang,sql,builder,postgresql,sql-builder,developers) [![Mentioned in Awesome Go](https://awesome.re/mentioned-badge.svg)](https://github.com/avelino/awesome-go) [![Go Report Card](https://goreportcard.com/badge/github.com/arthurkushman/buildsqlx)](https://goreportcard.com/report/github.com/arthurkushman/buildsqlx) @@ -34,159 +36,188 @@ Go Database query builder library [![Tweet](http://jpillora.com/github-twitter-b * [Chunking Results](#user-content-chunking-results) ## Installation + ```bash go get -u github.com/arthurkushman/buildsqlx ``` ## Selects, Ordering, Limit & Offset -You may not always want to select all columns from a database table. Using the select method, you can specify a custom select clause for the query: +You may not always want to select all columns from a database table. Using the select method, you can specify a custom +select clause for the query: ```go package yourpackage import ( - _ "github.com/lib/pq" + "database/sql" + "github.com/arthurkushman/buildsqlx" -) + _ "github.com/lib/pq" +) var db = buildsqlx.NewDb(buildsqlx.NewConnection("postgres", "user=postgres dbname=postgres password=postgres sslmode=disable")) func main() { - qDb := db.Table("posts").Select("title", "body") - - // If you already have a query builder instance and you wish to add a column to its existing select clause, you may use the addSelect method: - res, err := qDb.AddSelect("points").GroupBy("topic").OrderBy("points", "DESC").Limit(15).Offset(5).Get() + qDb := db.Table("posts").Select("title", "body") + + type DataStruct struct { + Foo string + Bar string + Baz *int64 + } + + dataStruct := DataStruct{} + var testStructs []DataStruct + // If you already have a query builder instance and you wish to add a column to its existing select clause, you may use the addSelect method: + err := qDb.AddSelect("points").GroupBy("topic").OrderBy("points", "DESC").Limit(15).Offset(5).EachToStruct(func(rows *sql.Rows) error { + err = db.Next(rows, &dataStruct) + if err != nil { + return err + } + + testStructs = append(testStructs, dataStruct) + return nil + }) } ``` ### InRandomOrder + ```go -res, err = db.Table("users").Select("name", "post", "user_id").InRandomOrder().Get() +res, err = db.Table("users").Select("name", "post", "user_id").InRandomOrder().ScanStruct(&dataStruct) ``` ## GroupBy / Having -The GroupBy and Having methods may be used to group the query results. + +The GroupBy and Having methods may be used to group the query results. The having method's signature is similar to that of the where method: + ```go -res, err := db.table("users").GroupBy("account_id").Having("account_id", ">", 100).Get() +res, err := db.table("users").GroupBy("account_id").Having("account_id", ">", 100).ScanStruct(&dataStruct) ``` ## Where, AndWhere, OrWhere clauses -You may use the where method on a query builder instance to add where clauses to the query. -The most basic call to where requires three arguments. -The first argument is the name of the column. -The second argument is an operator, which can be any of the database's supported operators. + +You may use the where method on a query builder instance to add where clauses to the query. +The most basic call to where requires three arguments. +The first argument is the name of the column. +The second argument is an operator, which can be any of the database's supported operators. Finally, the third argument is the value to evaluate against the column. ```go -package yourpackage - -import ( - _ "github.com/lib/pq" - "github.com/arthurkushman/buildsqlx" -) - -func main() { - res, err := db.Table("table1").Select("foo", "bar", "baz").Where("foo", "=", cmp).AndWhere("bar", "!=", "foo").OrWhere("baz", "=", 123).Get() -} +res, err := db.Table("table1").Select("foo", "bar", "baz").Where("foo", "=", cmp).AndWhere("bar", "!=", "foo").OrWhere("baz", "=", 123)..ScanStruct(&dataStruct) ``` -You may chain where constraints together as well as add or clauses to the query. +You may chain where constraints together as well as add or clauses to the query. The orWhere method accepts the same arguments as the where method. -## WhereIn / WhereNotIn +## WhereIn / WhereNotIn + The whereIn method verifies that a given column's value is contained within the given slice: + ```go -res, err := db.Table("table1").WhereIn("id", []int64{1, 2, 3}).OrWhereIn("name", []string{"John", "Paul"}).Get() +res, err := db.Table("table1").WhereIn("id", []int64{1, 2, 3}).OrWhereIn("name", []string{"John", "Paul"})..ScanStruct(&dataStruct) ``` -## WhereNull / WhereNotNull +## WhereNull / WhereNotNull + The whereNull method verifies that the value of the given column is NULL: + ```go -res, err := db.Table("posts").WhereNull("points").OrWhereNotNull("title").Get() +res, err := db.Table("posts").WhereNull("points").OrWhereNotNull("title")..ScanStruct(&dataStruct) ``` ## Left / Right / Cross / Inner / Left Outer Joins -The query builder may also be used to write join statements. -To perform a basic "inner join", you may use the InnerJoin method on a query builder instance. -The first argument passed to the join method is the name of the table you need to join to, -while the remaining arguments specify the column constraints for the join. + +The query builder may also be used to write join statements. +To perform a basic "inner join", you may use the InnerJoin method on a query builder instance. +The first argument passed to the join method is the name of the table you need to join to, +while the remaining arguments specify the column constraints for the join. You can even join to multiple tables in a single query: + ```go -res, err := db.Table("users").Select("name", "post", "user_id").LeftJoin("posts", "users.id", "=", "posts.user_id").Get() +res, err := db.Table("users").Select("name", "post", "user_id").LeftJoin("posts", "users.id", "=", "posts.user_id").EachToStruct(func(rows *sql.Rows) error { + err = db.Next(rows, &dataStruct) + if err != nil { + return err + } + + testStructs = append(testStructs, dataStruct) + return nil +}) ``` ## Inserts -The query builder also provides an insert method for inserting records into the database table. + +The query builder also provides an insert method for inserting records into the database table. The insert method accepts a map of column names and values: ```go -package yourpackage +// insert without getting id +err := db.Table("table1").Insert(DataStruct{ + Foo: "foo foo foo", + Bar: "bar bar bar", + Baz: &baz, +}) -import ( - _ "github.com/lib/pq" - "github.com/arthurkushman/buildsqlx" -) +// insert returning id +id, err := db.Table("table1").InsertGetId(DataStruct{ + Foo: "foo foo foo", + Bar: "bar bar bar", + Baz: &baz, +}) -func main() { - // insert without getting id - err := db.Table("table1").Insert(map[string]interface{}{"foo": "foo foo foo", "bar": "bar bar bar", "baz": int64(123)}) - - // insert returning id - id, err := db.Table("table1").InsertGetId(map[string]interface{}{"foo": "foo foo foo", "bar": "bar bar bar", "baz": int64(123)}) - - // batch insert - err := db.Table("table1").InsertBatch([]map[string]interface{}{ - 0: {"foo": "foo foo foo", "bar": "bar bar bar", "baz": 123}, - 1: {"foo": "foo foo foo foo", "bar": "bar bar bar bar", "baz": 1234}, - 2: {"foo": "foo foo foo foo foo", "bar": "bar bar bar bar bar", "baz": 12345}, - }) -} +// batch insert +err := db.Table("table1").InsertBatch([]map[string]interface{}{ + 0: {"foo": "foo foo foo", "bar": "bar bar bar", "baz": 123}, + 1: {"foo": "foo foo foo foo", "bar": "bar bar bar bar", "baz": 1234}, + 2: {"foo": "foo foo foo foo foo", "bar": "bar bar bar bar bar", "baz": 12345}, +}) ``` ## Updates -In addition to inserting records into the database, -the query builder can also update existing records using the update method. -The update method, like the insert method, accepts a slice of column and value pairs containing the columns to be updated. + +In addition to inserting records into the database, +the query builder can also update existing records using the update method. +The update method, like the insert method, accepts a slice of column and value pairs containing the columns to be +updated. You may constrain the update query using where clauses: + ```go -rows, err := db.Table("posts").Where("points", ">", 3).Update(map[string]interface{}{"title": "awesome"}) +rows, err := db.Table("posts").Where("points", ">", 3).Update(DataStruct{ + Title: "awesome", +}) ``` ## Delete -The query builder may also be used to delete records from the table via the delete method. + +The query builder may also be used to delete records from the table via the delete method. You may constrain delete statements by adding where clauses before calling the delete method: + ```go rows, err := db.Table("posts").Where("points", "=", 123).Delete() ``` ## Drop, Truncate, Rename -```go -package yourpackage - -import ( - _ "github.com/lib/pq" - "github.com/arthurkushman/buildsqlx" -) -func main() { - db.Drop("table_name") +```go +db.Drop("table_name") - db.DropIfExists("table_name") +db.DropIfExists("table_name") - db.Truncate("table_name") +db.Truncate("table_name") - db.Rename("table_name1", "table_name2") -} +db.Rename("table_name1", "table_name2") ``` ## Increment & Decrement -The query builder also provides convenient methods for incrementing or decrementing the value of a given column. +The query builder also provides convenient methods for incrementing or decrementing the value of a given column. This is a shortcut, providing a more expressive and terse interface compared to manually writing the update statement. -Both of these methods accept 2 arguments: the column to modify, a second argument to control the amount by which the column should be incremented or decremented: +Both of these methods accept 2 arguments: the column to modify, a second argument to control the amount by which the +column should be incremented or decremented: ```go db.Table("users").Increment("votes", 3) @@ -195,72 +226,78 @@ db.Table("users").Decrement("votes", 1) ``` ## Union / Union All -The query builder also provides a quick way to "union" two queries together. + +The query builder also provides a quick way to "union" two queries together. For example, you may create an initial query and use the union method to union it with a second query: + ```go union := db.Table("posts").Select("title", "likes").Union() -res, err := union.Table("users").Select("name", "points").Get() +res, err := union.Table("users").Select("name", "points").ScanStruct(&testStruct) // or if UNION ALL is of need // union := db.Table("posts").Select("title", "likes").UnionAll() ``` ## Transaction mode -You can run arbitrary queries mixed with any code in transaction mode getting an error and as a result rollback if something went wrong -or committed if everything is ok: + +You can run arbitrary queries mixed with any code in transaction mode getting an error and as a result rollback if +something went wrong +or committed if everything is ok: + ```go -err := db.InTransaction(func() (interface{}, error) { - return db.Table("users").Select("name", "post", "user_id").Get() +err := db.InTransaction(func () (interface{}, error) { + return db.Table("users").Select("name", "post", "user_id").ScanStruct(&testStruct) }) ``` ## Dump, Dd -You may use the Dd or Dump methods while building a query to dump the query bindings and SQL. -The dd method will display the debug information and then stop executing the request. -The dump method will display the debug information but allow the request to keep executing: -```go -package yourpackage -import ( - _ "github.com/lib/pq" - "github.com/arthurkushman/buildsqlx" -) +You may use the Dd or Dump methods while building a query to dump the query bindings and SQL. +The dd method will display the debug information and then stop executing the request. +The dump method will display the debug information but allow the request to keep executing: -func main() { +```go // to print raw sql query to stdout db.Table("table_name").Select("foo", "bar", "baz").Where("foo", "=", cmp).AndWhere("bar", "!=", "foo").Dump() - + // or to print to stdout and exit a.k.a dump and die - db.Table("table_name").Select("foo", "bar", "baz").Where("foo", "=", cmp).AndWhere("bar", "!=", "foo").Dd() -} + db.Table("table_name").Select("foo", "bar", "baz").Where("foo", "=", cmp).AndWhere("bar", "!=", "foo").Dd() ``` ## Check if table exists + ```go tblExists, err := db.HasTable("public", "posts") ``` ## Check if columns exist in a table within schema + ```go colsExists, err := db.HasColumns("public", "posts", "title", "user_id") ``` ## Retrieving A Single Row / Column From A Table -If you just need to retrieve a single row from the database table, you may use the `First` func. + +If you just need to retrieve a single row from the database table, you may use the `First` func. This method will return a single `map[string]interface{}`: + ```go post, err := db.Table("posts").Select("title").OrderBy("created_at", "desc").First() // usage ex: post["title"] ``` -If you don't even need an entire row, you may extract a single value from a record using the `Value` method. + +If you don't even need an entire row, you may extract a single value from a record using the `Value` method. This method will return the value of the column directly: + ```go res, err := db.Table("users").OrderBy("points", "desc").Value("name") // res -> "Alex Shmidt" ``` + To retrieve a single row by its id column value, use the `find` method: + ```go user, err := db.Table("users").Find(id) @@ -268,38 +305,49 @@ user, err := db.Table("users").Find(id) ``` ## WhereExists / WhereNotExists -The whereExists method allows you to write where exists SQL clauses. -The whereExists method accepts a *DB argument, -which will receive a query builder instance allowing you to define the query that should be placed inside of the "exists" clause: + +The whereExists method allows you to write where exists SQL clauses. +The whereExists method accepts a *DB argument, +which will receive a query builder instance allowing you to define the query that should be placed inside of the " +exists" clause: + ```go res, er := db.Table("users").Select("name").WhereExists( db.Table("users").Select("name").Where("points", ">=", int64(12345)), ).First() ``` + Any query that is of need to build one can place inside `WhereExists` clause/func. ## WhereBetween / WhereNotBetween + The whereBetween func verifies that a column's value is between two values: + ```go -res, err := db.Table(UsersTable).Select("name").WhereBetween("points", 1233, 12345).Get() +res, err := db.Table(UsersTable).Select("name").WhereBetween("points", 1233, 12345).ScanStruct(&testStruct) ``` The whereNotBetween func verifies that a column's value lies outside of two values: + ```go -res, err := db.Table(UsersTable).Select("name").WhereNotBetween("points", 123, 123456).Get() +res, err := db.Table(UsersTable).Select("name").WhereNotBetween("points", 123, 123456).ScanStruct(&testStruct) ``` ## Determining If Records Exist -Instead of using the count method to determine if any records exist that match your query's constraints, + +Instead of using the count method to determine if any records exist that match your query's constraints, you may use the exists and doesntExist methods: + ```go exists, err := db.Table(UsersTable).Select("name").Where("points", ">=", int64(12345)).Exists() // use an inverse DoesntExists() if needed ``` ## Aggregates -The query builder also provides a variety of aggregate methods such as Count, Max, Min, Avg, and Sum. + +The query builder also provides a variety of aggregate methods such as Count, Max, Min, Avg, and Sum. You may call any of these methods after constructing your query: + ```go cnt, err := db.Table(UsersTable).WHere("points", ">=", 1234).Count() @@ -313,9 +361,12 @@ sum, err := db.Table(UsersTable).Sum("points") ``` ## Create table -To create a new database table, use the CreateTable method. -The Schema method accepts two arguments. -The first is the name of the table, while the second is an anonymous function/closure which receives a Table struct that may be used to define the new table: + +To create a new database table, use the CreateTable method. +The Schema method accepts two arguments. +The first is the name of the table, while the second is an anonymous function/closure which receives a Table struct that +may be used to define the new table: + ```go res, err := db.Schema("big_tbl", func(table *Table) error { table.Increments("id") @@ -327,68 +378,76 @@ res, err := db.Schema("big_tbl", func(table *Table) error { table.DblPrecision("likes_to_points").Default(0.0) table.Char("tag", 10) table.DateTime("created_at", true) - table.DateTimeTz("updated_at", true) + table.DateTimeTz("updated_at", true) table.Decimal("tax", 2, 2) table.TsVector("body") - table.TsQuery("body_query") + table.TsQuery("body_query") table.Jsonb("settings") table.Point("pt") - table.Polygon("poly") - table.TableComment("big table for big data") - - return nil + table.Polygon("poly") + table.TableComment("big table for big data") + + return nil }) // to make a foreign key constraint from another table -_, err = db.Schema("tbl_to_ref", func(table *Table) error { +_, err = db.Schema("tbl_to_ref", func (table *Table) error { table.Increments("id") table.Integer("big_tbl_id").ForeignKey("fk_idx_big_tbl_id", "big_tbl", "id").Concurrently().IfNotExists() // to add index on existing column just repeat stmt + index e.g.: table.Char("tag", 10).Index("idx_tag").Include("likes", "created_at") table.Rename("settings", "options") - + return nil -}) +}) ``` ## Add / Modify / Drop columns -The Table structure in the Schema's 2nd argument may be used to update existing tables. Just the way you've been created it. + +The Table structure in the Schema's 2nd argument may be used to update existing tables. Just the way you've been created +it. The Change method allows you to modify some existing column types to a new type or modify the column's attributes. + ```go res, err := db.Schema("tbl_name", func(table *Table) error { table.String("title", 128).Change() - + return nil }) ``` + Use DropColumn method to remove any column: + ```go res, err := db.Schema("tbl_name", func(table *Table) error { table.DropColumn("deleted_at").IfExists() // To drop an index on the column table.DropIndex("idx_title") - + return nil }) ``` ## Chunking Results -If you need to work with thousands of database records, consider using the chunk method. + +If you need to work with thousands of database records, consider using the chunk method. This method retrieves a small chunk of the results at a time and feeds each chunk into a closure for processing. + ```go -err = db.Table("user_achievements").Select("points").Where("id", "=", id).Chunk(100, func(users []map[string]interface{}) bool { +err = db.Table("user_achievements").Select("points").Where("id", "=", id).Chunk(100, func (users []map[string]interface{}) bool { for _, m := range users { if val, ok := m["points"];ok { pointsCalc += diffFormula(val.(int64)) } - // or you can return false here to stop running chunks + // or you can return false here to stop running chunks } return true }) ``` -PS Why use buildsqlx? Because it is simple and fast, yet versatile. -The performance achieved because of structs conversion lack, as all that you need is just a columns - u can get it from an associated array/map while the conversion itself and it's processing eats more CPU/memory resources. +PS Why use buildsqlx? Because it is simple and fast, yet versatile. +The performance achieved because of structs conversion lack, as all that you need is just a columns - u can get it from +an associated array/map while the conversion itself and it's processing eats more CPU/memory resources. Supporters gratitude: diff --git a/builder_test.go b/builder_test.go index b3ec45c..b112bd8 100644 --- a/builder_test.go +++ b/builder_test.go @@ -8,17 +8,22 @@ import ( "testing" _ "github.com/lib/pq" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) const ( TestTable = "test" - PostsTable = "posts" - UsersTable = "users" + PostsTable = "test_posts" + UsersTable = "test_users" TestUserName = "Dead Beaf" ) +type DataStruct struct { + Foo string + Bar string + Baz *int64 +} + var ( dbConnInfo = fmt.Sprintf("host=%s port=%d user=%s "+ "password=%s dbname=%s sslmode=disable", @@ -26,16 +31,23 @@ var ( dataMap = map[string]interface{}{"foo": "foo foo foo", "bar": "bar bar bar", "baz": int64(123)} + baz = int64(123) + data = DataStruct{ + Foo: "foo foo foo", + Bar: "bar bar bar", + Baz: &baz, + } + db = NewDb(NewConnection("postgres", dbConnInfo)) ) func TestMain(m *testing.M) { - _, err := db.Sql().Exec("create table if not exists users (id serial primary key, name varchar(128) not null, points integer)") + _, err := db.Sql().Exec("create table if not exists test_users (id serial primary key, name varchar(128) not null, points integer)") if err != nil { panic(err) } - _, err = db.Sql().Exec("create table if not exists posts (id serial primary key, title varchar(128) not null, post text, user_id integer, created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW())") + _, err = db.Sql().Exec("create table if not exists test_posts (id serial primary key, title varchar(128) not null, post text, user_id integer, created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW())") if err != nil { panic(err) } @@ -55,43 +67,57 @@ func TestNewConnectionFromDB(t *testing.T) { func TestSelectAndLimit(t *testing.T) { _, err := db.Truncate(TestTable) - assert.NoError(t, err) + require.NoError(t, err) - err = db.Table(TestTable).Insert(dataMap) - assert.NoError(t, err) + err = db.Table(TestTable).Insert(data) + require.NoError(t, err) - qDb := db.Table(TestTable).Select("foo", "bar") - res, err := qDb.AddSelect("baz").Limit(15).Get() - assert.NoError(t, err) + err = db.Table(TestTable).Insert(data) + require.NoError(t, err) - for k, mapVal := range dataMap { - for _, v := range res { - assert.Equal(t, v[k], mapVal) + toScan := DataStruct{} + qDb := db.Table(TestTable).Select("foo", "bar") + var testStructs []DataStruct + err = qDb.AddSelect("baz").Limit(15).EachToStruct(func(rows *sql.Rows) error { + err = db.Next(rows, &toScan) + if err != nil { + return err } + + testStructs = append(testStructs, toScan) + return nil + }) + require.NoError(t, err) + require.Equal(t, data, toScan) + require.Len(t, testStructs, 2) + for _, strct := range testStructs { + require.Equal(t, data.Foo, strct.Foo) + require.Equal(t, data.Bar, strct.Bar) + require.Equal(t, data.Baz, strct.Baz) } _, err = db.Truncate(TestTable) - assert.NoError(t, err) + require.NoError(t, err) } func TestInsert(t *testing.T) { _, err := db.Truncate(TestTable) - assert.NoError(t, err) + require.NoError(t, err) - err = db.Table(TestTable).Insert(dataMap) - assert.NoError(t, err) + err = db.Table(TestTable).Insert(data) + require.NoError(t, err) res, err := db.Table(TestTable).Select("foo", "bar", "baz").Get() - assert.NoError(t, err) + require.NoError(t, err) for k, mapVal := range dataMap { for _, v := range res { - assert.Equal(t, v[k], mapVal) + require.Equal(t, v[k], mapVal) } } _, err = db.Truncate(TestTable) - assert.NoError(t, err) + require.NoError(t, err) } var batchData = []map[string]interface{}{ @@ -102,56 +128,56 @@ var batchData = []map[string]interface{}{ func TestInsertBatchSelectMultiple(t *testing.T) { _, err := db.Truncate(TestTable) - assert.NoError(t, err) + require.NoError(t, err) err = db.Table(TestTable).InsertBatch(batchData) - assert.NoError(t, err) + require.NoError(t, err) res, err := db.Table(TestTable).Select("foo", "bar", "baz").OrderBy("foo", "ASC").Get() - assert.NoError(t, err) + require.NoError(t, err) for mapKey, mapVal := range batchData { for k, mV := range mapVal { - assert.Equal(t, res[mapKey][k], mV) + require.Equal(t, res[mapKey][k], mV) } } _, err = db.Truncate(TestTable) - assert.NoError(t, err) + require.NoError(t, err) } func TestWhereOnly(t *testing.T) { var cmp = "foo foo foo" _, err := db.Truncate(TestTable) - assert.NoError(t, err) + require.NoError(t, err) err = db.Table(TestTable).InsertBatch(batchData) - assert.NoError(t, err) + require.NoError(t, err) res, err := db.Table(TestTable).Select("foo", "bar", "baz").Where("foo", "=", cmp).Get() - assert.NoError(t, err) + require.NoError(t, err) - assert.Equal(t, res[0]["foo"], cmp) + require.Equal(t, res[0]["foo"], cmp) _, err = db.Truncate(TestTable) - assert.NoError(t, err) + require.NoError(t, err) } func TestWhereAndOr(t *testing.T) { var cmp = "foo foo foo" _, err := db.Truncate(TestTable) - assert.NoError(t, err) + require.NoError(t, err) err = db.Table(TestTable).InsertBatch(batchData) - assert.NoError(t, err) + require.NoError(t, err) res, err := db.Table(TestTable).Select("foo", "bar", "baz").Where("foo", "=", cmp).AndWhere("bar", "!=", "foo").OrWhere("baz", "=", 123).Get() - assert.NoError(t, err) + require.NoError(t, err) - assert.Equal(t, res[0]["foo"], cmp) + require.Equal(t, res[0]["foo"], cmp) _, err = db.Truncate(TestTable) - assert.NoError(t, err) + require.NoError(t, err) } //var users = `create table users (id serial primary key, name varchar(128) not null, points integer)` @@ -172,17 +198,30 @@ var batchPosts = []map[string]interface{}{ 3: {"id": int64(4), "title": "Lorem ipsum dolor sit amet,", "post": nil, "user_id": nil, "updated_at": "2086-09-09 18:27:40"}, } +type DataStructUser struct { + ID int64 + Name string + Points int64 +} + +var batchUsersStruct = []DataStructUser{ + 0: {ID: int64(1), Name: "Alex Shmidt", Points: int64(123)}, + 1: {ID: int64(2), Name: "Darth Vader", Points: int64(1234)}, + 2: {ID: int64(3), Name: "Dead Beaf", Points: int64(12345)}, + 3: {ID: int64(4), Name: "Dead Beaf", Points: int64(12345)}, +} + func TestJoins(t *testing.T) { _, err := db.Truncate(UsersTable) - assert.NoError(t, err) + require.NoError(t, err) _, err = db.Truncate(PostsTable) - assert.NoError(t, err) + require.NoError(t, err) var posts []map[string]interface{} - for _, v := range batchUsers { + for _, v := range batchUsersStruct { id, err := db.Table(UsersTable).InsertGetId(v) - assert.NoError(t, err) + require.NoError(t, err) posts = append(posts, map[string]interface{}{ "title": "ttl", "post": "foo bar baz", "user_id": id, @@ -190,198 +229,222 @@ func TestJoins(t *testing.T) { } err = db.Table(PostsTable).InsertBatch(posts) - assert.NoError(t, err) + require.NoError(t, err) - res, err := db.Table(UsersTable).Select("name", "post", "user_id").LeftJoin("posts", "users.id", "=", "posts.user_id").Get() - assert.NoError(t, err) + res, err := db.Table(UsersTable).Select("name", "post", "user_id").LeftJoin( + PostsTable, UsersTable+".id", "=", PostsTable+".user_id").Get() + require.NoError(t, err) for k, val := range res { - assert.Equal(t, val["name"], batchUsers[k]["name"]) - assert.Equal(t, val["user_id"], batchUsers[k]["id"]) + require.Equal(t, val["name"], batchUsers[k]["name"]) + require.Equal(t, val["user_id"], batchUsers[k]["id"]) } _, err = db.Truncate(UsersTable) - assert.NoError(t, err) + require.NoError(t, err) _, err = db.Truncate(PostsTable) - assert.NoError(t, err) + require.NoError(t, err) } var rowsToUpdate = []struct { - insert map[string]interface{} - update map[string]interface{} + insert DataStruct + update DataStruct }{ - {map[string]interface{}{"foo": "foo foo foo", "bar": "bar bar bar", "baz": 123}, map[string]interface{}{"foo": "foo changed", "baz": nil}}, + {DataStruct{Foo: "foo foo foo", Bar: "bar bar bar", Baz: &baz}, DataStruct{Foo: "foo changed", Baz: nil}}, } func TestUpdate(t *testing.T) { _, err := db.Truncate(TestTable) - assert.NoError(t, err) + require.NoError(t, err) for _, obj := range rowsToUpdate { err := db.Table(TestTable).Insert(obj.insert) - assert.NoError(t, err) + require.NoError(t, err) rows, err := db.Table(TestTable).Where("foo", "=", "foo foo foo").Update(obj.update) - assert.NoError(t, err) - assert.GreaterOrEqual(t, rows, int64(1)) + require.NoError(t, err) + require.GreaterOrEqual(t, rows, int64(1)) - res, err := db.Table(TestTable).Select("foo").Where("foo", "=", obj.update["foo"]).Get() - assert.NoError(t, err) - assert.Equal(t, obj.update["foo"], res[0]["foo"]) + res, err := db.Table(TestTable).Select("foo").Where("foo", "=", obj.update.Foo).Get() + require.NoError(t, err) + require.Equal(t, obj.update.Foo, res[0]["foo"]) } _, err = db.Truncate(TestTable) - assert.NoError(t, err) + require.NoError(t, err) } var rowsToDelete = []struct { - insert map[string]interface{} + insert DataStruct where map[string]interface{} }{ - {map[string]interface{}{"foo": "foo foo foo", "bar": "bar bar bar", "baz": 123}, map[string]interface{}{"bar": 123}}, + {DataStruct{Foo: "foo foo foo", Bar: "bar bar bar", Baz: &baz}, map[string]interface{}{"bar": 123}}, } func TestDelete(t *testing.T) { _, err := db.Truncate(TestTable) - assert.NoError(t, err) + require.NoError(t, err) for _, obj := range rowsToDelete { err := db.Table(TestTable).Insert(obj.insert) - assert.NoError(t, err) + require.NoError(t, err) rows, err := db.Table(TestTable).Where("baz", "=", obj.where["bar"]).Delete() - assert.NoError(t, err) - assert.GreaterOrEqual(t, rows, int64(1)) + require.NoError(t, err) + require.GreaterOrEqual(t, rows, int64(1)) } } var incrDecr = []struct { - insert map[string]interface{} + insert DataStruct incr uint64 incrRes uint64 decr uint64 decrRes uint64 }{ - {map[string]interface{}{"foo": "foo foo foo", "bar": "bar bar bar", "baz": 1}, 3, 4, 1, 3}, + {DataStruct{Foo: "foo foo foo", Bar: "bar bar bar", Baz: &baz}, 3, 126, 1, 125}, } func TestDB_Increment_Decrement(t *testing.T) { _, err := db.Truncate(TestTable) - assert.NoError(t, err) + require.NoError(t, err) for _, obj := range incrDecr { err = db.Table(TestTable).Insert(obj.insert) - assert.NoError(t, err) + require.NoError(t, err) _, err = db.Table(TestTable).Increment("baz", obj.incr) - assert.NoError(t, err) + require.NoError(t, err) res, err := db.Table(TestTable).Select("baz").Where("baz", "=", obj.incrRes).Get() - assert.NoError(t, err) - assert.GreaterOrEqual(t, len(res), 1) - assert.Equal(t, res[0]["baz"], int64(obj.incrRes)) + require.NoError(t, err) + require.GreaterOrEqual(t, len(res), 1) + require.Equal(t, res[0]["baz"], int64(obj.incrRes)) _, err = db.Table(TestTable).Decrement("baz", obj.decr) - assert.NoError(t, err) + require.NoError(t, err) res, err = db.Table(TestTable).Select("baz").Where("baz", "=", obj.decrRes).Get() - assert.NoError(t, err) + require.NoError(t, err) - assert.GreaterOrEqual(t, len(res), 1) - assert.Equal(t, res[0]["baz"], int64(obj.decrRes)) + require.GreaterOrEqual(t, len(res), 1) + require.Equal(t, res[0]["baz"], int64(obj.decrRes)) } _, err = db.Truncate(TestTable) - assert.NoError(t, err) + require.NoError(t, err) +} + +type DataStructID struct { + ID int64 + Foo string + Bar string + Baz int64 } var rowsToReplace = []struct { - insert map[string]interface{} + insert DataStructID conflict string - replace map[string]interface{} + replace DataStructID }{ - {map[string]interface{}{"id": 1, "foo": "foo foo foo", "bar": "bar bar bar", "baz": 123}, "id", map[string]interface{}{"id": 1, "foo": "baz baz baz", "bar": "bar bar bar", "baz": 123}}, + {DataStructID{ID: 1, Foo: "foo foo foo", Bar: "bar bar bar", Baz: 123}, "id", DataStructID{ID: 1, Foo: "baz baz baz", Bar: "bar bar bar", Baz: 123}}, } func TestDB_Replace(t *testing.T) { _, err := db.Truncate(TestTable) - assert.NoError(t, err) + require.NoError(t, err) for _, obj := range rowsToReplace { _, err := db.Table(TestTable).Replace(obj.insert, obj.conflict) - assert.NoError(t, err) + require.NoError(t, err) rows, err := db.Table(TestTable).Replace(obj.replace, obj.conflict) - assert.NoError(t, err) - assert.GreaterOrEqual(t, rows, int64(1)) + require.NoError(t, err) + require.GreaterOrEqual(t, rows, int64(1)) - res, err := db.Table(TestTable).Select("foo").Where("baz", "=", obj.replace["baz"]).Get() - assert.NoError(t, err) - assert.GreaterOrEqual(t, len(res), 1) - assert.Equal(t, res[0]["foo"], obj.replace["foo"]) + res, err := db.Table(TestTable).Select("foo").Where("baz", "=", obj.replace.Baz).Get() + require.NoError(t, err) + require.GreaterOrEqual(t, len(res), 1) + require.Equal(t, res[0]["foo"], obj.replace.Foo) } _, err = db.Truncate(TestTable) - assert.NoError(t, err) + require.NoError(t, err) } -var userForUnion = map[string]interface{}{"id": int64(1), "name": "Alex Shmidt", "points": int64(123)} +type User struct { + ID int64 + Name string + Points int64 +} + +var userForUnion = User{ID: int64(1), Name: "Alex Shmidt", Points: int64(123)} func TestDB_Union(t *testing.T) { _, err := db.Truncate(TestTable) - assert.NoError(t, err) + require.NoError(t, err) _, err = db.Truncate(UsersTable) - assert.NoError(t, err) + require.NoError(t, err) - err = db.Table(TestTable).Insert(dataMap) - assert.NoError(t, err) + err = db.Table(TestTable).Insert(data) + require.NoError(t, err) err = db.Table(UsersTable).Insert(userForUnion) - assert.NoError(t, err) + require.NoError(t, err) union := db.Table(TestTable).Select("bar", "baz").Union() - res, err := union.Table(UsersTable).Select("name", "points").Get() - assert.NoError(t, err) - for _, v := range res { - assert.Equal(t, v["baz"], userForUnion["points"]) + dataStruct := DataStruct{} + var dataStructs []DataStruct + err = union.Table(UsersTable).Select("name", "points").EachToStruct(func(rows *sql.Rows) error { + err = union.Next(rows, &dataStruct) + if err != nil { + return err + } + + dataStructs = append(dataStructs, dataStruct) + return nil + }) + require.NoError(t, err) + for _, v := range dataStructs { + require.Equal(t, *v.Baz, userForUnion.Points) } _, err = db.Truncate(TestTable) - assert.NoError(t, err) + require.NoError(t, err) _, err = db.Truncate(UsersTable) - assert.NoError(t, err) + require.NoError(t, err) } func TestDB_InTransaction(t *testing.T) { var tests = map[string]struct { - dataMap map[string]interface{} + dataMap DataStruct res interface{} err error }{ "transaction commit ok": { - dataMap: dataMap, + dataMap: data, res: 1, err: nil, }, "transaction commit ok int64": { - dataMap: dataMap, + dataMap: data, res: int64(1), err: nil, }, "transaction commit ok uint64": { - dataMap: dataMap, + dataMap: data, res: uint64(1), err: nil, }, "transaction commit ok map[string]interface{}": { - dataMap: dataMap, + dataMap: data, res: map[string]interface{}{"foo": "foo foo foo", "bar": "bar bar bar", "baz": int64(123)}, err: nil, }, "transaction commit ok []map[string]interface{}": { - dataMap: dataMap, + dataMap: data, res: []map[string]interface{}{ { "foo": "foo foo foo", "bar": "bar bar bar", "baz": int64(123), @@ -390,12 +453,12 @@ func TestDB_InTransaction(t *testing.T) { err: nil, }, "transaction early exit err": { - dataMap: dataMap, + dataMap: data, res: 0, err: errors.New("some error"), }, "transaction rollback": { - dataMap: dataMap, + dataMap: data, res: 0, err: nil, }, @@ -404,11 +467,11 @@ func TestDB_InTransaction(t *testing.T) { for n, tt := range tests { t.Run(n, func(t *testing.T) { _, err := db.Truncate(TestTable) - assert.NoError(t, err) + require.NoError(t, err) defer func() { _, err = db.Truncate(TestTable) - assert.NoError(t, err) + require.NoError(t, err) }() err = db.InTransaction(func() (any, error) { @@ -418,9 +481,9 @@ func TestDB_InTransaction(t *testing.T) { }) if tt.err != nil { - assert.Error(t, tt.err, err) + require.Error(t, tt.err, err) } else { - assert.NoError(t, err) + require.NoError(t, err) } }) } @@ -428,221 +491,226 @@ func TestDB_InTransaction(t *testing.T) { func TestDB_HasTable(t *testing.T) { tblExists, err := db.HasTable("public", PostsTable) - assert.NoError(t, err) - assert.True(t, tblExists) + require.NoError(t, err) + require.True(t, tblExists) } func TestDB_HasColumns(t *testing.T) { colsExists, err := db.HasColumns("public", PostsTable, "title", "user_id") - assert.NoError(t, err) - assert.True(t, colsExists) + require.NoError(t, err) + require.True(t, colsExists) } func TestDB_First(t *testing.T) { _, err := db.Truncate(TestTable) - assert.NoError(t, err) + require.NoError(t, err) - err = db.Table(TestTable).Insert(dataMap) - assert.NoError(t, err) + err = db.Table(TestTable).Insert(data) + require.NoError(t, err) // write concurrent row to order and get the only 1st - err = db.Table(TestTable).Insert(map[string]interface{}{"foo": "foo foo foo 2", "bar": "bar bar bar 2", "baz": int64(1234)}) - assert.NoError(t, err) + bazz := int64(1234) + err = db.Table(TestTable).Insert(DataStruct{ + Foo: "foo foo foo 2", + Bar: "bar bar bar 2", + Baz: &bazz, + }) + require.NoError(t, err) res, err := db.Table(TestTable).Select("baz").OrderBy("baz", "desc").OrderBy("foo", "desc").First() - assert.NoError(t, err) - assert.Equal(t, res["baz"], int64(1234)) + require.NoError(t, err) + require.Equal(t, res["baz"], int64(1234)) _, err = db.Table(TestTable).Select("baz").OrderBy("baz", "desc").OrderBy("fo", "desc").First() - assert.Error(t, err) + require.Error(t, err) _, err = db.Truncate(TestTable) - assert.NoError(t, err) + require.NoError(t, err) } func TestDB_Find(t *testing.T) { _, err := db.Truncate(TestTable) - assert.NoError(t, err) + require.NoError(t, err) - id, err := db.Table(TestTable).InsertGetId(dataMap) - assert.NoError(t, err) + id, err := db.Table(TestTable).InsertGetId(data) + require.NoError(t, err) res, err := db.Table(TestTable).Find(id) - assert.NoError(t, err) - assert.Equal(t, res["id"], int64(id)) + require.NoError(t, err) + require.Equal(t, res["id"], int64(id)) _, err = db.Truncate(TestTable) - assert.NoError(t, err) + require.NoError(t, err) } func TestDB_WhereExists(t *testing.T) { _, err := db.Truncate(UsersTable) - assert.NoError(t, err) + require.NoError(t, err) err = db.Table(UsersTable).InsertBatch(batchUsers) - assert.NoError(t, err) + require.NoError(t, err) res, er := db.Table(UsersTable).Select("name").WhereExists( db.Table(UsersTable).Select("name").Where("points", ">=", int64(12345)), ).First() - assert.NoError(t, er) - assert.Equal(t, TestUserName, res["name"]) + require.NoError(t, er) + require.Equal(t, TestUserName, res["name"]) _, err = db.Truncate(UsersTable) - assert.NoError(t, err) + require.NoError(t, err) } func TestDB_WhereNotExists(t *testing.T) { _, err := db.Truncate(UsersTable) - assert.NoError(t, err) + require.NoError(t, err) err = db.Table(UsersTable).InsertBatch(batchUsers) - assert.NoError(t, err) + require.NoError(t, err) res, er := db.Table(UsersTable).Select("name").WhereNotExists( db.Table(UsersTable).Select("name").Where("points", ">=", int64(12345)), ).First() - assert.NoError(t, er) - assert.Equal(t, TestUserName, res["name"]) + require.NoError(t, er) + require.Equal(t, TestUserName, res["name"]) _, err = db.Truncate(UsersTable) - assert.NoError(t, err) + require.NoError(t, err) } func TestDB_Value(t *testing.T) { _, err := db.Truncate(UsersTable) - assert.NoError(t, err) + require.NoError(t, err) err = db.Table(UsersTable).InsertBatch(batchUsers) - assert.NoError(t, err) + require.NoError(t, err) res, err := db.Table(UsersTable).OrderBy("points", "desc").Value("name") - assert.NoError(t, err) - assert.Equal(t, TestUserName, res) + require.NoError(t, err) + require.Equal(t, TestUserName, res) _, err = db.Table(UsersTable).OrderBy("poin", "desc").Value("name") - assert.Error(t, err) + require.Error(t, err) _, err = db.Truncate(UsersTable) - assert.NoError(t, err) + require.NoError(t, err) } func TestDB_Pluck(t *testing.T) { _, err := db.Truncate(UsersTable) - assert.NoError(t, err) + require.NoError(t, err) err = db.Table(UsersTable).InsertBatch(batchUsers) - assert.NoError(t, err) + require.NoError(t, err) res, err := db.Table(UsersTable).Pluck("name") - assert.NoError(t, err) + require.NoError(t, err) for k, v := range res { resVal := v.(string) - assert.Equal(t, batchUsers[k]["name"], resVal) + require.Equal(t, batchUsers[k]["name"], resVal) } _, err = db.Table("nonexistent").Pluck("name") - assert.Error(t, err) + require.Error(t, err) _, err = db.Truncate(UsersTable) - assert.NoError(t, err) + require.NoError(t, err) } func TestDB_PluckMap(t *testing.T) { _, err := db.Truncate(UsersTable) - assert.NoError(t, err) + require.NoError(t, err) err = db.Table(UsersTable).InsertBatch(batchUsers) - assert.NoError(t, err) + require.NoError(t, err) res, err := db.Table(UsersTable).PluckMap("name", "points") - assert.NoError(t, err) + require.NoError(t, err) for k, m := range res { for key, value := range m { keyVal := key.(string) valueVal := value.(int64) - assert.Equal(t, batchUsers[k]["name"], keyVal) - assert.Equal(t, batchUsers[k]["points"], valueVal) + require.Equal(t, batchUsers[k]["name"], keyVal) + require.Equal(t, batchUsers[k]["points"], valueVal) } } _, err = db.Table("nonexistent").PluckMap("name", "points") - assert.Error(t, err) + require.Error(t, err) _, err = db.Truncate(UsersTable) - assert.NoError(t, err) + require.NoError(t, err) } func TestDB_Exists(t *testing.T) { _, err := db.Truncate(UsersTable) - assert.NoError(t, err) + require.NoError(t, err) err = db.Table(UsersTable).InsertBatch(batchUsers) - assert.NoError(t, err) + require.NoError(t, err) prepared := db.Table(UsersTable).Select("name").Where("points", ">=", int64(12345)) exists, err := prepared.Exists() - assert.NoError(t, err) + require.NoError(t, err) doesntEx, err := prepared.DoesntExists() - assert.NoError(t, err) + require.NoError(t, err) - assert.True(t, exists, "The record must exist at this state of db data") - assert.False(t, doesntEx, "The record must exist at this state of db data") + require.True(t, exists, "The record must exist at this state of db data") + require.False(t, doesntEx, "The record must exist at this state of db data") _, err = db.Truncate(UsersTable) - assert.NoError(t, err) + require.NoError(t, err) } func TestDB_Count(t *testing.T) { _, err := db.Truncate(UsersTable) - assert.NoError(t, err) + require.NoError(t, err) err = db.Table(UsersTable).InsertBatch(batchUsers) - assert.NoError(t, err) + require.NoError(t, err) cnt, err := db.Table(UsersTable).Count() - assert.NoError(t, err) + require.NoError(t, err) - assert.Equalf(t, int64(len(batchUsers)), cnt, "want: %d, got: %d", len(batchUsers), cnt) + require.Equalf(t, int64(len(batchUsers)), cnt, "want: %d, got: %d", len(batchUsers), cnt) _, err = db.Truncate(UsersTable) - assert.NoError(t, err) + require.NoError(t, err) } func TestDB_Avg(t *testing.T) { _, err := db.Truncate(UsersTable) - assert.NoError(t, err) + require.NoError(t, err) err = db.Table(UsersTable).InsertBatch(batchUsers) - assert.NoError(t, err) + require.NoError(t, err) avg, err := db.Table(UsersTable).Avg("points") - assert.NoError(t, err) + require.NoError(t, err) var cntBatch float64 for _, v := range batchUsers { cntBatch += float64(v["points"].(int64)) / float64(len(batchUsers)) } - assert.Equalf(t, cntBatch, avg, "want: %d, got: %d", cntBatch, avg) + require.Equalf(t, cntBatch, avg, "want: %d, got: %d", cntBatch, avg) _, err = db.Truncate(UsersTable) - assert.NoError(t, err) + require.NoError(t, err) } func TestDB_MinMax(t *testing.T) { _, err := db.Truncate(UsersTable) - assert.NoError(t, err) + require.NoError(t, err) err = db.Table(UsersTable).InsertBatch(batchUsers) - assert.NoError(t, err) + require.NoError(t, err) mn, err := db.Table(UsersTable).Min("points") - assert.NoError(t, err) + require.NoError(t, err) mx, err := db.Table(UsersTable).Max("points") - assert.NoError(t, err) + require.NoError(t, err) var max float64 var min = float64(123456) @@ -656,355 +724,361 @@ func TestDB_MinMax(t *testing.T) { } } - assert.Equalf(t, mn, min, "want: %d, got: %d", mn, min) - assert.Equalf(t, mx, max, "want: %d, got: %d", mx, max) + require.Equalf(t, mn, min, "want: %d, got: %d", mn, min) + require.Equalf(t, mx, max, "want: %d, got: %d", mx, max) _, err = db.Truncate(UsersTable) - assert.NoError(t, err) + require.NoError(t, err) } func TestDB_Sum(t *testing.T) { _, err := db.Truncate(UsersTable) - assert.NoError(t, err) + require.NoError(t, err) err = db.Table(UsersTable).InsertBatch(batchUsers) - assert.NoError(t, err) + require.NoError(t, err) sum, err := db.Table(UsersTable).Sum("points") - assert.NoError(t, err) + require.NoError(t, err) var cntBatch float64 for _, v := range batchUsers { cntBatch += float64(v["points"].(int64)) } - assert.Equalf(t, cntBatch, sum, "want: %d, got: %d", cntBatch, sum) + require.Equalf(t, cntBatch, sum, "want: %d, got: %d", cntBatch, sum) _, err = db.Truncate(UsersTable) - assert.NoError(t, err) + require.NoError(t, err) } func TestDB_GroupByHaving(t *testing.T) { _, err := db.Truncate(UsersTable) - assert.NoError(t, err) + require.NoError(t, err) err = db.Table(UsersTable).InsertBatch(batchUsers) - assert.NoError(t, err) + require.NoError(t, err) res, err := db.Table(UsersTable).Select("points").GroupBy("points").Having("points", ">=", 123).Get() - assert.NoError(t, err) - assert.Equal(t, len(res), len(batchUsers)-1) + require.NoError(t, err) + require.Equal(t, len(res), len(batchUsers)-1) _, err = db.Truncate(UsersTable) - assert.NoError(t, err) + require.NoError(t, err) } func TestDB_HavingRaw(t *testing.T) { _, err := db.Truncate(UsersTable) - assert.NoError(t, err) + require.NoError(t, err) err = db.Table(UsersTable).InsertBatch(batchUsers) - assert.NoError(t, err) + require.NoError(t, err) res, err := db.Table(UsersTable).Select("points").GroupBy("points").HavingRaw("points > 123").AndHavingRaw("points < 12345").OrHavingRaw("points = 0").Get() - assert.NoError(t, err) - assert.Equal(t, len(batchUsers)-3, len(res)) + require.NoError(t, err) + require.Equal(t, len(batchUsers)-3, len(res)) _, err = db.Truncate(UsersTable) - assert.NoError(t, err) + require.NoError(t, err) } func TestDB_AllJoins(t *testing.T) { _, err := db.Truncate(PostsTable) - assert.NoError(t, err) + require.NoError(t, err) _, err = db.Truncate(UsersTable) - assert.NoError(t, err) + require.NoError(t, err) err = db.Table(UsersTable).InsertBatch(batchUsers) - assert.NoError(t, err) + require.NoError(t, err) err = db.Table(PostsTable).InsertBatch(batchPosts) - assert.NoError(t, err) + require.NoError(t, err) - res, err := db.Table(UsersTable).Select("name", "post", "user_id").InnerJoin(PostsTable, "users.id", "=", "posts.user_id").Get() - assert.NoError(t, err) + res, err := db.Table(UsersTable).Select("name", "post", "user_id").InnerJoin(PostsTable, UsersTable+".id", + "=", PostsTable+".user_id").Get() + require.NoError(t, err) - assert.Equal(t, len(res), len(batchPosts)-1) + require.Equal(t, len(res), len(batchPosts)-1) - res, err = db.Table(PostsTable).Select("name", "post", "user_id").RightJoin(UsersTable, "posts.user_id", "=", "users.id").Get() - assert.NoError(t, err) + res, err = db.Table(PostsTable).Select("name", "post", "user_id").RightJoin(UsersTable, PostsTable+".user_id", + "=", UsersTable+".id").Get() + require.NoError(t, err) - assert.Equal(t, len(res), len(batchUsers)) + require.Equal(t, len(res), len(batchUsers)) - res, err = db.Table(UsersTable).Select("name", "post", "user_id").FullJoin(PostsTable, "users.id", "=", "posts.user_id").Get() - assert.NoError(t, err) + res, err = db.Table(UsersTable).Select("name", "post", "user_id").FullJoin(PostsTable, UsersTable+".id", + "=", PostsTable+".user_id").Get() + require.NoError(t, err) - assert.Equal(t, len(res), len(batchUsers)+1) + require.Equal(t, len(res), len(batchUsers)+1) - res, err = db.Table(UsersTable).Select("name", "post", "user_id").FullJoin(PostsTable, "users.id", "=", "posts.user_id").Get() - assert.NoError(t, err) + res, err = db.Table(UsersTable).Select("name", "post", "user_id").FullJoin(PostsTable, + UsersTable+".id", "=", PostsTable+".user_id").Get() + require.NoError(t, err) - assert.Equal(t, len(res), len(batchUsers)+1) + require.Equal(t, len(res), len(batchUsers)+1) // note InRandomOrder check - res, err = db.Table(UsersTable).Select("name", "post", "user_id").FullJoin(PostsTable, "users.id", "=", "posts.user_id").InRandomOrder().Get() - assert.NoError(t, err) + res, err = db.Table(UsersTable).Select("name", "post", "user_id").FullJoin( + PostsTable, UsersTable+".id", "=", PostsTable+".user_id").InRandomOrder().Get() + require.NoError(t, err) - assert.Equal(t, len(res), len(batchUsers)+1) + require.Equal(t, len(res), len(batchUsers)+1) _, err = db.Truncate(PostsTable) - assert.NoError(t, err) + require.NoError(t, err) _, err = db.Truncate(UsersTable) - assert.NoError(t, err) + require.NoError(t, err) } func TestDB_OrderByRaw(t *testing.T) { _, err := db.Truncate(PostsTable) - assert.NoError(t, err) + require.NoError(t, err) err = db.Table(PostsTable).InsertBatch(batchPosts) - assert.NoError(t, err) + require.NoError(t, err) res, err := db.Table(PostsTable).Select("title").OrderByRaw("updated_at - created_at DESC").First() - assert.NoError(t, err) + require.NoError(t, err) - assert.Equal(t, batchPosts[2]["title"], res["title"]) + require.Equal(t, batchPosts[2]["title"], res["title"]) _, err = db.Truncate(PostsTable) - assert.NoError(t, err) + require.NoError(t, err) } func TestDB_SelectRaw(t *testing.T) { _, err := db.Truncate(UsersTable) - assert.NoError(t, err) + require.NoError(t, err) err = db.Table(UsersTable).InsertBatch(batchUsers) - assert.NoError(t, err) + require.NoError(t, err) res, err := db.Table(UsersTable).SelectRaw("SUM(points) as pts").First() - assert.NoError(t, err) + require.NoError(t, err) var sum int64 for _, v := range batchUsers { sum += v["points"].(int64) } - assert.Equal(t, sum, res["pts"]) + require.Equal(t, sum, res["pts"]) _, err = db.Truncate(UsersTable) - assert.NoError(t, err) + require.NoError(t, err) } func TestDB_AndWhereBetween(t *testing.T) { _, err := db.Truncate(UsersTable) - assert.NoError(t, err) + require.NoError(t, err) err = db.Table(UsersTable).InsertBatch(batchUsers) - assert.NoError(t, err) + require.NoError(t, err) res, err := db.Table(UsersTable).Select("name").WhereBetween("points", 1233, 12345).OrWhereBetween("points", 123456, 67891023).AndWhereNotBetween("points", 12, 23).First() - assert.NoError(t, err) + require.NoError(t, err) - assert.Equal(t, "Darth Vader", res["name"]) + require.Equal(t, "Darth Vader", res["name"]) res, err = db.Table(UsersTable).Select("name").WhereNotBetween("points", 12, 123).AndWhereBetween("points", 1233, 12345).OrWhereNotBetween("points", 12, 23).First() - assert.NoError(t, err) + require.NoError(t, err) - assert.Equal(t, "Alex Shmidt", res["name"]) + require.Equal(t, "Alex Shmidt", res["name"]) _, err = db.Truncate(UsersTable) - assert.NoError(t, err) + require.NoError(t, err) } func TestDB_WhereRaw(t *testing.T) { _, err := db.Truncate(UsersTable) - assert.NoError(t, err) + require.NoError(t, err) err = db.Table(UsersTable).InsertBatch(batchUsers) - assert.NoError(t, err) + require.NoError(t, err) res, err := db.Table(UsersTable).Select("name").WhereRaw("LENGTH(name) > 15").OrWhereRaw("points > 1234").Get() - assert.NoError(t, err) - assert.Equal(t, len(res), 2) + require.NoError(t, err) + require.Equal(t, len(res), 2) cnt, err := db.Table(UsersTable).WhereRaw("points > 123").AndWhereRaw("points < 12345").Count() - assert.NoError(t, err) - assert.Equal(t, cnt, int64(1)) + require.NoError(t, err) + require.Equal(t, cnt, int64(1)) _, err = db.Truncate(UsersTable) - assert.NoError(t, err) + require.NoError(t, err) } func TestDB_Offset(t *testing.T) { _, err := db.Truncate(UsersTable) - assert.NoError(t, err) + require.NoError(t, err) err = db.Table(UsersTable).InsertBatch(batchUsers) - assert.NoError(t, err) + require.NoError(t, err) res, err := db.Table(UsersTable).Offset(2).Limit(10).Get() - assert.NoError(t, err) - assert.Equal(t, len(res), 2) + require.NoError(t, err) + require.Equal(t, len(res), 2) _, err = db.Truncate(UsersTable) - assert.NoError(t, err) + require.NoError(t, err) } func TestDB_Rename(t *testing.T) { tbl := "tbl1" tbl2 := "tbl2" _, err := db.DropIfExists(tbl, tbl2) - assert.NoError(t, err) + require.NoError(t, err) _, err = db.Schema(tbl, func(table *Table) error { table.Increments("id") return nil }) - assert.NoError(t, err) + require.NoError(t, err) _, err = db.Rename(tbl, tbl2) - assert.NoError(t, err) + require.NoError(t, err) exists, err := db.HasTable("public", tbl2) - assert.NoError(t, err) - assert.True(t, exists) + require.NoError(t, err) + require.True(t, exists) _, err = db.Drop(tbl2) - assert.NoError(t, err) + require.NoError(t, err) } func TestDB_WhereIn(t *testing.T) { _, err := db.Truncate(UsersTable) - assert.NoError(t, err) + require.NoError(t, err) err = db.Table(UsersTable).InsertBatch(batchUsers) - assert.NoError(t, err) + require.NoError(t, err) res, err := db.Table(UsersTable).Select("name").WhereIn("points", []int64{123, 1234}).OrWhereIn("id", []int64{1, 2}).Get() - assert.NoError(t, err) - assert.Equal(t, len(res), 2) + require.NoError(t, err) + require.Equal(t, len(res), 2) res, err = db.Table(UsersTable).Select("name").WhereIn("points", []int64{123, 1234}).AndWhereIn("id", []int64{1, 2}).Get() - assert.NoError(t, err) - assert.Equal(t, len(res), 2) + require.NoError(t, err) + require.Equal(t, len(res), 2) _, err = db.Truncate(UsersTable) - assert.NoError(t, err) + require.NoError(t, err) } func TestDB_WhereNotIn(t *testing.T) { _, err := db.Truncate(UsersTable) - assert.NoError(t, err) + require.NoError(t, err) err = db.Table(UsersTable).InsertBatch(batchUsers) - assert.NoError(t, err) + require.NoError(t, err) res, err := db.Table(UsersTable).Select("name").WhereNotIn("points", []int64{123, 1234}).OrWhereNotIn("id", []int64{1, 2}).Get() - assert.NoError(t, err) - assert.Equal(t, len(res), 2) + require.NoError(t, err) + require.Equal(t, len(res), 2) res, err = db.Table(UsersTable).Select("name").WhereNotIn("points", []int64{123, 1234}).AndWhereNotIn("id", []int64{1, 2}).Get() - assert.NoError(t, err) - assert.Equal(t, len(res), 2) + require.NoError(t, err) + require.Equal(t, len(res), 2) _, err = db.Truncate(UsersTable) - assert.NoError(t, err) + require.NoError(t, err) } func TestDB_WhereNull(t *testing.T) { _, err := db.Truncate(PostsTable) - assert.NoError(t, err) + require.NoError(t, err) err = db.Table(PostsTable).InsertBatch(batchPosts) - assert.NoError(t, err) + require.NoError(t, err) res, err := db.Table(PostsTable).Select("title").WhereNull("post").AndWhereNull("user_id").Get() db.Dump() - assert.NoError(t, err) - assert.Equal(t, len(res), 1) + require.NoError(t, err) + require.Equal(t, len(res), 1) res, err = db.Table(PostsTable).Select("title").WhereNull("post").OrWhereNull("user_id").Get() - assert.NoError(t, err) - assert.Equal(t, len(res), 1) + require.NoError(t, err) + require.Equal(t, len(res), 1) _, err = db.Truncate(PostsTable) - assert.NoError(t, err) + require.NoError(t, err) } func TestDB_WhereNotNull(t *testing.T) { _, err := db.Truncate(UsersTable) - assert.NoError(t, err) + require.NoError(t, err) err = db.Table(UsersTable).InsertBatch(batchUsers) - assert.NoError(t, err) + require.NoError(t, err) res, err := db.Table(UsersTable).Select("name").WhereNotNull("points").AndWhereNotNull("name").Get() - assert.NoError(t, err) - assert.Equal(t, len(res), len(batchUsers)) + require.NoError(t, err) + require.Equal(t, len(res), len(batchUsers)) res, err = db.Table(UsersTable).Select("name").WhereNotNull("points").OrWhereNotNull("name").Get() - assert.NoError(t, err) - assert.Equal(t, len(res), len(batchUsers)) + require.NoError(t, err) + require.Equal(t, len(res), len(batchUsers)) res, err = db.Table(UsersTable).Select("name").Where("id", "=", 1). OrWhere("id", "=", 2).AndWhereNotNull("points").Get() - assert.NoError(t, err) - assert.Equal(t, len(res), 2) + require.NoError(t, err) + require.Equal(t, len(res), 2) _, err = db.Truncate(UsersTable) - assert.NoError(t, err) + require.NoError(t, err) } func TestDB_LockForUpdate(t *testing.T) { _, err := db.Truncate(UsersTable) - assert.NoError(t, err) + require.NoError(t, err) err = db.Table(UsersTable).InsertBatch(batchUsers) - assert.NoError(t, err) + require.NoError(t, err) res, err := db.Table(UsersTable).Select("name").LockForUpdate().Get() - assert.NoError(t, err) - assert.Equal(t, len(res), len(batchUsers)) + require.NoError(t, err) + require.Equal(t, len(res), len(batchUsers)) _, err = db.Truncate(UsersTable) - assert.NoError(t, err) + require.NoError(t, err) } func TestDB_UnionAll(t *testing.T) { _, err := db.Truncate(UsersTable) - assert.NoError(t, err) + require.NoError(t, err) err = db.Table(UsersTable).InsertBatch(batchUsers) - assert.NoError(t, err) + require.NoError(t, err) res, err := db.Table(UsersTable).Select("name").UnionAll().Table(UsersTable).Select("name").Get() - assert.NoError(t, err) - assert.Equal(t, len(res), len(batchUsers)*2) + require.NoError(t, err) + require.Equal(t, len(res), len(batchUsers)*2) _, err = db.Truncate(UsersTable) - assert.NoError(t, err) + require.NoError(t, err) } func TestDB_FullOuterJoin(t *testing.T) { _, err := db.Truncate(UsersTable) - assert.NoError(t, err) + require.NoError(t, err) err = db.Table(UsersTable).InsertBatch(batchUsers) - assert.NoError(t, err) + require.NoError(t, err) - res, err := db.Table(UsersTable).Select("name").FullOuterJoin(PostsTable, "users.id", "=", "posts.user_id").Get() - assert.NoError(t, err) - assert.Equal(t, len(res), len(batchUsers)) + res, err := db.Table(UsersTable).Select("name").FullOuterJoin(PostsTable, UsersTable+".id", "=", + PostsTable+".user_id").Get() + require.NoError(t, err) + require.Equal(t, len(res), len(batchUsers)) _, err = db.Truncate(UsersTable) - assert.NoError(t, err) + require.NoError(t, err) } func TestDB_Chunk(t *testing.T) { _, err := db.Truncate(UsersTable) - assert.NoError(t, err) + require.NoError(t, err) err = db.Table(UsersTable).InsertBatch(batchUsers) - assert.NoError(t, err) + require.NoError(t, err) var sumOfPoints int64 err = db.Table(UsersTable).Select("name", "points").Chunk(2, func(users []map[string]interface{}) bool { for _, m := range users { @@ -1015,25 +1089,25 @@ func TestDB_Chunk(t *testing.T) { return true }) - assert.NoError(t, err) + require.NoError(t, err) var initialSum int64 for _, mm := range batchUsers { if val, ok := mm["points"]; ok { initialSum += val.(int64) } } - assert.Equal(t, sumOfPoints, initialSum) + require.Equal(t, sumOfPoints, initialSum) _, err = db.Truncate(UsersTable) - assert.NoError(t, err) + require.NoError(t, err) } func TestDB_ChunkFalse(t *testing.T) { _, err := db.Truncate(UsersTable) - assert.NoError(t, err) + require.NoError(t, err) err = db.Table(UsersTable).InsertBatch(batchUsers) - assert.NoError(t, err) + require.NoError(t, err) var sumOfPoints int64 err = db.Table(UsersTable).Select("name", "points").Chunk(2, func(users []map[string]interface{}) bool { for _, m := range users { @@ -1047,19 +1121,19 @@ func TestDB_ChunkFalse(t *testing.T) { return true }) - assert.NoError(t, err) - assert.Equal(t, sumOfPoints, batchUsers[0]["points"].(int64)) + require.NoError(t, err) + require.Equal(t, sumOfPoints, batchUsers[0]["points"].(int64)) _, err = db.Truncate(UsersTable) - assert.NoError(t, err) + require.NoError(t, err) } func TestDB_ChunkLessThenAmount(t *testing.T) { _, err := db.Truncate(UsersTable) - assert.NoError(t, err) + require.NoError(t, err) err = db.Table(UsersTable).InsertBatch(batchUsers) - assert.NoError(t, err) + require.NoError(t, err) var sumOfPoints int64 err = db.Table(UsersTable).Select("name", "points").Chunk(int64(len(batchUsers)+1), func(users []map[string]interface{}) bool { @@ -1070,19 +1144,19 @@ func TestDB_ChunkLessThenAmount(t *testing.T) { } return true }) - assert.NoError(t, err) - assert.Greater(t, sumOfPoints, int64(0)) + require.NoError(t, err) + require.Greater(t, sumOfPoints, int64(0)) _, err = db.Truncate(UsersTable) - assert.NoError(t, err) + require.NoError(t, err) } func TestDB_ChunkLessThenZeroErr(t *testing.T) { _, err := db.Truncate(UsersTable) - assert.NoError(t, err) + require.NoError(t, err) err = db.Table(UsersTable).InsertBatch(batchUsers) - assert.NoError(t, err) + require.NoError(t, err) var sumOfPoints int64 err = db.Table(UsersTable).Select("name", "points").Chunk(int64(-1), func(users []map[string]interface{}) bool { @@ -1093,67 +1167,117 @@ func TestDB_ChunkLessThenZeroErr(t *testing.T) { } return true }) - assert.Errorf(t, err, "chunk can't be <= 0, your chunk is: -1") + require.Errorf(t, err, "chunk can't be <= 0, your chunk is: -1") _, err = db.Truncate(UsersTable) - assert.NoError(t, err) + require.NoError(t, err) } func TestDB_ChunkBuilderTableErr(t *testing.T) { _, err := db.Truncate(UsersTable) - assert.NoError(t, err) + require.NoError(t, err) // reset prev set up table as we don't want to use Table to produce err db.Builder.table = "" err = db.InsertBatch(batchUsers) - assert.Error(t, err, errTableCallBeforeOp) + require.Error(t, err, errTableCallBeforeOp) _, err = db.Select("foo", "bar", "baz").Get() - assert.Error(t, err, errTableCallBeforeOp) + require.Error(t, err, errTableCallBeforeOp) err = db.Insert(dataMap) - assert.Error(t, err, errTableCallBeforeOp) + require.Error(t, err, errTableCallBeforeOp) _, err = db.InsertGetId(dataMap) - assert.Error(t, err, errTableCallBeforeOp) + require.Error(t, err, errTableCallBeforeOp) _, err = db.Update(dataMap) - assert.Error(t, err, errTableCallBeforeOp) + require.Error(t, err, errTableCallBeforeOp) _, err = db.Delete() - assert.Error(t, err, errTableCallBeforeOp) + require.Error(t, err, errTableCallBeforeOp) _, err = db.Replace(dataMap, "id") - assert.Error(t, err, errTableCallBeforeOp) + require.Error(t, err, errTableCallBeforeOp) _, err = db.Increment("clmn", 123) - assert.Error(t, err, errTableCallBeforeOp) + require.Error(t, err, errTableCallBeforeOp) _, err = db.Exists() - assert.Error(t, err, errTableCallBeforeOp) + require.Error(t, err, errTableCallBeforeOp) - _, err = db.Table("nonexistent").Update(dataMap) - assert.Error(t, err) + _, err = db.Table("nonexistent").Update(data) + require.Error(t, err) _, err = db.Table("nonexistent").Delete() - assert.Error(t, err) + require.Error(t, err) _, err = db.Table("nonexistent").Increment("clmn", 123) - assert.Error(t, err) + require.Error(t, err) - _, err = db.Table("nonexistent").Replace(dataMap, "id") - assert.Error(t, err) + _, err = db.Table("nonexistent").Replace(data, "id") + require.Error(t, err) _, err = db.Truncate(UsersTable) - assert.NoError(t, err) + require.NoError(t, err) } func TestDB_FirsNoRecordsErr(t *testing.T) { _, err := db.Truncate(UsersTable) - assert.NoError(t, err) + require.NoError(t, err) _, err = db.Table(TestTable).Select("baz").OrderBy("baz", "desc").OrderBy("foo", "desc").First() - assert.Errorf(t, err, "no records were produced by query: %s") + require.Errorf(t, err, "no records were produced by query: %s") _, err = db.Truncate(UsersTable) - assert.NoError(t, err) + require.NoError(t, err) +} + +func TestDB_ScanStruct(t *testing.T) { + _, err := db.Truncate(TestTable) + require.NoError(t, err) + + err = db.Table(TestTable).Insert(data) + require.NoError(t, err) + + type TestStruct struct { + FooSomeCol string `db:"foo"` + Bar string `db:"bar"` + Baz int64 + Title string + } + + testStruct := TestStruct{} + err = db.Table(TestTable).Select("foo", "bar", "baz").ScanStruct(&testStruct) + require.NoError(t, err) + + require.Equal(t, testStruct.FooSomeCol, data.Foo) + require.Equal(t, testStruct.Bar, data.Bar) + require.Equal(t, testStruct.Baz, *data.Baz) + require.Empty(t, testStruct.Title) + + err = db.Table(TestTable).Insert(data) + require.NoError(t, err) + + var testStructs []TestStruct + err = db.Table(TestTable).Select("foo", "bar", "baz").EachToStruct(func(rows *sql.Rows) error { + err = db.Next(rows, &testStruct) + if err != nil { + return err + } + + testStructs = append(testStructs, testStruct) + return nil + }) + require.NoError(t, err) + + require.Len(t, testStructs, 2) + for _, val := range testStructs { + require.Equal(t, val.FooSomeCol, data.Foo) + require.Equal(t, val.Bar, data.Bar) + require.Equal(t, val.Baz, *data.Baz) + require.Empty(t, val.Title) + } + + _, err = db.Truncate(TestTable) + require.NoError(t, err) } diff --git a/factory.go b/factory.go index 111fb3f..7aba22d 100644 --- a/factory.go +++ b/factory.go @@ -1,12 +1,18 @@ package buildsqlx import ( + "database/sql" + "errors" "fmt" "log" + "reflect" "strconv" "strings" + "github.com/fatih/structs" "github.com/lib/pq" + "golang.org/x/text/cases" + "golang.org/x/text/language" ) const ( @@ -20,29 +26,32 @@ var ( errTransactionModeWithoutTx = fmt.Errorf("sql: there was no *sql.Tx object set properly") ) +type EachToStructFunc func(rows *sql.Rows) error + // Get builds all sql statements chained before and executes query collecting data to the slice +// Deprecated: this method will no longer be used in future releases, because of ScanStruct and EachToStruct replacement func (r *DB) Get() ([]map[string]any, error) { - builder := r.Builder - if builder.table == "" { + bldr := r.Builder + if bldr.table == "" { return nil, errTableCallBeforeOp } query := "" - if len(builder.union) > 0 { // got union - need different logic to glue - for _, uBuilder := range builder.union { + if len(bldr.union) > 0 { // got union - need different logic to glue + for _, uBuilder := range bldr.union { query += uBuilder + " UNION " - if builder.isUnionAll { + if bldr.isUnionAll { query += "ALL " } } - query += builder.buildSelect() + query += bldr.buildSelect() // clean union (all) after ensuring selects are built r.Builder.union = []string{} r.Builder.isUnionAll = false - } else { // std builder - query = builder.buildSelect() + } else { // std bldr + query = bldr.buildSelect() } rows, err := r.Sql().Query(query, prepareValues(r.Builder.whereBindings)...) @@ -65,8 +74,7 @@ func (r *DB) Get() ([]map[string]any, error) { valuePtrs[i] = &values[i] } - err := rows.Scan(valuePtrs...) - + err = rows.Scan(valuePtrs...) if err != nil { return nil, err } @@ -88,6 +96,228 @@ func (r *DB) Get() ([]map[string]any, error) { return res, nil } +// ScanStruct scans query into specific struct +func (r *DB) ScanStruct(src any) error { + if reflect.ValueOf(src).IsNil() { + return fmt.Errorf("cannot decode into nil type %T", src) + } + + sqlBuilder := r.Builder + if sqlBuilder.table == "" { + return errTableCallBeforeOp + } + + sqlBuilder.limit = 1 + query := "" + if len(sqlBuilder.union) > 0 { // got union - need different logic to glue + for _, uBuilder := range sqlBuilder.union { + query += uBuilder + " UNION " + + if sqlBuilder.isUnionAll { + query += "ALL " + } + } + + query += sqlBuilder.buildSelect() + // clean union (all) after ensuring selects are built + r.Builder.union = []string{} + r.Builder.isUnionAll = false + } else { // std builder + query = sqlBuilder.buildSelect() + } + + rows, err := r.Sql().Query(query, prepareValues(r.Builder.whereBindings)...) + if err != nil { + return err + } + + columns, _ := rows.Columns() + count := len(columns) + values := make([]any, count) + valuePtrs := make([]any, count) + + // resource is the actual value that ptr points to. + resource := reflect.ValueOf(src).Elem() + if err = validateFields(resource, src, columns); err != nil { + return err + } + + for rows.Next() { + for i := range columns { + valuePtrs[i] = &values[i] + } + + err = rows.Scan(valuePtrs...) + if err != nil { + return err + } + + for i, col := range columns { + val := values[i] + setResourceValue(resource, src, cases.Title(language.English).String(col), val) + } + + src = resource + } + + return nil +} + +// EachToStruct scans query into specific struct per row with iterative behaviour +func (r *DB) EachToStruct(fn EachToStructFunc) error { + sqlBuilder := r.Builder + if sqlBuilder.table == "" { + return errTableCallBeforeOp + } + + query := "" + if len(sqlBuilder.union) > 0 { // got union - need different logic to glue + for _, uBuilder := range sqlBuilder.union { + query += uBuilder + " UNION " + + if sqlBuilder.isUnionAll { + query += "ALL " + } + } + + query += sqlBuilder.buildSelect() + // clean union (all) after ensuring selects are built + r.Builder.union = []string{} + r.Builder.isUnionAll = false + } else { // std builder + query = sqlBuilder.buildSelect() + } + + rows, err := r.Sql().Query(query, prepareValues(r.Builder.whereBindings)...) + if err != nil { + return err + } + + for { + err = fn(rows) + if errors.Is(err, ErrNoMoreRows) { + return nil + } + + if err != nil { + return err + } + } +} + +// ErrNoMoreRows is returned by Next when there were no more rows +var ErrNoMoreRows = errors.New("sql: no more rows") + +// Next will parse the next row into a struct passed as src parameter. +// Returns ErrNoMoreRows if there are no more row to parse +func (r *DB) Next(rows *sql.Rows, src any) error { + if reflect.ValueOf(src).IsNil() { + return fmt.Errorf("cannot decode into nil type %T", src) + } + + columns, err := rows.Columns() + if err != nil { + return err + } + + count := len(columns) + // resource is the actual value that ptr points to. + resource := reflect.ValueOf(src).Elem() + if err = validateFields(resource, src, columns); err != nil { + return err + } + + values := make([]any, count) + valuePtrs := make([]any, count) + if rows.Next() { + for i := range columns { + valuePtrs[i] = &values[i] + } + + err = rows.Scan(valuePtrs...) + if err != nil { + return err + } + + for i, col := range columns { + val := values[i] + setResourceValue(resource, src, cases.Title(language.English).String(col), val) + } + src = resource + + return nil + } + + return ErrNoMoreRows +} + +func setResourceValue(resource reflect.Value, src any, col string, value any) { + if !resource.FieldByName(col).IsValid() { // try to get field by db: tag + fields := structs.Fields(src) + for i, f := range fields { + tag := f.Tag("db") + if tag == strings.ToLower(col) { + setValue(resource.Field(i), value) + return + } + } + } + + setValue(resource.FieldByName(col), value) +} + +func setValue(field reflect.Value, val any) { + if field.Kind() == reflect.Ptr { + newVal := reflect.New(field.Type().Elem()) + newVal.Elem().Set(reflect.ValueOf(val)) + field.Set(newVal) + + return + } + + switch v := val.(type) { + case string: + field.SetString(v) + case int: + field.SetInt(int64(v)) + case int64: + field.SetInt(v) + case float64: + field.SetFloat(v) + case uint64: + field.SetUint(v) + case nil: + field.SetPointer(nil) + } + + if reflect.TypeOf(val).Kind() == reflect.Ptr { + setValue(field, reflect.ValueOf(val).Elem().Interface()) + } +} + +func validateFields(resource reflect.Value, src any, columns []string) error { + for _, col := range columns { + foundColByTag := false + fieldName := cases.Title(language.English).String(col) + if !resource.FieldByName(fieldName).IsValid() { + fields := structs.Fields(src) + for _, f := range fields { + tag := f.Tag("db") + if tag == col { + foundColByTag = true + break + } + } + + if !foundColByTag { + return fmt.Errorf("field %s not found in struct", fieldName) + } + } + } + + return nil +} + func prepareValues(values []map[string]any) []any { var vls []any for _, v := range values { @@ -192,23 +422,22 @@ func composeOrderBy(orderBy []map[string]string, orderByRaw *string) string { return "" } -// Insert inserts one row with param bindings -func (r *DB) Insert(data map[string]any) error { +// Insert inserts one row with param bindings for struct +func (r *DB) Insert(data any) error { if r.Txn != nil { return r.Txn.Insert(data) } - builder := r.Builder - if builder.table == "" { + bldr := r.Builder + if bldr.table == "" { return errTableCallBeforeOp } - columns, values, bindings := prepareBindings(data) + columns, values, bindings := prepareBindingsForStruct(data) - query := `INSERT INTO "` + builder.table + `" (` + strings.Join(columns, `, `) + `) VALUES(` + strings.Join(bindings, `, `) + `)` + query := `INSERT INTO "` + bldr.table + `" (` + strings.Join(columns, `, `) + `) VALUES(` + strings.Join(bindings, `, `) + `)` _, err := r.Sql().Exec(query, values...) - if err != nil { return err } @@ -216,23 +445,23 @@ func (r *DB) Insert(data map[string]any) error { return nil } -// Insert inserts one row with param bindings -func (r *Txn) Insert(data map[string]any) error { +// Insert inserts one row with param bindings from struct +// in transaction context +func (r *Txn) Insert(data any) error { if r.Tx == nil { return errTransactionModeWithoutTx } - builder := r.Builder - if builder.table == "" { + bldr := r.Builder + if bldr.table == "" { return errTableCallBeforeOp } - columns, values, bindings := prepareBindings(data) + columns, values, bindings := prepareBindingsForStruct(data) - query := `INSERT INTO "` + builder.table + `" (` + strings.Join(columns, `, `) + `) VALUES(` + strings.Join(bindings, `, `) + `)` + query := `INSERT INTO "` + bldr.table + `" (` + strings.Join(columns, `, `) + `) VALUES(` + strings.Join(bindings, `, `) + `)` _, err := r.Tx.Exec(query, values...) - if err != nil { return err } @@ -241,19 +470,19 @@ func (r *Txn) Insert(data map[string]any) error { } // InsertGetId inserts one row with param bindings and returning id -func (r *DB) InsertGetId(data map[string]any) (uint64, error) { +func (r *DB) InsertGetId(data any) (uint64, error) { if r.Txn != nil { return r.Txn.InsertGetId(data) } - builder := r.Builder - if builder.table == "" { + bldr := r.Builder + if bldr.table == "" { return 0, errTableCallBeforeOp } - columns, values, bindings := prepareBindings(data) + columns, values, bindings := prepareBindingsForStruct(data) - query := `INSERT INTO "` + builder.table + `" (` + strings.Join(columns, `, `) + `) VALUES(` + strings.Join(bindings, `, `) + `) RETURNING id` + query := `INSERT INTO "` + bldr.table + `" (` + strings.Join(columns, `, `) + `) VALUES(` + strings.Join(bindings, `, `) + `) RETURNING id` var id uint64 err := r.Sql().QueryRow(query, values...).Scan(&id) @@ -266,19 +495,20 @@ func (r *DB) InsertGetId(data map[string]any) (uint64, error) { } // InsertGetId inserts one row with param bindings and returning id -func (r *Txn) InsertGetId(data map[string]any) (uint64, error) { +// in transaction context +func (r *Txn) InsertGetId(data any) (uint64, error) { if r.Tx == nil { return 0, errTransactionModeWithoutTx } - builder := r.Builder - if builder.table == "" { + bldr := r.Builder + if bldr.table == "" { return 0, errTableCallBeforeOp } - columns, values, bindings := prepareBindings(data) + columns, values, bindings := prepareBindingsForStruct(data) - query := `INSERT INTO "` + builder.table + `" (` + strings.Join(columns, `, `) + `) VALUES(` + strings.Join(bindings, `, `) + `) RETURNING id` + query := `INSERT INTO "` + bldr.table + `" (` + strings.Join(columns, `, `) + `) VALUES(` + strings.Join(bindings, `, `) + `) RETURNING id` var id uint64 err := r.Tx.QueryRow(query, values...).Scan(&id) @@ -290,15 +520,31 @@ func (r *Txn) InsertGetId(data map[string]any) (uint64, error) { return id, nil } +func prepareValuesForStruct(value reflect.Value) []any { + var values []any + switch value.Kind() { + case reflect.String: + values = append(values, value.String()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + values = append(values, strconv.FormatInt(value.Int(), 10)) + case reflect.Float32, reflect.Float64: + values = append(values, fmt.Sprintf("%g", value.Float())) + case reflect.Ptr: + if value.IsNil() { + values = append(values, nil) + } else { + values = prepareValuesForStruct(value.Elem()) + } + } + + return values +} + func prepareValue(value any) []any { var values []any switch v := value.(type) { case string: - //if where { // todo: left comments for further exploration, probably incorrect behaviour for pg driver - // values = append(values, "'"+v+"'") - //} else { values = append(values, v) - //} case int: values = append(values, strconv.FormatInt(int64(v), 10)) case float64: @@ -341,6 +587,44 @@ func prepareBindings(data map[string]any) (columns []string, values []any, bindi return } +// prepareBindingsForStruct prepares all bindings for SQL-query +func prepareBindingsForStruct(data any) (columns []string, values []any, bindings []string) { + j := 1 + resource := reflect.ValueOf(data) + t := reflect.TypeOf(data) + for i := 0; i < t.NumField(); i++ { + value := resource.Field(i) + col := getColumn(t.Field(i)) + + if strings.Contains(col, sqlOperatorIs) || strings.Contains(col, sqlOperatorBetween) { + continue + } + + columns = append(columns, col) + pValues := prepareValuesForStruct(value) + if len(pValues) > 0 { + values = append(values, pValues...) + + for range pValues { + bindings = append(bindings, "$"+strconv.FormatInt(int64(j), 10)) + j++ + } + } + } + + return +} + +// getColumn gets column name and value +func getColumn(structField reflect.StructField) string { + col := strings.ToLower(structField.Name) + if structField.Tag.Get("db") != "" { + col = structField.Tag.Get("db") + } + + return col +} + // InsertBatch inserts multiple rows based on transaction func (r *DB) InsertBatch(data []map[string]any) error { builder := r.Builder @@ -422,17 +706,17 @@ func prepareInsertBatch(data []map[string]any) (columns []string, values [][]any // Update builds an UPDATE sql stmt with corresponding where/from clauses if stated // returning affected rows -func (r *DB) Update(data map[string]any) (int64, error) { +func (r *DB) Update(data any) (int64, error) { if r.Txn != nil { return r.Txn.Update(data) } - builder := r.Builder - if builder.table == "" { + bldr := r.Builder + if bldr.table == "" { return 0, errTableCallBeforeOp } - columns, values, bindings := prepareBindings(data) + columns, values, bindings := prepareBindingsForStruct(data) setVal := "" l := len(columns) for k, col := range columns { @@ -460,17 +744,17 @@ func (r *DB) Update(data map[string]any) (int64, error) { // Update builds an UPDATE sql stmt with corresponding where/from clauses if stated // returning affected rows -func (r *Txn) Update(data map[string]any) (int64, error) { +func (r *Txn) Update(data any) (int64, error) { if r.Tx == nil { return 0, errTransactionModeWithoutTx } - builder := r.Builder - if builder.table == "" { + bldr := r.Builder + if bldr.table == "" { return 0, errTableCallBeforeOp } - columns, values, bindings := prepareBindings(data) + columns, values, bindings := prepareBindingsForStruct(data) setVal := "" l := len(columns) for k, col := range columns { @@ -503,8 +787,8 @@ func (r *DB) Delete() (int64, error) { return r.Txn.Delete() } - builder := r.Builder - if builder.table == "" { + bldr := r.Builder + if bldr.table == "" { return 0, errTableCallBeforeOp } @@ -525,8 +809,8 @@ func (r *Txn) Delete() (int64, error) { return 0, errTransactionModeWithoutTx } - builder := r.Builder - if builder.table == "" { + bldr := r.Builder + if bldr.table == "" { return 0, errTableCallBeforeOp } @@ -541,18 +825,18 @@ func (r *Txn) Delete() (int64, error) { } // Replace inserts data if conflicting row hasn't been found, else it will update an existing one -func (r *DB) Replace(data map[string]any, conflict string) (int64, error) { +func (r *DB) Replace(data any, conflict string) (int64, error) { if r.Txn != nil { return r.Txn.Replace(data, conflict) } - builder := r.Builder - if builder.table == "" { + bldr := r.Builder + if bldr.table == "" { return 0, errTableCallBeforeOp } - columns, values, bindings := prepareBindings(data) - query := `INSERT INTO "` + builder.table + `" (` + strings.Join(columns, `, `) + `) VALUES(` + strings.Join(bindings, `, `) + `) ON CONFLICT(` + conflict + `) DO UPDATE SET ` + columns, values, bindings := prepareBindingsForStruct(data) + query := `INSERT INTO "` + bldr.table + `" (` + strings.Join(columns, `, `) + `) VALUES(` + strings.Join(bindings, `, `) + `) ON CONFLICT(` + conflict + `) DO UPDATE SET ` for i, v := range columns { columns[i] = v + " = excluded." + v } @@ -567,18 +851,18 @@ func (r *DB) Replace(data map[string]any, conflict string) (int64, error) { } // Replace inserts data if conflicting row hasn't been found, else it will update an existing one -func (r *Txn) Replace(data map[string]any, conflict string) (int64, error) { +func (r *Txn) Replace(data any, conflict string) (int64, error) { if r.Tx == nil { return 0, errTransactionModeWithoutTx } - builder := r.Builder - if builder.table == "" { + bldr := r.Builder + if bldr.table == "" { return 0, errTableCallBeforeOp } - columns, values, bindings := prepareBindings(data) - query := `INSERT INTO "` + builder.table + `" (` + strings.Join(columns, `, `) + `) VALUES(` + strings.Join(bindings, `, `) + `) ON CONFLICT(` + conflict + `) DO UPDATE SET ` + columns, values, bindings := prepareBindingsForStruct(data) + query := `INSERT INTO "` + bldr.table + `" (` + strings.Join(columns, `, `) + `) VALUES(` + strings.Join(bindings, `, `) + `) ON CONFLICT(` + conflict + `) DO UPDATE SET ` for i, v := range columns { columns[i] = v + " = excluded." + v } diff --git a/go.mod b/go.mod index b422ff5..896b775 100644 --- a/go.mod +++ b/go.mod @@ -3,8 +3,10 @@ module github.com/arthurkushman/buildsqlx go 1.18 require ( + github.com/fatih/structs v1.1.0 github.com/lib/pq v1.2.0 github.com/stretchr/testify v1.4.0 + golang.org/x/text v0.13.0 ) require ( diff --git a/go.sum b/go.sum index 5d1e726..0bcc447 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,7 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/fatih/structs v1.1.0 h1:Q7juDM0QtcnhCpeyLGQKyg4TOIghuNXrkL32pHAUMxo= +github.com/fatih/structs v1.1.0/go.mod h1:9NiDSp5zOcgEDl+j00MP/WkGVPOlPRLejGD8Ga6PJ7M= github.com/lib/pq v1.2.0 h1:LXpIM/LZ5xGFhOpXAQUIMM1HdyqzVYM13zNdjCEEcA0= github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= @@ -7,6 +9,8 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k= +golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= diff --git a/schema_test.go b/schema_test.go index 31b7c7b..8cb44cd 100644 --- a/schema_test.go +++ b/schema_test.go @@ -4,14 +4,14 @@ import ( "errors" "testing" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) const TableToCreate = "big_tbl" func TestDB_CreateEmptyTable(t *testing.T) { _, err := db.DropIfExists(TableToCreate) - assert.NoError(t, err) + require.NoError(t, err) _, err = db.Schema(TableToCreate, func(table *Table) error { return nil @@ -20,7 +20,7 @@ func TestDB_CreateEmptyTable(t *testing.T) { func TestDB_CreateTable(t *testing.T) { _, err := db.DropIfExists(TableToCreate) - assert.NoError(t, err) + require.NoError(t, err) _, err = db.Schema(TableToCreate, func(table *Table) error { table.Increments("id") @@ -40,11 +40,11 @@ func TestDB_CreateTable(t *testing.T) { return nil }) - assert.NoError(t, err) + require.NoError(t, err) is, err := db.HasTable("public", TableToCreate) - assert.NoError(t, err) - assert.True(t, is) + require.NoError(t, err) + require.True(t, is) _, err = db.Schema("tbl_to_ref", func(table *Table) error { table.Increments("id") @@ -52,25 +52,25 @@ func TestDB_CreateTable(t *testing.T) { return nil }) - assert.NoError(t, err) + require.NoError(t, err) // test some err returning from fn() _, err = db.Schema(TableToCreate, func(table *Table) error { return errors.New("some err") }) - assert.Error(t, err) + require.Error(t, err) // 1st drop the referencing tbl _, err = db.Drop("tbl_to_ref") - assert.NoError(t, err) + require.NoError(t, err) // then referenced _, err = db.Drop(TableToCreate) - assert.NoError(t, err) + require.NoError(t, err) } func TestTable_BigIncrements(t *testing.T) { _, err := db.DropIfExists(TableToCreate) - assert.NoError(t, err) + require.NoError(t, err) res, err := db.SchemaIfNotExists(TableToCreate, func(table *Table) error { table.BigIncrements("id") @@ -79,14 +79,14 @@ func TestTable_BigIncrements(t *testing.T) { return nil }) - assert.NoError(t, err) + require.NoError(t, err) _, err = res.RowsAffected() - assert.NoError(t, err) + require.NoError(t, err) is, err := db.HasTable("public", TableToCreate) - assert.NoError(t, err) - assert.True(t, is) + require.NoError(t, err) + require.True(t, is) // test add columns _, err = db.Schema(TableToCreate, func(table *Table) error { @@ -97,11 +97,11 @@ func TestTable_BigIncrements(t *testing.T) { return nil }) - assert.NoError(t, err) + require.NoError(t, err) isCol, err := db.HasColumns("public", TableToCreate, "title") - assert.NoError(t, err) - assert.True(t, isCol) + require.NoError(t, err) + require.True(t, isCol) // test modify the column _, err = db.Schema(TableToCreate, func(table *Table) error { @@ -109,7 +109,7 @@ func TestTable_BigIncrements(t *testing.T) { return nil }) - assert.NoError(t, err) + require.NoError(t, err) // test drop the column _, err = db.Schema(TableToCreate, func(table *Table) error { @@ -118,19 +118,19 @@ func TestTable_BigIncrements(t *testing.T) { return nil }) - assert.NoError(t, err) + require.NoError(t, err) isCol, err = db.HasColumns("public", TableToCreate, "title") - assert.NoError(t, err) - assert.False(t, isCol) + require.NoError(t, err) + require.False(t, isCol) _, err = db.Drop(TableToCreate) - assert.NoError(t, err) + require.NoError(t, err) } func TestTable_DateTime(t *testing.T) { _, err := db.DropIfExists(TableToCreate) - assert.NoError(t, err) + require.NoError(t, err) _, err = db.Schema(TableToCreate, func(table *Table) error { table.Increments("id") @@ -143,11 +143,11 @@ func TestTable_DateTime(t *testing.T) { return nil }) - assert.NoError(t, err) + require.NoError(t, err) is, err := db.HasTable("public", TableToCreate) - assert.NoError(t, err) - assert.True(t, is) + require.NoError(t, err) + require.True(t, is) // test modify the column _, err = db.Schema(TableToCreate, func(table *Table) error { @@ -156,12 +156,12 @@ func TestTable_DateTime(t *testing.T) { return nil }) - assert.NoError(t, err) + require.NoError(t, err) isCol, err := db.HasColumns("public", TableToCreate, "options") - assert.NoError(t, err) - assert.True(t, isCol) + require.NoError(t, err) + require.True(t, isCol) _, err = db.Drop(TableToCreate) - assert.NoError(t, err) + require.NoError(t, err) }