From a32bdc7e80c9b4c22b9b9dcc0f79e4ebf28c03bf Mon Sep 17 00:00:00 2001 From: Kevin Zhang <54437031+kzdev420@users.noreply.github.com> Date: Wed, 27 Nov 2024 01:36:25 +0800 Subject: [PATCH] 22942 fix attr name issue (#3093) --- .../sql_versioning/versioning.py | 51 ++++++++++++------- 1 file changed, 34 insertions(+), 17 deletions(-) diff --git a/python/common/sql-versioning/sql_versioning/versioning.py b/python/common/sql-versioning/sql_versioning/versioning.py index 8a34ff5dd2..3739b3e797 100644 --- a/python/common/sql-versioning/sql_versioning/versioning.py +++ b/python/common/sql-versioning/sql_versioning/versioning.py @@ -93,7 +93,7 @@ def _create_version(session, target, operation_type): return transaction_manager = TransactionManager(session) - transaction_id = transaction_manager.create_transaction() + transaction_id = transaction_manager.get_current_transaction_id() if transaction_id is None: print(f'\033[31mError - Unable to create transaction for {target.__class__.__name__} (id={target.id})\033[0m') @@ -119,15 +119,18 @@ def _create_version(session, target, operation_type): 'operation_type': {'I': 0, 'U': 1, 'D': 2}.get(operation_type, 1) } - for column in inspect(target.__class__).columns: + mapper = inspect(target.__class__) + + for column in mapper.columns: if column.name not in ['transaction_id', 'end_transaction_id', 'operation_type']: - if hasattr(target, column.name): - new_version_data[column.name] = getattr(target, column.name) + property_name = mapper.get_property_by_column(column).key + if hasattr(target, property_name): + new_version_data[column.name] = getattr(target, property_name) if existing_version: # Update the existing version session.execute( - update(VersionClass). + update(VersionClass.__table__). where(and_( VersionClass.id == target.id, VersionClass.transaction_id == transaction_id @@ -136,11 +139,11 @@ def _create_version(session, target, operation_type): ) else: # Insert a new version - session.execute(insert(VersionClass).values(new_version_data)) + session.execute(insert(VersionClass.__table__).values(new_version_data)) # Close any open versions session.execute( - update(VersionClass). + update(VersionClass.__table__). where(and_( VersionClass.id == target.id, VersionClass.end_transaction_id.is_(None), @@ -197,14 +200,14 @@ def __init__(self, session): @debug def create_transaction(self): - """Create a new transaction or reuses the existing one in the session. + """Create a new transaction in the session. - :return: The ID of the created or reused transaction. + :return: The ID of the created transaction. """ if 'current_transaction_id' in self.session.info: - print(f"\033[32mReusing existing transaction: {self.session.info['current_transaction_id']}\033[0m") - return self.session.info['current_transaction_id'] + print(f"\033[32mPoping out existing transaction: {self.session.info['current_transaction_id']}\033[0m") + self.session.info.pop('current_transaction_id', None) # Use insert().returning() to get the ID and issued_at without committing stmt = insert(self.transaction_model).values( @@ -224,7 +227,10 @@ def get_current_transaction_id(self): :return: The current transaction ID in the session. """ - return self.session.info.get('current_transaction_id') + if 'current_transaction_id' in self.session.info: + return self.session.info.get('current_transaction_id') + else: + return self.create_transaction() @debug def clear_current_transaction(self): @@ -232,6 +238,9 @@ def clear_current_transaction(self): :return: None """ + if self.session.transaction.nested: + print(f"\033[32mSkip clearing nested transaction\033[0m") + return print(f"\033[32mClearing current transaction: {self.session.info.get('current_transaction_id')}\033[0m") self.session.info.pop('current_transaction_id', None) @@ -244,8 +253,13 @@ def _before_flush(session, flush_context, instances): if not _is_session_modified(session): print('\033[31mThere is no modified versioned object in this session.\033[0m') return - transaction_manager = TransactionManager(session) - transaction_manager.create_transaction() + + if 'current_transaction_id' in session.info: + print(f"\033[31mtransaction_id={session.info['current_transaction_id']} exists before flush.\033[0m") + else: + print('\033[31mCreating transaction before flush.\033[0m') + transaction_manager = TransactionManager(session) + transaction_manager.create_transaction() except Exception as e: raise e @@ -344,10 +358,13 @@ def _after_configured(cls): if hasattr(cls, '_pending_version_classes'): for pending_cls in cls._pending_version_classes: version_cls = pending_cls._version_cls + mapper = inspect(pending_cls) # Now add columns from the original table - for c in pending_cls.__table__.columns: - if not hasattr(version_cls, c.name): - setattr(version_cls, c.name, Column(c.type)) + for c in mapper.columns: + # Make sure table's column name and class's property name can be different + property_name = mapper.get_property_by_column(c).key + if not hasattr(version_cls, property_name): + setattr(version_cls, property_name, Column(c.name, c.type)) delattr(cls, '_pending_version_classes')