Skip to content

Commit

Permalink
Incorporating comments
Browse files Browse the repository at this point in the history
  • Loading branch information
ankiaga committed Nov 29, 2023
1 parent 3e80473 commit 6318ce9
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 53 deletions.
41 changes: 22 additions & 19 deletions google/cloud/spanner_dbapi/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from google.rpc.code_pb2 import ABORTED


TRANSACTION_NOT_BEGUN_WARNING = (
CLIENT_TRANSACTION_NOT_STARTED_WARNING = (
"This method is non-operational as transaction has not begun"
)
MAX_INTERNAL_RETRIES = 50
Expand Down Expand Up @@ -125,7 +125,7 @@ def autocommit(self, value):
:type value: bool
:param value: New autocommit mode state.
"""
if value and not self._autocommit and self.inside_transaction:
if value and not self._autocommit and self.spanner_transaction_started:
self.commit()

self._autocommit = value
Expand All @@ -140,11 +140,14 @@ def database(self):
return self._database

@property
def inside_transaction(self):
"""Flag: transaction is started.
def spanner_transaction_started(self):
"""Flag: whether transaction started at SpanFE. This means that we had
made atleast one call to SpanFE. Property client_transaction_started
would always be true if this is true as transaction has to start first
at clientside than at Spanner (SpanFE)
Returns:
bool: True if transaction started, False otherwise.
bool: True if SpanFE transaction started, False otherwise.
"""
return (
self._transaction
Expand All @@ -153,8 +156,8 @@ def inside_transaction(self):
)

@property
def transaction_begun(self):
"""Flag: transaction has begun
def client_transaction_started(self):
"""Flag: whether transaction started at client side.
Returns:
bool: True if transaction begun, False otherwise.
Expand Down Expand Up @@ -187,7 +190,7 @@ def read_only(self, value):
Args:
value (bool): True for ReadOnly mode, False for ReadWrite.
"""
if self.inside_transaction:
if self.spanner_transaction_started:
raise ValueError(
"Connection read/write mode can't be changed while a transaction is in progress. "
"Commit or rollback the current transaction and try again."
Expand Down Expand Up @@ -225,7 +228,7 @@ def staleness(self, value):
Args:
value (dict): Staleness type and value.
"""
if self.inside_transaction:
if self.spanner_transaction_started:
raise ValueError(
"`staleness` option can't be changed while a transaction is in progress. "
"Commit or rollback the current transaction and try again."
Expand Down Expand Up @@ -348,8 +351,8 @@ def transaction_checkout(self):
:rtype: :class:`google.cloud.spanner_v1.transaction.Transaction`
:returns: A Cloud Spanner transaction object, ready to use.
"""
if self.transaction_begun:
if not self.inside_transaction:
if self.client_transaction_started:
if not self.spanner_transaction_started:
self._transaction = self._session_checkout().transaction()
self._transaction.begin()

Expand All @@ -364,7 +367,7 @@ def snapshot_checkout(self):
:rtype: :class:`google.cloud.spanner_v1.snapshot.Snapshot`
:returns: A Cloud Spanner snapshot object, ready to use.
"""
if self.read_only and self.transaction_begun:
if self.read_only and self.client_transaction_started:
if not self._snapshot:
self._snapshot = Snapshot(
self._session_checkout(), multi_use=True, **self.staleness
Expand All @@ -379,7 +382,7 @@ def close(self):
The connection will be unusable from this point forward. If the
connection has an active transaction, it will be rolled back.
"""
if self.inside_transaction:
if self.spanner_transaction_started:
self._transaction.rollback()

if self._own_pool and self.database:
Expand All @@ -397,7 +400,7 @@ def begin(self):
"""
if self._transaction_begin_marked:
raise OperationalError("A transaction has already begun")
if self.inside_transaction:
if self.spanner_transaction_started:
raise OperationalError(
"Beginning a new transaction is not allowed when a transaction is already running"
)
Expand All @@ -412,12 +415,12 @@ def commit(self):
raise ValueError("Database needs to be passed for this operation")
self._snapshot = None

if not self.transaction_begun:
warnings.warn(TRANSACTION_NOT_BEGUN_WARNING, UserWarning, stacklevel=2)
if not self.client_transaction_started:
warnings.warn(CLIENT_TRANSACTION_NOT_STARTED_WARNING, UserWarning, stacklevel=2)
return

self.run_prior_DDL_statements()
if self.inside_transaction:
if self.spanner_transaction_started:
try:
if not self.read_only:
self._transaction.commit()
Expand All @@ -437,8 +440,8 @@ def rollback(self):
"""
self._snapshot = None

if not self.transaction_begun:
warnings.warn(TRANSACTION_NOT_BEGUN_WARNING, UserWarning, stacklevel=2)
if not self.client_transaction_started:
warnings.warn(CLIENT_TRANSACTION_NOT_STARTED_WARNING, UserWarning, stacklevel=2)
elif self._transaction:
if not self.read_only:
self._transaction.rollback()
Expand Down
14 changes: 7 additions & 7 deletions google/cloud/spanner_dbapi/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def execute(self, sql, args=None):
)
if parsed_statement.statement_type == StatementType.DDL:
self._batch_DDLs(sql)
if not self.connection.transaction_begun:
if not self.connection.client_transaction_started:
self.connection.run_prior_DDL_statements()
return

Expand All @@ -264,7 +264,7 @@ def execute(self, sql, args=None):

sql, args = sql_pyformat_args_to_spanner(sql, args or None)

if self.connection.transaction_begun:
if self.connection.client_transaction_started:
statement = Statement(
sql,
args,
Expand Down Expand Up @@ -348,7 +348,7 @@ def executemany(self, operation, seq_of_params):
)
statements.append((sql, params, get_param_types(params)))

if not self.connection.transaction_begun:
if not self.connection.client_transaction_started:
self.connection.database.run_in_transaction(
self._do_batch_update, statements, many_result_set
)
Expand Down Expand Up @@ -396,7 +396,7 @@ def fetchone(self):
sequence, or None when no more data is available."""
try:
res = next(self)
if self.connection.transaction_begun and not self.connection.read_only:
if self.connection.client_transaction_started and not self.connection.read_only:
self._checksum.consume_result(res)
return res
except StopIteration:
Expand All @@ -414,7 +414,7 @@ def fetchall(self):
res = []
try:
for row in self:
if self.connection.transaction_begun and not self.connection.read_only:
if self.connection.client_transaction_started and not self.connection.read_only:
self._checksum.consume_result(row)
res.append(row)
except Aborted:
Expand Down Expand Up @@ -443,7 +443,7 @@ def fetchmany(self, size=None):
for _ in range(size):
try:
res = next(self)
if self.connection.transaction_begun and not self.connection.read_only:
if self.connection.client_transaction_started and not self.connection.read_only:
self._checksum.consume_result(res)
items.append(res)
except StopIteration:
Expand Down Expand Up @@ -473,7 +473,7 @@ def _handle_DQL(self, sql, params):
if self.connection.database is None:
raise ValueError("Database needs to be passed for this operation")
sql, params = parse_utils.sql_pyformat_args_to_spanner(sql, params)
if self.connection.read_only and self.connection.transaction_begun:
if self.connection.read_only and self.connection.client_transaction_started:
# initiate or use the existing multi-use snapshot
self._handle_DQL_with_snapshot(
self.connection.snapshot_checkout(), sql, params
Expand Down
13 changes: 6 additions & 7 deletions tests/system/test_dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,26 +142,25 @@ def test_begin_client_side(self, shared_instance, dbapi_database):
cursor2.execute("SELECT * FROM contacts")
conn2.commit()
got_rows = cursor2.fetchall()
cursor2.close()
conn2.close()
assert got_rows != [updated_row]

assert conn1._transaction_begin_marked is True
conn1.commit()
assert conn1._transaction_begin_marked is False
cursor1.close()
conn1.close()

# As the connection conn1 is committed a new connection should see its results
conn3 = Connection(shared_instance, dbapi_database)
cursor3 = conn3.cursor()
cursor3.execute("SELECT * FROM contacts")
conn3.commit()
got_rows = cursor3.fetchall()
assert got_rows == [updated_row]

conn1.close()
conn2.close()
conn3.close()
cursor1.close()
cursor2.close()
cursor3.close()
conn3.close()
assert got_rows == [updated_row]

def test_begin_success_post_commit(self):
"""Test beginning a new transaction post commiting an existing transaction
Expand Down
36 changes: 16 additions & 20 deletions tests/unit/spanner_dbapi/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import unittest
import warnings
import pytest
from google.cloud.spanner_dbapi.exceptions import InterfaceError, OperationalError

PROJECT = "test-project"
INSTANCE = "test-instance"
Expand All @@ -36,6 +37,10 @@ class _CredentialsWithScopes(credentials.Credentials, credentials.Scoped):


class TestConnection(unittest.TestCase):

def setUp(self):
self._under_test = self._make_connection()

def _get_client_info(self):
from google.api_core.gapic_v1.client_info import ClientInfo

Expand Down Expand Up @@ -280,7 +285,7 @@ def test_close(self, mock_client):
@mock.patch.object(warnings, "warn")
def test_commit(self, mock_warn):
from google.cloud.spanner_dbapi import Connection
from google.cloud.spanner_dbapi.connection import TRANSACTION_NOT_BEGUN_WARNING
from google.cloud.spanner_dbapi.connection import CLIENT_TRANSACTION_NOT_STARTED_WARNING

connection = Connection(INSTANCE, DATABASE)

Expand All @@ -307,7 +312,7 @@ def test_commit(self, mock_warn):
connection._autocommit = True
connection.commit()
mock_warn.assert_called_once_with(
TRANSACTION_NOT_BEGUN_WARNING, UserWarning, stacklevel=2
CLIENT_TRANSACTION_NOT_STARTED_WARNING, UserWarning, stacklevel=2
)

def test_commit_database_error(self):
Expand All @@ -321,7 +326,7 @@ def test_commit_database_error(self):
@mock.patch.object(warnings, "warn")
def test_rollback(self, mock_warn):
from google.cloud.spanner_dbapi import Connection
from google.cloud.spanner_dbapi.connection import TRANSACTION_NOT_BEGUN_WARNING
from google.cloud.spanner_dbapi.connection import CLIENT_TRANSACTION_NOT_STARTED_WARNING

connection = Connection(INSTANCE, DATABASE)

Expand All @@ -348,7 +353,7 @@ def test_rollback(self, mock_warn):
connection._autocommit = True
connection.rollback()
mock_warn.assert_called_once_with(
TRANSACTION_NOT_BEGUN_WARNING, UserWarning, stacklevel=2
CLIENT_TRANSACTION_NOT_STARTED_WARNING, UserWarning, stacklevel=2
)

@mock.patch("google.cloud.spanner_v1.database.Database", autospec=True)
Expand Down Expand Up @@ -386,7 +391,6 @@ def test_as_context_manager(self):
self.assertTrue(connection.is_closed)

def test_begin_cursor_closed(self):
from google.cloud.spanner_dbapi.exceptions import InterfaceError

connection = self._make_connection()
connection.close()
Expand All @@ -397,33 +401,25 @@ def test_begin_cursor_closed(self):
self.assertEqual(connection._transaction_begin_marked, False)

def test_begin_transaction_begin_marked(self):
from google.cloud.spanner_dbapi.exceptions import OperationalError

connection = self._make_connection()
connection._transaction_begin_marked = True
self._under_test._transaction_begin_marked = True

with self.assertRaises(OperationalError):
connection.begin()
self._under_test.begin()

def test_begin_inside_transaction(self):
from google.cloud.spanner_dbapi.exceptions import OperationalError

connection = self._make_connection()
mock_transaction = mock.MagicMock()
mock_transaction.committed = mock_transaction.rolled_back = False
connection._transaction = mock_transaction
self._under_test._transaction = mock_transaction

with self.assertRaises(OperationalError):
connection.begin()
self._under_test.begin()

self.assertEqual(connection._transaction_begin_marked, False)
self.assertEqual(self._under_test._transaction_begin_marked, False)

def test_begin(self):
connection = self._make_connection()

connection.begin()
self._under_test.begin()

self.assertEqual(connection._transaction_begin_marked, True)
self.assertEqual(self._under_test._transaction_begin_marked, True)

def test_run_statement_wo_retried(self):
"""Check that Connection remembers executed statements."""
Expand Down

0 comments on commit 6318ce9

Please sign in to comment.