Skip to content

Commit

Permalink
Pass context
Browse files Browse the repository at this point in the history
  • Loading branch information
kaklakariada committed Jun 28, 2024
1 parent ad6bdad commit af423b3
Show file tree
Hide file tree
Showing 7 changed files with 81 additions and 56 deletions.
8 changes: 7 additions & 1 deletion pkg/connection/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,12 @@ func (c *Connection) createStatement(ctx context.Context, result *types.CreatePr
return NewStatement(ctx, c, result)
}

func (c *Connection) Ping(ctx context.Context) error {
fmt.Printf("Ping\n")
// FIXME
return nil
}

func (c *Connection) Prepare(query string) (driver.Stmt, error) {
return c.PrepareContext(c.Ctx, query)
}
Expand All @@ -102,7 +108,7 @@ func (c *Connection) Begin() (driver.Tx, error) {
if c.Config.Autocommit {
return nil, errors.ErrAutocommitEnabled
}
return NewTransaction(c), nil
return NewTransaction(c.Ctx, c), nil
}

func (c *Connection) query(ctx context.Context, query string, args []driver.Value) (driver.Rows, error) {
Expand Down
56 changes: 29 additions & 27 deletions pkg/connection/connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ func mockExceptionError(exception types.Exception) string {
type ConnectionTestSuite struct {
suite.Suite
websocketMock *wsconn.WebsocketConnectionMock
ctx context.Context
}

func TestConnectionSuite(t *testing.T) {
Expand All @@ -30,20 +31,21 @@ func TestConnectionSuite(t *testing.T) {

func (suite *ConnectionTestSuite) SetupTest() {
suite.websocketMock = wsconn.CreateWebsocketConnectionMock()
suite.ctx = context.Background()
}

func (suite *ConnectionTestSuite) TestConnectFails() {
conn := &Connection{
Config: &config.Config{Host: "invalid", Port: 12345},
Ctx: context.Background(),
Ctx: suite.ctx,
IsClosed: true,
}
err := conn.Connect()
suite.ErrorContains(err, `failed to connect to URL "ws://invalid:12345": dial tcp`)
}

func (suite *ConnectionTestSuite) TestQueryContextNamedParametersNotSupported() {
rows, err := suite.createOpenConnection().QueryContext(context.Background(), "query", []driver.NamedValue{{Name: "arg", Ordinal: 1, Value: "value"}})
rows, err := suite.createOpenConnection().QueryContext(suite.ctx, "query", []driver.NamedValue{{Name: "arg", Ordinal: 1, Value: "value"}})
suite.EqualError(err, "E-EGOD-7: named parameters not supported")
suite.Nil(rows)
}
Expand All @@ -66,7 +68,7 @@ func (suite *ConnectionTestSuite) TestQueryContext() {
types.SqlQueryResponseResultSet{ResultType: "resultType", ResultSet: types.SqlQueryResponseResultSetData{}})
suite.websocketMock.SimulateOKResponse(types.ClosePreparedStatementCommand{Command: types.Command{Command: "closePreparedStatement"}, StatementHandle: 0, Attributes: types.Attributes{}}, nil)

rows, err := suite.createOpenConnection().QueryContext(context.Background(), "query", []driver.NamedValue{{Ordinal: 1, Value: "value"}})
rows, err := suite.createOpenConnection().QueryContext(suite.ctx, "query", []driver.NamedValue{{Ordinal: 1, Value: "value"}})
suite.NoError(err)
suite.Equal([]string{}, rows.Columns())
}
Expand Down Expand Up @@ -95,7 +97,7 @@ func (suite *ConnectionTestSuite) TestQuery() {
}

func (suite *ConnectionTestSuite) TestExecContextNamedParametersNotSupported() {
rows, err := suite.createOpenConnection().ExecContext(context.Background(), "query", []driver.NamedValue{{Name: "arg", Ordinal: 1, Value: "value"}})
rows, err := suite.createOpenConnection().ExecContext(suite.ctx, "query", []driver.NamedValue{{Name: "arg", Ordinal: 1, Value: "value"}})
suite.EqualError(err, "E-EGOD-7: named parameters not supported")
suite.Nil(rows)
}
Expand All @@ -118,7 +120,7 @@ func (suite *ConnectionTestSuite) TestExecContext() {
types.SqlQueryResponseRowCount{ResultType: "resultType", RowCount: 42})
suite.websocketMock.SimulateOKResponse(types.ClosePreparedStatementCommand{Command: types.Command{Command: "closePreparedStatement"}, StatementHandle: 0, Attributes: types.Attributes{}}, nil)

rows, err := suite.createOpenConnection().ExecContext(context.Background(), "query", []driver.NamedValue{{Ordinal: 1, Value: "value"}})
rows, err := suite.createOpenConnection().ExecContext(suite.ctx, "query", []driver.NamedValue{{Ordinal: 1, Value: "value"}})
suite.NoError(err)
rowsAffected, err := rows.RowsAffected()
suite.NoError(err)
Expand Down Expand Up @@ -153,7 +155,7 @@ func (suite *ConnectionTestSuite) TestExec() {
func (suite *ConnectionTestSuite) TestPrepareContextFailsClosed() {
conn := suite.createOpenConnection()
conn.IsClosed = true
stmt, err := conn.PrepareContext(context.Background(), "query")
stmt, err := conn.PrepareContext(suite.ctx, "query")
suite.EqualError(err, driver.ErrBadConn.Error())
suite.Nil(stmt)
}
Expand All @@ -166,7 +168,7 @@ func (suite *ConnectionTestSuite) TestPrepareContextPreparedStatementFails() {
Attributes: types.Attributes{},
},
mockException)
stmt, err := suite.createOpenConnection().PrepareContext(context.Background(), "query")
stmt, err := suite.createOpenConnection().PrepareContext(suite.ctx, "query")
suite.EqualError(err, mockExceptionError(mockException))
suite.Nil(stmt)
}
Expand All @@ -180,7 +182,7 @@ func (suite *ConnectionTestSuite) TestPrepareContextSuccess() {
},
types.CreatePreparedStatementResponse{
ParameterData: types.ParameterData{Columns: []types.SqlQueryColumn{{Name: "col", DataType: types.SqlQueryColumnType{Type: "type"}}}}})
stmt, err := suite.createOpenConnection().PrepareContext(context.Background(), "query")
stmt, err := suite.createOpenConnection().PrepareContext(suite.ctx, "query")
suite.NoError(err)
suite.NotNil(stmt)
}
Expand Down Expand Up @@ -248,7 +250,7 @@ func (suite *ConnectionTestSuite) TestBeginFailsWithAutocommitEnabled() {
func (suite *ConnectionTestSuite) TestQueryFailsConnectionClosed() {
conn := suite.createOpenConnection()
conn.IsClosed = true
rows, err := conn.query(context.Background(), "query", nil)
rows, err := conn.query(suite.ctx, "query", nil)
suite.EqualError(err, driver.ErrBadConn.Error())
suite.Nil(rows)
}
Expand All @@ -257,7 +259,7 @@ func (suite *ConnectionTestSuite) TestQueryNoArgs() {
suite.websocketMock.SimulateSQLQueriesResponse(
types.SqlCommand{Command: types.Command{Command: "execute"}, SQLText: "query", Attributes: types.Attributes{}},
types.SqlQueryResponseResultSet{ResultType: "resultType", ResultSet: types.SqlQueryResponseResultSetData{}})
rows, err := suite.createOpenConnection().query(context.Background(), "query", []driver.Value{})
rows, err := suite.createOpenConnection().query(suite.ctx, "query", []driver.Value{})
suite.NoError(err)
suite.NotNil(rows)
}
Expand All @@ -266,7 +268,7 @@ func (suite *ConnectionTestSuite) TestQueryNoArgsFails() {
suite.websocketMock.SimulateErrorResponse(
types.SqlCommand{Command: types.Command{Command: "execute"}, SQLText: "query", Attributes: types.Attributes{}},
mockException)
rows, err := suite.createOpenConnection().query(context.Background(), "query", []driver.Value{})
rows, err := suite.createOpenConnection().query(suite.ctx, "query", []driver.Value{})
suite.EqualError(err, mockExceptionError(mockException))
suite.Nil(rows)
}
Expand All @@ -289,7 +291,7 @@ func (suite *ConnectionTestSuite) TestQueryWithArgs() {
types.SqlQueryResponseResultSet{ResultType: "resultType", ResultSet: types.SqlQueryResponseResultSetData{}})
suite.websocketMock.SimulateOKResponse(types.ClosePreparedStatementCommand{Command: types.Command{Command: "closePreparedStatement"}, StatementHandle: 0, Attributes: types.Attributes{}}, nil)

rows, err := suite.createOpenConnection().query(context.Background(), "query", []driver.Value{"value"})
rows, err := suite.createOpenConnection().query(suite.ctx, "query", []driver.Value{"value"})
suite.NoError(err)
suite.NotNil(rows)
}
Expand All @@ -303,7 +305,7 @@ func (suite *ConnectionTestSuite) TestQueryWithArgsFailsInPrepare() {
},
mockException)

rows, err := suite.createOpenConnection().query(context.Background(), "query", []driver.Value{"value"})
rows, err := suite.createOpenConnection().query(suite.ctx, "query", []driver.Value{"value"})
suite.EqualError(err, mockExceptionError(mockException))
suite.Nil(rows)
}
Expand All @@ -325,7 +327,7 @@ func (suite *ConnectionTestSuite) TestQueryWithArgsFailsInExecute() {
},
mockException)

rows, err := suite.createOpenConnection().query(context.Background(), "query", []driver.Value{"value"})
rows, err := suite.createOpenConnection().query(suite.ctx, "query", []driver.Value{"value"})
suite.EqualError(err, mockExceptionError(mockException))
suite.Nil(rows)
}
Expand All @@ -334,15 +336,15 @@ func (suite *ConnectionTestSuite) TestPasswordLoginFailsInitialRequest() {
suite.websocketMock.SimulateErrorResponse(types.LoginCommand{Command: types.Command{Command: "login"}, ProtocolVersion: 42},
mockException)
conn := suite.createOpenConnection()
err := conn.Login(context.Background())
err := conn.Login(suite.ctx)
suite.EqualError(err, mockExceptionError(mockException))
}

func (suite *ConnectionTestSuite) TestPasswordLoginFailsEncryptingPasswordRequest() {
suite.websocketMock.SimulateOKResponse(types.LoginCommand{Command: types.Command{Command: "login"}, ProtocolVersion: 42},
types.PublicKeyResponse{PublicKeyPem: "", PublicKeyModulus: "", PublicKeyExponent: ""})
conn := suite.createOpenConnection()
err := conn.Login(context.Background())
err := conn.Login(suite.ctx)
suite.EqualError(err, driver.ErrBadConn.Error())
}

Expand All @@ -352,7 +354,7 @@ func (suite *ConnectionTestSuite) TestPasswordLoginSuccess() {
conn.IsClosed = true

suite.True(conn.IsClosed)
err := conn.Login(context.Background())
err := conn.Login(suite.ctx)
suite.False(conn.IsClosed)
suite.NoError(err)
}
Expand All @@ -364,7 +366,7 @@ func (suite *ConnectionTestSuite) TestAccessTokenLoginSuccess() {
conn.Config.AccessToken = "accessToken"

suite.True(conn.IsClosed)
err := conn.Login(context.Background())
err := conn.Login(suite.ctx)
suite.False(conn.IsClosed)
suite.NoError(err)
}
Expand All @@ -376,7 +378,7 @@ func (suite *ConnectionTestSuite) TestAccessTokenLoginPrepareFails() {
conn.Config.AccessToken = "accessToken"

suite.True(conn.IsClosed)
err := conn.Login(context.Background())
err := conn.Login(suite.ctx)
suite.True(conn.IsClosed)
suite.EqualError(err, "access token login failed: E-EGOD-11: execution failed with SQL error code 'mock sql code' and message 'mock error'")
}
Expand All @@ -388,7 +390,7 @@ func (suite *ConnectionTestSuite) TestRefreshTokenLoginSuccess() {
conn.Config.RefreshToken = "refreshToken"

suite.True(conn.IsClosed)
err := conn.Login(context.Background())
err := conn.Login(suite.ctx)
suite.False(conn.IsClosed)
suite.NoError(err)
}
Expand All @@ -400,7 +402,7 @@ func (suite *ConnectionTestSuite) TestRefreshTokenLoginPrepareFails() {
conn.Config.RefreshToken = "refreshToken"

suite.True(conn.IsClosed)
err := conn.Login(context.Background())
err := conn.Login(suite.ctx)
suite.True(conn.IsClosed)
suite.EqualError(err, "refresh token login failed: E-EGOD-11: execution failed with SQL error code 'mock sql code' and message 'mock error'")
}
Expand All @@ -410,7 +412,7 @@ func (suite *ConnectionTestSuite) TestLoginRestoresCompressionToTrue() {
conn := suite.createOpenConnection()
conn.Config.Compression = true

err := conn.Login(context.Background())
err := conn.Login(suite.ctx)
suite.True(conn.Config.Compression)
suite.NoError(err)
}
Expand All @@ -419,7 +421,7 @@ func (suite *ConnectionTestSuite) TestLoginRestoresCompressionToFalse() {
conn := suite.createOpenConnection()
conn.Config.Compression = false

err := conn.Login(context.Background())
err := conn.Login(suite.ctx)
suite.False(conn.Config.Compression)
suite.NoError(err)
}
Expand All @@ -429,7 +431,7 @@ func (suite *ConnectionTestSuite) TestLoginFails() {
conn := suite.createOpenConnection()
conn.IsClosed = false

err := conn.Login(context.Background())
err := conn.Login(suite.ctx)
suite.True(conn.IsClosed)
suite.EqualError(err, "failed to login: E-EGOD-11: execution failed with SQL error code 'mock sql code' and message 'mock error'")
}
Expand All @@ -439,7 +441,7 @@ func (suite *ConnectionTestSuite) TestLoginFailureRestoresCompressionToTrue() {
conn := suite.createOpenConnection()
conn.Config.Compression = true

conn.Login(context.Background())
conn.Login(suite.ctx)
suite.True(conn.Config.Compression)
}

Expand All @@ -448,7 +450,7 @@ func (suite *ConnectionTestSuite) TestLoginFailureRestoresCompressionToFalse() {
conn := suite.createOpenConnection()
conn.Config.Compression = false

conn.Login(context.Background())
conn.Login(suite.ctx)
suite.False(conn.Config.Compression)
}

Expand Down Expand Up @@ -494,7 +496,7 @@ uYIhswioGpmyPXr/wqz1NFkt5wMzm6sU3lFfCjD5SxU6arQ1zVY3AgMBAAE=
func (suite *ConnectionTestSuite) createOpenConnection() *Connection {
conn := &Connection{
Config: &config.Config{Host: "invalid", Port: 12345, User: "user", Password: "password", ApiVersion: 42},
Ctx: context.Background(),
Ctx: suite.ctx,
IsClosed: false,
websocket: suite.websocketMock,
}
Expand Down
5 changes: 3 additions & 2 deletions pkg/connection/result_set_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -313,13 +313,14 @@ func (suite *ResultSetTestSuite) TestConvertValue() {
}

func (suite *ResultSetTestSuite) createResultSet() QueryResults {
ctx := context.Background()
return QueryResults{
ctx: context.Background(),
ctx: ctx,
data: &types.SqlQueryResponseResultSetData{
ResultSetHandle: 1, NumRows: 2, NumRowsInMessage: 2, Columns: []types.SqlQueryColumn{{}, {}},
},
con: &Connection{
websocket: suite.websocketMock, Config: &config.Config{}, Ctx: context.Background(), IsClosed: false,
websocket: suite.websocketMock, Config: &config.Config{}, Ctx: ctx, IsClosed: false,
},
fetchedRows: 0,
totalRowPointer: 0,
Expand Down
9 changes: 5 additions & 4 deletions pkg/connection/transaction.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@ import (
)

type Transaction struct {
ctx context.Context
connection *Connection
}

func NewTransaction(connection *Connection) *Transaction {
return &Transaction{connection: connection}
func NewTransaction(ctx context.Context, connection *Connection) *Transaction {
return &Transaction{ctx: ctx, connection: connection}
}

func (t *Transaction) Commit() error {
Expand All @@ -24,7 +25,7 @@ func (t *Transaction) Commit() error {
logger.ErrorLogger.Print(errors.ErrClosed)
return driver.ErrBadConn
}
_, err := t.connection.SimpleExec(context.Background(), "COMMIT")
_, err := t.connection.SimpleExec(t.ctx, "COMMIT")
t.connection = nil
return err
}
Expand All @@ -37,7 +38,7 @@ func (t *Transaction) Rollback() error {
logger.ErrorLogger.Print(errors.ErrClosed)
return driver.ErrBadConn
}
_, err := t.connection.SimpleExec(context.Background(), "ROLLBACK")
_, err := t.connection.SimpleExec(t.ctx, "ROLLBACK")
t.connection = nil
return err
}
20 changes: 14 additions & 6 deletions pkg/connection/transaction_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package connection

import (
"context"
"database/sql/driver"
"testing"

Expand All @@ -16,23 +17,30 @@ func TestTransactionSuite(t *testing.T) {
}

func (suite *TransactionTestSuite) TestCommitWithEmptyConnection() {
transaction := Transaction{nil}
transaction := suite.createTransaction()
transaction.connection = nil
suite.EqualError(transaction.Commit(), "E-EGOD-1: invalid connection")
}

func (suite *TransactionTestSuite) TestRollbackWithEmptyConnection() {
transaction := Transaction{nil}
transaction := suite.createTransaction()
transaction.connection = nil
suite.EqualError(transaction.Rollback(), "E-EGOD-1: invalid connection")
}

func (suite *TransactionTestSuite) TestCommitWithClosedConnection() {
connection := Connection{IsClosed: true}
transaction := Transaction{connection: &connection}
transaction := suite.createTransaction()
transaction.connection.IsClosed = true
suite.EqualError(transaction.Commit(), driver.ErrBadConn.Error())
}

func (suite *TransactionTestSuite) TestRollbackWithClosedConnection() {
connection := Connection{IsClosed: true}
transaction := Transaction{connection: &connection}
transaction := suite.createTransaction()
transaction.connection.IsClosed = true
suite.EqualError(transaction.Rollback(), driver.ErrBadConn.Error())
}

func (suite *TransactionTestSuite) createTransaction() Transaction {
connection := Connection{IsClosed: true}
return Transaction{ctx: context.Background(), connection: &connection}
}
Loading

0 comments on commit af423b3

Please sign in to comment.