Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor SqlAlchemy Session.execute() to 2.0 style #32857

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions airflow/utils/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -1222,9 +1222,8 @@ def _create_table_as(
)
else:
# Postgres and SQLite both support the same "CREATE TABLE a AS SELECT ..." syntax
session.execute(
f"CREATE TABLE {target_table_name} AS {source_query.selectable.compile(bind=session.get_bind())}"
)
select_table = source_query.selectable.compile(bind=session.get_bind())
session.execute(text(f"CREATE TABLE {target_table_name} AS {select_table}"))


def _move_dangling_data_to_new_table(
Expand Down
3 changes: 2 additions & 1 deletion tests/utils/test_db_cleanup.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import pendulum
import pytest
from pytest import param
from sqlalchemy import text
from sqlalchemy.exc import OperationalError
from sqlalchemy.ext.declarative import DeclarativeMeta

Expand Down Expand Up @@ -211,7 +212,7 @@ def test__build_query(self, table_name, date_add_kwargs, expected_to_delete, ext
)
stmt = CreateTableAs(target_table_name, query.selectable)
session.execute(stmt)
res = session.execute(f"SELECT COUNT(1) FROM {target_table_name}")
res = session.execute(text(f"SELECT COUNT(1) FROM {target_table_name}"))
for row in res:
assert row[0] == expected_to_delete

Expand Down
13 changes: 7 additions & 6 deletions tests/utils/test_sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import pytest
from kubernetes.client import models as k8s
from pytest import param
from sqlalchemy import text
from sqlalchemy.exc import StatementError

from airflow import settings
Expand Down Expand Up @@ -54,7 +55,7 @@ def setup_method(self):
# make sure NOT to run in UTC. Only postgres supports storing
# timezone information in the datetime field
if session.bind.dialect.name == "postgresql":
session.execute("SET timezone='Europe/Amsterdam'")
session.execute(text("SET timezone='Europe/Amsterdam'"))

self.session = session

Expand Down Expand Up @@ -208,17 +209,17 @@ def test_with_row_locks(

def test_prohibit_commit(self):
with prohibit_commit(self.session) as guard:
self.session.execute("SELECT 1")
self.session.execute(text("SELECT 1"))
with pytest.raises(RuntimeError):
self.session.commit()
self.session.rollback()

self.session.execute("SELECT 1")
self.session.execute(text("SELECT 1"))
guard.commit()

# Check the expected_commit is reset
with pytest.raises(RuntimeError):
self.session.execute("SELECT 1")
self.session.execute(text("SELECT 1"))
self.session.commit()

def test_prohibit_commit_specific_session_only(self):
Expand All @@ -233,12 +234,12 @@ def test_prohibit_commit_specific_session_only(self):
assert other_session is not self.session

with prohibit_commit(self.session):
self.session.execute("SELECT 1")
self.session.execute(text("SELECT 1"))
with pytest.raises(RuntimeError):
self.session.commit()
self.session.rollback()

other_session.execute("SELECT 1")
other_session.execute(text("SELECT 1"))
other_session.commit()

def teardown_method(self):
Expand Down