diff --git a/py-polars/polars/dataframe/frame.py b/py-polars/polars/dataframe/frame.py index 572d9bfdebaa..c17bc1563e65 100644 --- a/py-polars/polars/dataframe/frame.py +++ b/py-polars/polars/dataframe/frame.py @@ -3970,8 +3970,8 @@ def filter( Provide multiple filters using `*args` syntax: >>> df.filter( - ... pl.col("foo") == 1, - ... pl.col("ham") == "a", + ... pl.col("foo") <= 2, + ... ~pl.col("ham").is_in(["b", "c"]), ... ) shape: (1, 3) ┌─────┬─────┬─────┐ @@ -3984,14 +3984,14 @@ def filter( Provide multiple filters using `**kwargs` syntax: - >>> df.filter(foo=1, ham="a") + >>> df.filter(foo=2, ham="b") shape: (1, 3) ┌─────┬─────┬─────┐ │ foo ┆ bar ┆ ham │ │ --- ┆ --- ┆ --- │ │ i64 ┆ i64 ┆ str │ ╞═════╪═════╪═════╡ - │ 1 ┆ 6 ┆ a │ + │ 2 ┆ 7 ┆ b │ └─────┴─────┴─────┘ """ diff --git a/py-polars/polars/io/database.py b/py-polars/polars/io/database.py index ed3bfa858d5a..1f0795b3df0d 100644 --- a/py-polars/polars/io/database.py +++ b/py-polars/polars/io/database.py @@ -31,42 +31,49 @@ Selectable: TypeAlias = Any # type: ignore[no-redef] -class _DriverProperties_(TypedDict): - fetch_all: str - fetch_batches: str | None - exact_batch_size: bool | None +class _ArrowDriverProperties_(TypedDict): + fetch_all: str # name of the method that fetches all arrow data + fetch_batches: str | None # name of the method that fetches arrow data in batches + exact_batch_size: bool | None # whether indicated batch size is respected exactly + repeat_batch_calls: bool # repeat batch calls (if False, batch call is generator) -_ARROW_DRIVER_REGISTRY_: dict[str, _DriverProperties_] = { +_ARROW_DRIVER_REGISTRY_: dict[str, _ArrowDriverProperties_] = { "adbc_.*": { "fetch_all": "fetch_arrow_table", "fetch_batches": None, "exact_batch_size": None, + "repeat_batch_calls": False, }, "arrow_odbc_proxy": { "fetch_all": "fetch_record_batches", "fetch_batches": "fetch_record_batches", "exact_batch_size": True, + "repeat_batch_calls": False, }, "databricks": { "fetch_all": "fetchall_arrow", "fetch_batches": "fetchmany_arrow", "exact_batch_size": True, + "repeat_batch_calls": True, }, "duckdb": { "fetch_all": "fetch_arrow_table", "fetch_batches": "fetch_record_batch", "exact_batch_size": True, + "repeat_batch_calls": False, }, "snowflake": { "fetch_all": "fetch_arrow_all", "fetch_batches": "fetch_arrow_batches", "exact_batch_size": False, + "repeat_batch_calls": False, }, "turbodbc": { "fetch_all": "fetchallarrow", "fetch_batches": "fetcharrowbatches", "exact_batch_size": False, + "repeat_batch_calls": False, }, } @@ -121,10 +128,9 @@ def fetch_record_batches( class ConnectionExecutor: """Abstraction for querying databases with user-supplied connection objects.""" - # indicate that we acquired a cursor (and are therefore responsible for closing - # it on scope-exit). note that we should never close the underlying connection, - # or a user-supplied cursor. - acquired_cursor: bool = False + # indicate if we can/should close the cursor on scope exit. note that we + # should never close the underlying connection, or a user-supplied cursor. + can_close_cursor: bool = False def __init__(self, connection: ConnectionOrCursor) -> None: self.driver_name = ( @@ -144,24 +150,57 @@ def __exit__( exc_val: BaseException | None, exc_tb: TracebackType | None, ) -> None: - # iif we created it, close the cursor (NOT the connection) - if self.acquired_cursor: + # iif we created it and are finished with it, we can + # close the cursor (but NOT the connection) + if self.can_close_cursor: self.cursor.close() def __repr__(self) -> str: return f"<{type(self).__name__} module={self.driver_name!r}>" + def _arrow_batches( + self, + driver_properties: _ArrowDriverProperties_, + *, + batch_size: int | None, + iter_batches: bool, + ) -> Iterable[pa.RecordBatch]: + """Yield Arrow data in batches, or as a single 'fetchall' batch.""" + fetch_batches = driver_properties["fetch_batches"] + if not iter_batches or fetch_batches is None: + fetch_method = driver_properties["fetch_all"] + yield getattr(self.result, fetch_method)() + else: + size = batch_size if driver_properties["exact_batch_size"] else None + repeat_batch_calls = driver_properties["repeat_batch_calls"] + fetchmany_arrow = getattr(self.result, fetch_batches) + if not repeat_batch_calls: + yield from fetchmany_arrow(size) + else: + while True: + arrow = fetchmany_arrow(size) + if not arrow: + break + yield arrow + def _normalise_cursor(self, conn: ConnectionOrCursor) -> Cursor: """Normalise a connection object such that we have the query executor.""" if self.driver_name == "sqlalchemy" and type(conn).__name__ == "Engine": - # sqlalchemy engine; direct use is deprecated, so prefer the connection - self.acquired_cursor = True - return conn.connect() # type: ignore[union-attr] + self.can_close_cursor = True + if conn.driver == "databricks-sql-python": # type: ignore[union-attr] + # take advantage of the raw connection to get arrow integration + self.driver_name = "databricks" + return conn.raw_connection().cursor() # type: ignore[union-attr] + else: + # sqlalchemy engine; direct use is deprecated, so prefer the connection + return conn.connect() # type: ignore[union-attr] + elif hasattr(conn, "cursor"): # connection has a dedicated cursor; prefer over direct execute cursor = cursor() if callable(cursor := conn.cursor) else cursor - self.acquired_cursor = True + self.can_close_cursor = True return cursor + elif hasattr(conn, "execute"): # can execute directly (given cursor, sqlalchemy connection, etc) return conn # type: ignore[return-value] @@ -206,22 +245,20 @@ def _from_arrow( try: for driver, driver_properties in _ARROW_DRIVER_REGISTRY_.items(): if re.match(f"^{driver}$", self.driver_name): - size = batch_size if driver_properties["exact_batch_size"] else None fetch_batches = driver_properties["fetch_batches"] + self.can_close_cursor = fetch_batches is None or not iter_batches frames = ( from_arrow(batch, schema_overrides=schema_overrides) - for batch in ( - getattr(self.result, fetch_batches)(size) - if (iter_batches and fetch_batches is not None) - else [ - getattr(self.result, driver_properties["fetch_all"])() - ] + for batch in self._arrow_batches( + driver_properties, + iter_batches=iter_batches, + batch_size=batch_size, ) ) return frames if iter_batches else next(frames) # type: ignore[arg-type,return-value] except Exception as err: # eg: valid turbodbc/snowflake connection, but no arrow support - # available in the underlying driver or this connection + # compiled in to the underlying driver (or on this connection) arrow_not_supported = ( "does not support Apache Arrow", "Apache Arrow format is not supported", diff --git a/py-polars/tests/unit/io/test_database_read.py b/py-polars/tests/unit/io/test_database_read.py index aa872c00a85b..824ffb989fd2 100644 --- a/py-polars/tests/unit/io/test_database_read.py +++ b/py-polars/tests/unit/io/test_database_read.py @@ -15,9 +15,12 @@ import polars as pl from polars.exceptions import UnsuitableSQLError +from polars.io.database import _ARROW_DRIVER_REGISTRY_ from polars.testing import assert_frame_equal if TYPE_CHECKING: + import pyarrow as pa + from polars.type_aliases import DbReadEngine, SchemaDefinition, SchemaDict @@ -84,6 +87,77 @@ class ExceptionTestParams(NamedTuple): kwargs: dict[str, Any] | None = None +class MockConnection: + """Mock connection class for databases we can't test in CI.""" + + def __init__( + self, + driver: str, + batch_size: int | None, + test_data: pa.Table, + repeat_batch_calls: bool, + ) -> None: + self.__class__.__module__ = driver + self._cursor = MockCursor( + repeat_batch_calls=repeat_batch_calls, + batched=(batch_size is not None), + test_data=test_data, + ) + + def close(self) -> None: # noqa: D102 + pass + + def cursor(self) -> Any: # noqa: D102 + return self._cursor + + +class MockCursor: + """Mock cursor class for databases we can't test in CI.""" + + def __init__( + self, + batched: bool, + test_data: pa.Table, + repeat_batch_calls: bool, + ) -> None: + self.resultset = MockResultSet(test_data, batched, repeat_batch_calls) + self.called: list[str] = [] + self.batched = batched + self.n_calls = 1 + + def __getattr__(self, item: str) -> Any: + if "fetch" in item: + self.called.append(item) + return self.resultset + super().__getattr__(item) # type: ignore[misc] + + def close(self) -> Any: # noqa: D102 + pass + + def execute(self, query: str) -> Any: # noqa: D102 + return self + + +class MockResultSet: + """Mock resultset class for databases we can't test in CI.""" + + def __init__( + self, test_data: pa.Table, batched: bool, repeat_batch_calls: bool = False + ): + self.test_data = test_data + self.repeat_batched_calls = repeat_batch_calls + self.batched = batched + self.n_calls = 1 + + def __call__(self, *args: Any, **kwargs: Any) -> Any: # noqa: D102 + if self.repeat_batched_calls: + res = self.test_data[: None if self.n_calls else 0] + self.n_calls -= 1 + else: + res = iter((self.test_data,)) + return res + + @pytest.mark.write_disk() @pytest.mark.parametrize( ( @@ -307,45 +381,9 @@ def test_read_database_parameterisd(tmp_path: Path) -> None: ) -def test_read_database_mocked() -> None: - arr = pl.DataFrame({"x": [1, 2, 3], "y": ["aa", "bb", "cc"]}).to_arrow() - - class MockConnection: - def __init__(self, driver: str, batch_size: int | None = None) -> None: - self.__class__.__module__ = driver - self._cursor = MockCursor(batched=batch_size is not None) - - def close(self) -> None: - pass - - def cursor(self) -> Any: - return self._cursor - - class MockCursor: - def __init__(self, batched: bool) -> None: - self.called: list[str] = [] - self.batched = batched - - def __getattr__(self, item: str) -> Any: - if "fetch" in item: - res = ( - (lambda *args, **kwargs: (arr for _ in range(1))) - if self.batched - else (lambda *args, **kwargs: arr) - ) - self.called.append(item) - return res - super().__getattr__(item) # type: ignore[misc] - - def close(self) -> Any: - pass - - def execute(self, query: str) -> Any: - return self - - # since we don't have access to snowflake/databricks/etc from CI we - # mock them so we can check that we're calling the expected methods - for driver, batch_size, iter_batches, expected_call in ( +@pytest.mark.parametrize( + ("driver", "batch_size", "iter_batches", "expected_call"), + [ ("snowflake", None, False, "fetch_arrow_all"), ("snowflake", 10_000, False, "fetch_arrow_all"), ("snowflake", 10_000, True, "fetch_arrow_batches"), @@ -358,20 +396,34 @@ def execute(self, query: str) -> Any: ("adbc_driver_postgresql", None, False, "fetch_arrow_table"), ("adbc_driver_postgresql", 75_000, False, "fetch_arrow_table"), ("adbc_driver_postgresql", 75_000, True, "fetch_arrow_table"), - ): - mc = MockConnection(driver, batch_size) - res = pl.read_database( # type: ignore[call-overload] - query="SELECT * FROM test_data", - connection=mc, - iter_batches=iter_batches, - batch_size=batch_size, - ) - assert expected_call in mc.cursor().called - if iter_batches: - assert isinstance(res, GeneratorType) - res = pl.concat(res) + ], +) +def test_read_database_mocked( + driver: str, batch_size: int | None, iter_batches: bool, expected_call: str +) -> None: + # since we don't have access to snowflake/databricks/etc from CI we + # mock them so we can check that we're calling the expected methods + arrow = pl.DataFrame({"x": [1, 2, 3], "y": ["aa", "bb", "cc"]}).to_arrow() + mc = MockConnection( + driver, + batch_size, + test_data=arrow, + repeat_batch_calls=_ARROW_DRIVER_REGISTRY_.get(driver, {}).get( # type: ignore[call-overload] + "repeat_batch_calls", False + ), + ) + res = pl.read_database( # type: ignore[call-overload] + query="SELECT * FROM test_data", + connection=mc, + iter_batches=iter_batches, + batch_size=batch_size, + ) + if iter_batches: + assert isinstance(res, GeneratorType) + res = pl.concat(res) - assert res.rows() == [(1, "aa"), (2, "bb"), (3, "cc")] + assert expected_call in mc.cursor().called + assert res.rows() == [(1, "aa"), (2, "bb"), (3, "cc")] @pytest.mark.parametrize(