Skip to content

Commit

Permalink
feat(python): Support use of SQLAlchemy "Connectable" in `write_datab…
Browse files Browse the repository at this point in the history
…ase` (pola-rs#17470)
  • Loading branch information
phi-friday authored Jul 8, 2024
1 parent 2b54214 commit e713045
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 5 deletions.
14 changes: 9 additions & 5 deletions py-polars/polars/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -3878,15 +3878,19 @@ def unpack_table_name(name: str) -> tuple[str | None, str | None, str]:
min_err_prefix="pandas >= 2.2 requires",
)
# note: the catalog (database) should be a part of the connection string
from sqlalchemy.engine import create_engine
from sqlalchemy.engine import Connectable, create_engine
from sqlalchemy.orm import Session

sa_object: Connectable
if isinstance(connection, str):
engine_sa = create_engine(connection)
sa_object = create_engine(connection)
elif isinstance(connection, Session):
engine_sa = connection.connection().engine
sa_object = connection.connection()
elif isinstance(connection, Connectable):
sa_object = connection
else:
engine_sa = connection.engine # type: ignore[union-attr]
error_msg = f"unexpected connection type {type(connection)}"
raise TypeError(error_msg)

catalog, db_schema, unpacked_table_name = unpack_table_name(table_name)
if catalog:
Expand All @@ -3900,7 +3904,7 @@ def unpack_table_name(name: str) -> tuple[str | None, str | None, str]:
).to_sql(
name=unpacked_table_name,
schema=db_schema,
con=engine_sa,
con=sa_object,
if_exists=if_table_exists,
index=False,
**(engine_options or {}),
Expand Down
59 changes: 59 additions & 0 deletions py-polars/tests/unit/io/database/test_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,3 +259,62 @@ def test_write_database_using_sa_session(tmp_path: str) -> None:
)

assert_frame_equal(result, df)


@pytest.mark.write_disk()
@pytest.mark.parametrize("pass_connection", [True, False])
def test_write_database_sa_rollback(tmp_path: str, pass_connection: bool) -> None:
df = pl.DataFrame(
{
"key": ["xx", "yy", "zz"],
"value": [123, None, 789],
"other": [5.5, 7.0, None],
}
)
table_name = "test_sa_rollback"
test_db_uri = f"sqlite:///{tmp_path}/test_sa_rollback.db"
engine = create_engine(test_db_uri, poolclass=NullPool)
with Session(engine) as session:
if pass_connection:
conn = session.connection()
df.write_database(table_name, conn)
else:
df.write_database(table_name, session)
session.rollback()

with Session(engine) as session:
count = pl.read_database(
query=f"select count(*) from {table_name}", connection=session
).item(0, 0)

assert isinstance(count, int)
assert count == 0


@pytest.mark.write_disk()
@pytest.mark.parametrize("pass_connection", [True, False])
def test_write_database_sa_commit(tmp_path: str, pass_connection: bool) -> None:
df = pl.DataFrame(
{
"key": ["xx", "yy", "zz"],
"value": [123, None, 789],
"other": [5.5, 7.0, None],
}
)
table_name = "test_sa_commit"
test_db_uri = f"sqlite:///{tmp_path}/test_sa_commit.db"
engine = create_engine(test_db_uri, poolclass=NullPool)
with Session(engine) as session:
if pass_connection:
conn = session.connection()
df.write_database(table_name, conn)
else:
df.write_database(table_name, session)
session.commit()

with Session(engine) as session:
result = pl.read_database(
query=f"select * from {table_name}", connection=session
)

assert_frame_equal(result, df)

0 comments on commit e713045

Please sign in to comment.