Skip to content

Commit

Permalink
perf(python): optimise read_database Databricks queries made using …
Browse files Browse the repository at this point in the history
…SQLAlchemy connections (also fixes an issue with 'iter_batches' prematurely closing cursor)
  • Loading branch information
alexander-beedie committed Oct 20, 2023
1 parent 28a99f6 commit f699130
Show file tree
Hide file tree
Showing 3 changed files with 168 additions and 79 deletions.
8 changes: 4 additions & 4 deletions py-polars/polars/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -3970,8 +3970,8 @@ def filter(
Provide multiple filters using `*args` syntax:
>>> df.filter(
... pl.col("foo") == 1,
... pl.col("ham") == "a",
... pl.col("foo") <= 2,
... ~pl.col("ham").is_in(["b", "c"]),
... )
shape: (1, 3)
┌─────┬─────┬─────┐
Expand All @@ -3984,14 +3984,14 @@ def filter(
Provide multiple filters using `**kwargs` syntax:
>>> df.filter(foo=1, ham="a")
>>> df.filter(foo=2, ham="b")
shape: (1, 3)
┌─────┬─────┬─────┐
│ foo ┆ bar ┆ ham │
│ --- ┆ --- ┆ --- │
│ i64 ┆ i64 ┆ str │
╞═════╪═════╪═════╡
16a
27b
└─────┴─────┴─────┘
"""
Expand Down
83 changes: 60 additions & 23 deletions py-polars/polars/io/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,42 +31,49 @@
Selectable: TypeAlias = Any # type: ignore[no-redef]


class _DriverProperties_(TypedDict):
fetch_all: str
fetch_batches: str | None
exact_batch_size: bool | None
class _ArrowDriverProperties_(TypedDict):
fetch_all: str # name of the method that fetches all arrow data
fetch_batches: str | None # name of the method that fetches arrow data in batches
exact_batch_size: bool | None # whether indicated batch size is respected exactly
repeat_batch_calls: bool # repeat batch calls (if False, batch call is generator)


_ARROW_DRIVER_REGISTRY_: dict[str, _DriverProperties_] = {
_ARROW_DRIVER_REGISTRY_: dict[str, _ArrowDriverProperties_] = {
"adbc_.*": {
"fetch_all": "fetch_arrow_table",
"fetch_batches": None,
"exact_batch_size": None,
"repeat_batch_calls": False,
},
"arrow_odbc_proxy": {
"fetch_all": "fetch_record_batches",
"fetch_batches": "fetch_record_batches",
"exact_batch_size": True,
"repeat_batch_calls": False,
},
"databricks": {
"fetch_all": "fetchall_arrow",
"fetch_batches": "fetchmany_arrow",
"exact_batch_size": True,
"repeat_batch_calls": True,
},
"duckdb": {
"fetch_all": "fetch_arrow_table",
"fetch_batches": "fetch_record_batch",
"exact_batch_size": True,
"repeat_batch_calls": False,
},
"snowflake": {
"fetch_all": "fetch_arrow_all",
"fetch_batches": "fetch_arrow_batches",
"exact_batch_size": False,
"repeat_batch_calls": False,
},
"turbodbc": {
"fetch_all": "fetchallarrow",
"fetch_batches": "fetcharrowbatches",
"exact_batch_size": False,
"repeat_batch_calls": False,
},
}

Expand Down Expand Up @@ -121,10 +128,9 @@ def fetch_record_batches(
class ConnectionExecutor:
"""Abstraction for querying databases with user-supplied connection objects."""

# indicate that we acquired a cursor (and are therefore responsible for closing
# it on scope-exit). note that we should never close the underlying connection,
# or a user-supplied cursor.
acquired_cursor: bool = False
# indicate if we can/should close the cursor on scope exit. note that we
# should never close the underlying connection, or a user-supplied cursor.
can_close_cursor: bool = False

def __init__(self, connection: ConnectionOrCursor) -> None:
self.driver_name = (
Expand All @@ -144,24 +150,57 @@ def __exit__(
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
# iif we created it, close the cursor (NOT the connection)
if self.acquired_cursor:
# iif we created it and are finished with it, we can
# close the cursor (but NOT the connection)
if self.can_close_cursor:
self.cursor.close()

def __repr__(self) -> str:
return f"<{type(self).__name__} module={self.driver_name!r}>"

def _arrow_batches(
self,
driver_properties: _ArrowDriverProperties_,
*,
batch_size: int | None,
iter_batches: bool,
) -> Iterable[pa.RecordBatch]:
"""Yield Arrow data in batches, or as a single 'fetchall' batch."""
fetch_batches = driver_properties["fetch_batches"]
if not iter_batches or fetch_batches is None:
fetch_method = driver_properties["fetch_all"]
yield getattr(self.result, fetch_method)()
else:
size = batch_size if driver_properties["exact_batch_size"] else None
repeat_batch_calls = driver_properties["repeat_batch_calls"]
fetchmany_arrow = getattr(self.result, fetch_batches)
if not repeat_batch_calls:
yield from fetchmany_arrow(size)
else:
while True:
arrow = fetchmany_arrow(size)
if not arrow:
break
yield arrow

def _normalise_cursor(self, conn: ConnectionOrCursor) -> Cursor:
"""Normalise a connection object such that we have the query executor."""
if self.driver_name == "sqlalchemy" and type(conn).__name__ == "Engine":
# sqlalchemy engine; direct use is deprecated, so prefer the connection
self.acquired_cursor = True
return conn.connect() # type: ignore[union-attr]
self.can_close_cursor = True
if conn.driver == "databricks-sql-python": # type: ignore[union-attr]
# take advantage of the raw connection to get arrow integration
self.driver_name = "databricks"
return conn.raw_connection().cursor() # type: ignore[union-attr]
else:
# sqlalchemy engine; direct use is deprecated, so prefer the connection
return conn.connect() # type: ignore[union-attr]

elif hasattr(conn, "cursor"):
# connection has a dedicated cursor; prefer over direct execute
cursor = cursor() if callable(cursor := conn.cursor) else cursor
self.acquired_cursor = True
self.can_close_cursor = True
return cursor

elif hasattr(conn, "execute"):
# can execute directly (given cursor, sqlalchemy connection, etc)
return conn # type: ignore[return-value]
Expand Down Expand Up @@ -206,22 +245,20 @@ def _from_arrow(
try:
for driver, driver_properties in _ARROW_DRIVER_REGISTRY_.items():
if re.match(f"^{driver}$", self.driver_name):
size = batch_size if driver_properties["exact_batch_size"] else None
fetch_batches = driver_properties["fetch_batches"]
self.can_close_cursor = fetch_batches is None or not iter_batches
frames = (
from_arrow(batch, schema_overrides=schema_overrides)
for batch in (
getattr(self.result, fetch_batches)(size)
if (iter_batches and fetch_batches is not None)
else [
getattr(self.result, driver_properties["fetch_all"])()
]
for batch in self._arrow_batches(
driver_properties,
iter_batches=iter_batches,
batch_size=batch_size,
)
)
return frames if iter_batches else next(frames) # type: ignore[arg-type,return-value]
except Exception as err:
# eg: valid turbodbc/snowflake connection, but no arrow support
# available in the underlying driver or this connection
# compiled in to the underlying driver (or on this connection)
arrow_not_supported = (
"does not support Apache Arrow",
"Apache Arrow format is not supported",
Expand Down
156 changes: 104 additions & 52 deletions py-polars/tests/unit/io/test_database_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@

import polars as pl
from polars.exceptions import UnsuitableSQLError
from polars.io.database import _ARROW_DRIVER_REGISTRY_
from polars.testing import assert_frame_equal

if TYPE_CHECKING:
import pyarrow as pa

from polars.type_aliases import DbReadEngine, SchemaDefinition, SchemaDict


Expand Down Expand Up @@ -84,6 +87,77 @@ class ExceptionTestParams(NamedTuple):
kwargs: dict[str, Any] | None = None


class MockConnection:
"""Mock connection class for databases we can't test in CI."""

def __init__(
self,
driver: str,
batch_size: int | None,
test_data: pa.Table,
repeat_batch_calls: bool,
) -> None:
self.__class__.__module__ = driver
self._cursor = MockCursor(
repeat_batch_calls=repeat_batch_calls,
batched=(batch_size is not None),
test_data=test_data,
)

def close(self) -> None: # noqa: D102
pass

def cursor(self) -> Any: # noqa: D102
return self._cursor


class MockCursor:
"""Mock cursor class for databases we can't test in CI."""

def __init__(
self,
batched: bool,
test_data: pa.Table,
repeat_batch_calls: bool,
) -> None:
self.resultset = MockResultSet(test_data, batched, repeat_batch_calls)
self.called: list[str] = []
self.batched = batched
self.n_calls = 1

def __getattr__(self, item: str) -> Any:
if "fetch" in item:
self.called.append(item)
return self.resultset
super().__getattr__(item) # type: ignore[misc]

def close(self) -> Any: # noqa: D102
pass

def execute(self, query: str) -> Any: # noqa: D102
return self


class MockResultSet:
"""Mock resultset class for databases we can't test in CI."""

def __init__(
self, test_data: pa.Table, batched: bool, repeat_batch_calls: bool = False
):
self.test_data = test_data
self.repeat_batched_calls = repeat_batch_calls
self.batched = batched
self.n_calls = 1

def __call__(self, *args: Any, **kwargs: Any) -> Any: # noqa: D102
if self.repeat_batched_calls:
res = self.test_data[: None if self.n_calls else 0]
self.n_calls -= 1
else:
res = iter((self.test_data,))
return res


@pytest.mark.write_disk()
@pytest.mark.parametrize(
(
Expand Down Expand Up @@ -307,45 +381,9 @@ def test_read_database_parameterisd(tmp_path: Path) -> None:
)


def test_read_database_mocked() -> None:
arr = pl.DataFrame({"x": [1, 2, 3], "y": ["aa", "bb", "cc"]}).to_arrow()

class MockConnection:
def __init__(self, driver: str, batch_size: int | None = None) -> None:
self.__class__.__module__ = driver
self._cursor = MockCursor(batched=batch_size is not None)

def close(self) -> None:
pass

def cursor(self) -> Any:
return self._cursor

class MockCursor:
def __init__(self, batched: bool) -> None:
self.called: list[str] = []
self.batched = batched

def __getattr__(self, item: str) -> Any:
if "fetch" in item:
res = (
(lambda *args, **kwargs: (arr for _ in range(1)))
if self.batched
else (lambda *args, **kwargs: arr)
)
self.called.append(item)
return res
super().__getattr__(item) # type: ignore[misc]

def close(self) -> Any:
pass

def execute(self, query: str) -> Any:
return self

# since we don't have access to snowflake/databricks/etc from CI we
# mock them so we can check that we're calling the expected methods
for driver, batch_size, iter_batches, expected_call in (
@pytest.mark.parametrize(
("driver", "batch_size", "iter_batches", "expected_call"),
[
("snowflake", None, False, "fetch_arrow_all"),
("snowflake", 10_000, False, "fetch_arrow_all"),
("snowflake", 10_000, True, "fetch_arrow_batches"),
Expand All @@ -358,20 +396,34 @@ def execute(self, query: str) -> Any:
("adbc_driver_postgresql", None, False, "fetch_arrow_table"),
("adbc_driver_postgresql", 75_000, False, "fetch_arrow_table"),
("adbc_driver_postgresql", 75_000, True, "fetch_arrow_table"),
):
mc = MockConnection(driver, batch_size)
res = pl.read_database( # type: ignore[call-overload]
query="SELECT * FROM test_data",
connection=mc,
iter_batches=iter_batches,
batch_size=batch_size,
)
assert expected_call in mc.cursor().called
if iter_batches:
assert isinstance(res, GeneratorType)
res = pl.concat(res)
],
)
def test_read_database_mocked(
driver: str, batch_size: int | None, iter_batches: bool, expected_call: str
) -> None:
# since we don't have access to snowflake/databricks/etc from CI we
# mock them so we can check that we're calling the expected methods
arrow = pl.DataFrame({"x": [1, 2, 3], "y": ["aa", "bb", "cc"]}).to_arrow()
mc = MockConnection(
driver,
batch_size,
test_data=arrow,
repeat_batch_calls=_ARROW_DRIVER_REGISTRY_.get(driver, {}).get( # type: ignore[call-overload]
"repeat_batch_calls", False
),
)
res = pl.read_database( # type: ignore[call-overload]
query="SELECT * FROM test_data",
connection=mc,
iter_batches=iter_batches,
batch_size=batch_size,
)
if iter_batches:
assert isinstance(res, GeneratorType)
res = pl.concat(res)

assert res.rows() == [(1, "aa"), (2, "bb"), (3, "cc")]
assert expected_call in mc.cursor().called
assert res.rows() == [(1, "aa"), (2, "bb"), (3, "cc")]


@pytest.mark.parametrize(
Expand Down

0 comments on commit f699130

Please sign in to comment.