From 8e6f3be93c1366405496c41d07812aabb163c3b1 Mon Sep 17 00:00:00 2001 From: Vladimir Mihailenco Date: Thu, 14 Oct 2021 14:22:08 +0300 Subject: [PATCH] chore: check if table exists and return proper message --- internal/dbtest/db_test.go | 22 +++++++++++++++------- query_base.go | 5 +++++ 2 files changed, 20 insertions(+), 7 deletions(-) diff --git a/internal/dbtest/db_test.go b/internal/dbtest/db_test.go index cf54c4f21..0e05e625f 100644 --- a/internal/dbtest/db_test.go +++ b/internal/dbtest/db_test.go @@ -24,7 +24,6 @@ import ( _ "github.com/go-sql-driver/mysql" _ "github.com/jackc/pgx/v4/stdlib" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -56,7 +55,7 @@ func pg(tb testing.TB) *bun.DB { sqldb := sql.OpenDB(pgdriver.NewConnector(pgdriver.WithDSN(dsn))) tb.Cleanup(func() { - assert.NoError(tb, sqldb.Close()) + require.NoError(tb, sqldb.Close()) }) db := bun.NewDB(sqldb, pgdialect.New()) @@ -78,7 +77,7 @@ func pgx(tb testing.TB) *bun.DB { sqldb, err := sql.Open("pgx", dsn) require.NoError(tb, err) tb.Cleanup(func() { - assert.NoError(tb, sqldb.Close()) + require.NoError(tb, sqldb.Close()) }) db := bun.NewDB(sqldb, pgdialect.New()) @@ -100,7 +99,7 @@ func mysql8(tb testing.TB) *bun.DB { sqldb, err := sql.Open("mysql", dsn) require.NoError(tb, err) tb.Cleanup(func() { - assert.NoError(tb, sqldb.Close()) + require.NoError(tb, sqldb.Close()) }) db := bun.NewDB(sqldb, mysqldialect.New()) @@ -122,7 +121,7 @@ func mysql5(tb testing.TB) *bun.DB { sqldb, err := sql.Open("mysql", dsn) require.NoError(tb, err) tb.Cleanup(func() { - assert.NoError(tb, sqldb.Close()) + require.NoError(tb, sqldb.Close()) }) db := bun.NewDB(sqldb, mysqldialect.New()) @@ -144,7 +143,7 @@ func mariadb(tb testing.TB) *bun.DB { sqldb, err := sql.Open("mysql", dsn) require.NoError(tb, err) tb.Cleanup(func() { - assert.NoError(tb, sqldb.Close()) + require.NoError(tb, sqldb.Close()) }) db := bun.NewDB(sqldb, mysqldialect.New()) @@ -161,7 +160,7 @@ func sqlite(tb testing.TB) *bun.DB { sqldb, err := sql.Open(sqliteshim.DriverName(), filepath.Join(tb.TempDir(), "sqlite.db")) require.NoError(tb, err) tb.Cleanup(func() { - assert.NoError(tb, sqldb.Close()) + require.NoError(tb, sqldb.Close()) }) db := bun.NewDB(sqldb, sqlitedialect.New()) @@ -226,6 +225,7 @@ func TestDB(t *testing.T) { {testPointers}, {testExists}, {testScanTimeIntoString}, + {testModelNonPointer}, } testEachDB(t, func(t *testing.T, dbName string, db *bun.DB) { @@ -900,3 +900,11 @@ func testScanTimeIntoString(t *testing.T, db *bun.DB) { require.NoError(t, err) require.NotZero(t, str) } + +func testModelNonPointer(t *testing.T, db *bun.DB) { + type Model struct{} + + _, err := db.NewInsert().Model(Model{}).ExcludeColumn("id").Returning("id").Exec(ctx) + require.Error(t, err) + require.Equal(t, "bun: Model(non-pointer dbtest_test.Model)", err.Error()) +} diff --git a/query_base.go b/query_base.go index 60226890f..8b11c6cef 100644 --- a/query_base.go +++ b/query_base.go @@ -260,6 +260,11 @@ func (q *baseQuery) addColumn(column schema.QueryWithArgs) { } func (q *baseQuery) excludeColumn(columns []string) { + if q.table == nil { + q.setErr(errNilModel) + return + } + if q.columns == nil { for _, f := range q.table.Fields { q.columns = append(q.columns, schema.UnsafeIdent(f.Name))