Skip to content

Commit

Permalink
chore: check if table exists and return proper message
Browse files Browse the repository at this point in the history
  • Loading branch information
vmihailenco committed Oct 14, 2021
1 parent 2ede649 commit 8e6f3be
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 7 deletions.
22 changes: 15 additions & 7 deletions internal/dbtest/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import (

_ "github.com/go-sql-driver/mysql"
_ "github.com/jackc/pgx/v4/stdlib"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -56,7 +55,7 @@ func pg(tb testing.TB) *bun.DB {

sqldb := sql.OpenDB(pgdriver.NewConnector(pgdriver.WithDSN(dsn)))
tb.Cleanup(func() {
assert.NoError(tb, sqldb.Close())
require.NoError(tb, sqldb.Close())
})

db := bun.NewDB(sqldb, pgdialect.New())
Expand All @@ -78,7 +77,7 @@ func pgx(tb testing.TB) *bun.DB {
sqldb, err := sql.Open("pgx", dsn)
require.NoError(tb, err)
tb.Cleanup(func() {
assert.NoError(tb, sqldb.Close())
require.NoError(tb, sqldb.Close())
})

db := bun.NewDB(sqldb, pgdialect.New())
Expand All @@ -100,7 +99,7 @@ func mysql8(tb testing.TB) *bun.DB {
sqldb, err := sql.Open("mysql", dsn)
require.NoError(tb, err)
tb.Cleanup(func() {
assert.NoError(tb, sqldb.Close())
require.NoError(tb, sqldb.Close())
})

db := bun.NewDB(sqldb, mysqldialect.New())
Expand All @@ -122,7 +121,7 @@ func mysql5(tb testing.TB) *bun.DB {
sqldb, err := sql.Open("mysql", dsn)
require.NoError(tb, err)
tb.Cleanup(func() {
assert.NoError(tb, sqldb.Close())
require.NoError(tb, sqldb.Close())
})

db := bun.NewDB(sqldb, mysqldialect.New())
Expand All @@ -144,7 +143,7 @@ func mariadb(tb testing.TB) *bun.DB {
sqldb, err := sql.Open("mysql", dsn)
require.NoError(tb, err)
tb.Cleanup(func() {
assert.NoError(tb, sqldb.Close())
require.NoError(tb, sqldb.Close())
})

db := bun.NewDB(sqldb, mysqldialect.New())
Expand All @@ -161,7 +160,7 @@ func sqlite(tb testing.TB) *bun.DB {
sqldb, err := sql.Open(sqliteshim.DriverName(), filepath.Join(tb.TempDir(), "sqlite.db"))
require.NoError(tb, err)
tb.Cleanup(func() {
assert.NoError(tb, sqldb.Close())
require.NoError(tb, sqldb.Close())
})

db := bun.NewDB(sqldb, sqlitedialect.New())
Expand Down Expand Up @@ -226,6 +225,7 @@ func TestDB(t *testing.T) {
{testPointers},
{testExists},
{testScanTimeIntoString},
{testModelNonPointer},
}

testEachDB(t, func(t *testing.T, dbName string, db *bun.DB) {
Expand Down Expand Up @@ -900,3 +900,11 @@ func testScanTimeIntoString(t *testing.T, db *bun.DB) {
require.NoError(t, err)
require.NotZero(t, str)
}

func testModelNonPointer(t *testing.T, db *bun.DB) {
type Model struct{}

_, err := db.NewInsert().Model(Model{}).ExcludeColumn("id").Returning("id").Exec(ctx)
require.Error(t, err)
require.Equal(t, "bun: Model(non-pointer dbtest_test.Model)", err.Error())
}
5 changes: 5 additions & 0 deletions query_base.go
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,11 @@ func (q *baseQuery) addColumn(column schema.QueryWithArgs) {
}

func (q *baseQuery) excludeColumn(columns []string) {
if q.table == nil {
q.setErr(errNilModel)
return
}

if q.columns == nil {
for _, f := range q.table.Fields {
q.columns = append(q.columns, schema.UnsafeIdent(f.Name))
Expand Down

0 comments on commit 8e6f3be

Please sign in to comment.