Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implementation for Begin and Rollback clientside statements #1041

Merged
merged 13 commits into from
Dec 4, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,15 @@ def execute(connection, parsed_statement: ParsedStatement):

It is an internal method that can make backwards-incompatible changes.

:type connection: Connection
ankiaga marked this conversation as resolved.
Show resolved Hide resolved
:param connection: Connection object of the dbApi

:type parsed_statement: ParsedStatement
:param parsed_statement: parsed_statement based on the sql query
"""
if parsed_statement.client_side_statement_type == ClientSideStatementType.COMMIT:
return connection.commit()
if parsed_statement.client_side_statement_type == ClientSideStatementType.BEGIN:
return connection.begin()
if parsed_statement.client_side_statement_type == ClientSideStatementType.ROLLBACK:
return connection.rollback()
10 changes: 10 additions & 0 deletions google/cloud/spanner_dbapi/client_side_statement_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
ClientSideStatementType,
)

RE_BEGIN = re.compile(r"^\s*(BEGIN|START)(TRANSACTION)?", re.IGNORECASE)
RE_COMMIT = re.compile(r"^\s*(COMMIT)(TRANSACTION)?", re.IGNORECASE)
RE_ROLLBACK = re.compile(r"^\s*(ROLLBACK)(TRANSACTION)?", re.IGNORECASE)


def parse_stmt(query):
Expand All @@ -39,4 +41,12 @@ def parse_stmt(query):
return ParsedStatement(
StatementType.CLIENT_SIDE, query, ClientSideStatementType.COMMIT
)
if RE_BEGIN.match(query):
return ParsedStatement(
StatementType.CLIENT_SIDE, query, ClientSideStatementType.BEGIN
)
if RE_ROLLBACK.match(query):
return ParsedStatement(
StatementType.CLIENT_SIDE, query, ClientSideStatementType.ROLLBACK
)
return None
48 changes: 38 additions & 10 deletions google/cloud/spanner_dbapi/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@
from google.rpc.code_pb2 import ABORTED


AUTOCOMMIT_MODE_WARNING = "This method is non-operational in autocommit mode"
TRANSACTION_NOT_BEGUN_WARNING = (
ankiaga marked this conversation as resolved.
Show resolved Hide resolved
"This method is non-operational as transaction has not begun"
)
MAX_INTERNAL_RETRIES = 50


Expand Down Expand Up @@ -104,6 +106,7 @@ def __init__(self, instance, database=None, read_only=False):
self._read_only = read_only
self._staleness = None
self.request_priority = None
self._transaction_begin_marked = False

@property
def autocommit(self):
Expand Down Expand Up @@ -141,14 +144,23 @@ def inside_transaction(self):
"""Flag: transaction is started.

Returns:
bool: True if transaction begun, False otherwise.
bool: True if transaction started, False otherwise.
"""
return (
self._transaction
and not self._transaction.committed
and not self._transaction.rolled_back
)

@property
ankiaga marked this conversation as resolved.
Show resolved Hide resolved
def transaction_begun(self):
ankiaga marked this conversation as resolved.
Show resolved Hide resolved
"""Flag: transaction has begun

Returns:
bool: True if transaction begun, False otherwise.
"""
return (not self._autocommit) or self._transaction_begin_marked
olavloite marked this conversation as resolved.
Show resolved Hide resolved

@property
def instance(self):
"""Instance to which this connection relates.
Expand Down Expand Up @@ -333,12 +345,10 @@ def transaction_checkout(self):
Begin a new transaction, if there is no transaction in
this connection yet. Return the begun one otherwise.

The method is non operational in autocommit mode.
ankiaga marked this conversation as resolved.
Show resolved Hide resolved

:rtype: :class:`google.cloud.spanner_v1.transaction.Transaction`
:returns: A Cloud Spanner transaction object, ready to use.
"""
if not self.autocommit:
if self.transaction_begun:
if not self.inside_transaction:
self._transaction = self._session_checkout().transaction()
self._transaction.begin()
Expand All @@ -354,7 +364,7 @@ def snapshot_checkout(self):
:rtype: :class:`google.cloud.spanner_v1.snapshot.Snapshot`
:returns: A Cloud Spanner snapshot object, ready to use.
"""
if self.read_only and not self.autocommit:
if self.read_only and self.transaction_begun:
if not self._snapshot:
self._snapshot = Snapshot(
self._session_checkout(), multi_use=True, **self.staleness
Expand All @@ -377,6 +387,22 @@ def close(self):

self.is_closed = True

@check_not_closed
def begin(self):
"""
Marks the transaction as started.
olavloite marked this conversation as resolved.
Show resolved Hide resolved

:raises: :class:`InterfaceError`: if this connection is closed.
:raises: :class:`OperationalError`: if there is an existing transaction that has begin or is running
"""
if self._transaction_begin_marked:
raise OperationalError("A transaction has already begun")
ankiaga marked this conversation as resolved.
Show resolved Hide resolved
if self.inside_transaction:
raise OperationalError(
"Beginning a new transaction is not allowed when a transaction is already running"
)
self._transaction_begin_marked = True

def commit(self):
"""Commits any pending transaction to the database.

Expand All @@ -386,8 +412,8 @@ def commit(self):
raise ValueError("Database needs to be passed for this operation")
self._snapshot = None

if self._autocommit:
warnings.warn(AUTOCOMMIT_MODE_WARNING, UserWarning, stacklevel=2)
if not self.transaction_begun:
warnings.warn(TRANSACTION_NOT_BEGUN_WARNING, UserWarning, stacklevel=2)
return

self.run_prior_DDL_statements()
Expand All @@ -398,6 +424,7 @@ def commit(self):

self._release_session()
self._statements = []
self._transaction_begin_marked = False
except Aborted:
self.retry_transaction()
self.commit()
Expand All @@ -410,14 +437,15 @@ def rollback(self):
"""
self._snapshot = None

if self._autocommit:
warnings.warn(AUTOCOMMIT_MODE_WARNING, UserWarning, stacklevel=2)
if not self.transaction_begun:
warnings.warn(TRANSACTION_NOT_BEGUN_WARNING, UserWarning, stacklevel=2)
elif self._transaction:
if not self.read_only:
self._transaction.rollback()

self._release_session()
self._statements = []
self._transaction_begin_marked = False

@check_not_closed
def cursor(self):
Expand Down
14 changes: 7 additions & 7 deletions google/cloud/spanner_dbapi/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def execute(self, sql, args=None):
)
if parsed_statement.statement_type == StatementType.DDL:
self._batch_DDLs(sql)
if self.connection.autocommit:
if not self.connection.transaction_begun:
self.connection.run_prior_DDL_statements()
return

Expand All @@ -264,7 +264,7 @@ def execute(self, sql, args=None):

sql, args = sql_pyformat_args_to_spanner(sql, args or None)

if not self.connection.autocommit:
if self.connection.transaction_begun:
statement = Statement(
sql,
args,
Expand Down Expand Up @@ -348,7 +348,7 @@ def executemany(self, operation, seq_of_params):
)
statements.append((sql, params, get_param_types(params)))

if self.connection.autocommit:
if not self.connection.transaction_begun:
self.connection.database.run_in_transaction(
self._do_batch_update, statements, many_result_set
)
Expand Down Expand Up @@ -396,7 +396,7 @@ def fetchone(self):
sequence, or None when no more data is available."""
try:
res = next(self)
if not self.connection.autocommit and not self.connection.read_only:
if self.connection.transaction_begun and not self.connection.read_only:
self._checksum.consume_result(res)
return res
except StopIteration:
Expand All @@ -414,7 +414,7 @@ def fetchall(self):
res = []
try:
for row in self:
if not self.connection.autocommit and not self.connection.read_only:
if self.connection.transaction_begun and not self.connection.read_only:
self._checksum.consume_result(row)
res.append(row)
except Aborted:
Expand Down Expand Up @@ -443,7 +443,7 @@ def fetchmany(self, size=None):
for _ in range(size):
try:
res = next(self)
if not self.connection.autocommit and not self.connection.read_only:
if self.connection.transaction_begun and not self.connection.read_only:
self._checksum.consume_result(res)
items.append(res)
except StopIteration:
Expand Down Expand Up @@ -473,7 +473,7 @@ def _handle_DQL(self, sql, params):
if self.connection.database is None:
raise ValueError("Database needs to be passed for this operation")
sql, params = parse_utils.sql_pyformat_args_to_spanner(sql, params)
if self.connection.read_only and not self.connection.autocommit:
if self.connection.read_only and self.connection.transaction_begun:
# initiate or use the existing multi-use snapshot
self._handle_DQL_with_snapshot(
self.connection.snapshot_checkout(), sql, params
Expand Down
1 change: 1 addition & 0 deletions google/cloud/spanner_dbapi/parsed_statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class StatementType(Enum):
class ClientSideStatementType(Enum):
COMMIT = 1
BEGIN = 2
ROLLBACK = 3


@dataclass
Expand Down
Loading
Loading