Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(python): Address read_database issue with batched reads from Snowflake #17688

Merged
merged 1 commit into from
Jul 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 16 additions & 8 deletions py-polars/polars/io/database/_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,14 +162,14 @@ def _fetch_arrow(
fetch_method = driver_properties["fetch_all"]
yield getattr(self.result, fetch_method)()
else:
size = batch_size if driver_properties["exact_batch_size"] else None
size = [batch_size] if driver_properties["exact_batch_size"] else []
repeat_batch_calls = driver_properties["repeat_batch_calls"]
fetchmany_arrow = getattr(self.result, fetch_batches)
if not repeat_batch_calls:
yield from fetchmany_arrow(size)
yield from fetchmany_arrow(*size)
else:
while True:
arrow = fetchmany_arrow(size)
arrow = fetchmany_arrow(*size)
if not arrow:
break
yield arrow
Expand Down Expand Up @@ -213,6 +213,13 @@ def _from_arrow(
if re.match(f"^{driver}$", self.driver_name):
if ver := driver_properties["minimum_version"]:
self._check_module_version(self.driver_name, ver)

if iter_batches and (
driver_properties["exact_batch_size"] and not batch_size
):
msg = f"Cannot set `iter_batches` for {self.driver_name} without also setting a non-zero `batch_size`"
raise ValueError(msg) # noqa: TRY301

frames = (
self._apply_overrides(batch, (schema_overrides or {}))
if isinstance(batch, DataFrame)
Expand Down Expand Up @@ -247,6 +254,12 @@ def _from_rows(
"""Return resultset data row-wise for frame init."""
from polars import DataFrame

if iter_batches and not batch_size:
msg = (
"Cannot set `iter_batches` without also setting a non-zero `batch_size`"
)
raise ValueError(msg)

if is_async := isinstance(original_result := self.result, Coroutine):
self.result = _run_async(self.result)
try:
Expand Down Expand Up @@ -506,11 +519,6 @@ def to_polars(
if self.result is None:
msg = "Cannot return a frame before executing a query"
raise RuntimeError(msg)
elif iter_batches and not batch_size:
msg = (
"Cannot set `iter_batches` without also setting a non-zero `batch_size`"
)
raise ValueError(msg)

can_close = self.can_close_cursor

Expand Down
13 changes: 7 additions & 6 deletions py-polars/polars/io/database/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,16 +98,17 @@ def read_database(
data returned by the query; this can be useful for processing large resultsets
in a memory-efficient manner. If supported by the backend, this value is passed
to the underlying query execution method (note that very low values will
typically result in poor performance as it will result in many round-trips to
the database as the data is returned). If the backend does not support changing
typically result in poor performance as it will cause many round-trips to the
database as the data is returned). If the backend does not support changing
the batch size then a single DataFrame is yielded from the iterator.
batch_size
Indicate the size of each batch when `iter_batches` is True (note that you can
still set this when `iter_batches` is False, in which case the resulting
DataFrame is constructed internally using batched return before being returned
to you. Note that some backends may support batched operation but not allow for
an explicit size; in this case you will still receive batches, but their exact
size will be determined by the backend (so may not equal the value set here).
to you. Note that some backends (such as Snowflake) may support batch operation
but not allow for an explicit size to be set; in this case you will still
receive batches but their size is determined by the backend (in which case any
value set here will be ignored).
schema_overrides
A dictionary mapping column names to dtypes, used to override the schema
inferred from the query cursor or given by the incoming Arrow data (depending
Expand Down Expand Up @@ -242,7 +243,7 @@ def read_database(
connection = ODBCCursorProxy(connection)
elif "://" in connection:
# otherwise looks like a mistaken call to read_database_uri
msg = "Use of string URI is invalid here; call `read_database_uri` instead"
msg = "use of string URI is invalid here; call `read_database_uri` instead"
raise ValueError(msg)
else:
msg = "unable to identify string connection as valid ODBC (no driver)"
Expand Down
30 changes: 25 additions & 5 deletions py-polars/tests/unit/io/database/test_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,14 @@ def __init__(
self,
driver: str,
batch_size: int | None,
exact_batch_size: bool,
test_data: pa.Table,
repeat_batch_calls: bool,
) -> None:
self.__class__.__module__ = driver
self._cursor = MockCursor(
repeat_batch_calls=repeat_batch_calls,
exact_batch_size=exact_batch_size,
batched=(batch_size is not None),
test_data=test_data,
)
Expand All @@ -69,10 +71,17 @@ class MockCursor:
def __init__(
self,
batched: bool,
exact_batch_size: bool,
test_data: pa.Table,
repeat_batch_calls: bool,
) -> None:
self.resultset = MockResultSet(test_data, batched, repeat_batch_calls)
self.resultset = MockResultSet(
test_data=test_data,
batched=batched,
exact_batch_size=exact_batch_size,
repeat_batch_calls=repeat_batch_calls,
)
self.exact_batch_size = exact_batch_size
self.called: list[str] = []
self.batched = batched
self.n_calls = 1
Expand All @@ -94,14 +103,21 @@ class MockResultSet:
"""Mock resultset class for databases we can't test in CI."""

def __init__(
self, test_data: pa.Table, batched: bool, repeat_batch_calls: bool = False
self,
test_data: pa.Table,
batched: bool,
exact_batch_size: bool,
repeat_batch_calls: bool = False,
):
self.test_data = test_data
self.repeat_batched_calls = repeat_batch_calls
self.exact_batch_size = exact_batch_size
self.batched = batched
self.n_calls = 1

def __call__(self, *args: Any, **kwargs: Any) -> Any:
if not self.exact_batch_size:
assert len(args) == 0
if self.repeat_batched_calls:
res = self.test_data[: None if self.n_calls else 0]
self.n_calls -= 1
Expand Down Expand Up @@ -478,13 +494,17 @@ def test_read_database_mocked(
# since we don't have access to snowflake/databricks/etc from CI we
# mock them so we can check that we're calling the expected methods
arrow = pl.DataFrame({"x": [1, 2, 3], "y": ["aa", "bb", "cc"]}).to_arrow()

reg = ARROW_DRIVER_REGISTRY.get(driver, {}) # type: ignore[var-annotated]
exact_batch_size = reg.get("exact_batch_size", False)
repeat_batch_calls = reg.get("repeat_batch_calls", False)

mc = MockConnection(
driver,
batch_size,
test_data=arrow,
repeat_batch_calls=ARROW_DRIVER_REGISTRY.get(driver, {}).get( # type: ignore[call-overload]
"repeat_batch_calls", False
),
repeat_batch_calls=repeat_batch_calls,
exact_batch_size=exact_batch_size, # type: ignore[arg-type]
)
res = pl.read_database(
query="SELECT * FROM test_data",
Expand Down