diff --git a/py-polars/polars/functions/lazy.py b/py-polars/polars/functions/lazy.py index 1868ae180ee44..03160dc2c0f58 100644 --- a/py-polars/polars/functions/lazy.py +++ b/py-polars/polars/functions/lazy.py @@ -1710,6 +1710,31 @@ def collect_all_async( May be useful if you use gevent or asyncio and want to release control to other greenlets/tasks while LazyFrames are being collected. + Parameters + ---------- + lazy_frames + A list of LazyFrames to collect. + gevent + Return wrapper to `gevent.event.AsyncResult` instead of Awaitable + type_coercion + Do type coercion optimization. + predicate_pushdown + Do predicate pushdown optimization. + projection_pushdown + Do projection pushdown optimization. + simplify_expression + Run simplify expressions optimization. + no_optimization + Turn off (certain) optimizations. + slice_pushdown + Slice pushdown optimization. + comm_subplan_elim + Will try to cache branching subplans that occur on self-joins or unions. + comm_subexpr_elim + Common subexpressions will be cached and reused. + streaming + Run parts of the query in a streaming fashion (this is in an alpha state) + Notes ----- In case of error `set_exception` is used on diff --git a/py-polars/polars/lazyframe/frame.py b/py-polars/polars/lazyframe/frame.py index ca217cbc224df..e610de2fb7e98 100644 --- a/py-polars/polars/lazyframe/frame.py +++ b/py-polars/polars/lazyframe/frame.py @@ -1751,6 +1751,29 @@ def collect_async( May be useful if you use gevent or asyncio and want to release control to other greenlets/tasks while LazyFrames are being collected. + Parameters + ---------- + gevent + Return wrapper to `gevent.event.AsyncResult` instead of Awaitable + type_coercion + Do type coercion optimization. + predicate_pushdown + Do predicate pushdown optimization. + projection_pushdown + Do projection pushdown optimization. + simplify_expression + Run simplify expressions optimization. + no_optimization + Turn off (certain) optimizations. + slice_pushdown + Slice pushdown optimization. + comm_subplan_elim + Will try to cache branching subplans that occur on self-joins or unions. + comm_subexpr_elim + Common subexpressions will be cached and reused. + streaming + Run parts of the query in a streaming fashion (this is in an alpha state) + Notes ----- In case of error `set_exception` is used on diff --git a/py-polars/requirements-dev.txt b/py-polars/requirements-dev.txt index ed9ce720a1a9c..6a8edbb9b3886 100644 --- a/py-polars/requirements-dev.txt +++ b/py-polars/requirements-dev.txt @@ -20,6 +20,7 @@ adbc_driver_sqlite; python_version >= '3.9' and platform_system != 'Windows' connectorx cloudpickle fsspec +gevent # Tooling hypothesis==6.82.6 diff --git a/py-polars/tests/unit/test_async.py b/py-polars/tests/unit/test_async.py new file mode 100644 index 0000000000000..9e874c96ad5b1 --- /dev/null +++ b/py-polars/tests/unit/test_async.py @@ -0,0 +1,203 @@ +from __future__ import annotations + +import asyncio +import time +from functools import partial +from typing import Any, Callable + +import gevent # type: ignore[import] +import pytest + +import polars as pl + + +async def _aio_collect_async(raises: bool = False) -> pl.DataFrame: + lf = ( + pl.LazyFrame( + { + "a": ["a", "b", "a", "b", "b", "c"], + "b": [1, 2, 3, 4, 5, 6], + "c": [6, 5, 4, 3, 2, 1], + } + ) + .group_by("a", maintain_order=True) + .agg(pl.all().sum()) + ) + if raises: + lf = lf.select(pl.col("foo_bar")) + return await lf.collect_async() + + +async def _aio_collect_all_async(raises: bool = False) -> list[pl.DataFrame]: + lf = ( + pl.LazyFrame( + { + "a": ["a", "b", "a", "b", "b", "c"], + "b": [1, 2, 3, 4, 5, 6], + "c": [6, 5, 4, 3, 2, 1], + } + ) + .group_by("a", maintain_order=True) + .agg(pl.all().sum()) + ) + if raises: + lf = lf.select(pl.col("foo_bar")) + + lf2 = pl.LazyFrame({"a": [1, 2], "b": [1, 2]}).group_by("a").sum() + + return await pl.collect_all_async([lf, lf2]) + + +_aio_collect = pytest.mark.parametrize( + ("collect", "raises"), + [ + (_aio_collect_async, None), + (_aio_collect_all_async, None), + (partial(_aio_collect_async, True), pl.ColumnNotFoundError), + (partial(_aio_collect_all_async, True), pl.ColumnNotFoundError), + ], +) + + +def _aio_run(coroutine: Any, raises: Exception | None = None) -> None: + if raises is not None: + with pytest.raises(raises): # type: ignore[call-overload] + asyncio.run(coroutine) + else: + assert len(asyncio.run(coroutine)) > 0 + + +@_aio_collect +def test_collect_async_switch( + collect: Callable[[], Any], + raises: Exception | None, +) -> None: + async def main() -> Any: + df = collect() + await asyncio.sleep(0.3) + return await df + + _aio_run(main(), raises) + + +@_aio_collect +def test_collect_async_task( + collect: Callable[[], Any], raises: Exception | None +) -> None: + async def main() -> Any: + df = asyncio.create_task(collect()) + await asyncio.sleep(0.3) + return await df + + _aio_run(main(), raises) + + +def _gevent_collect_async(raises: bool = False) -> Any: + lf = ( + pl.LazyFrame( + { + "a": ["a", "b", "a", "b", "b", "c"], + "b": [1, 2, 3, 4, 5, 6], + "c": [6, 5, 4, 3, 2, 1], + } + ) + .group_by("a", maintain_order=True) + .agg(pl.all().sum()) + ) + if raises: + lf = lf.select(pl.col("foo_bar")) + return lf.collect_async(gevent=True) + + +def _gevent_collect_all_async(raises: bool = False) -> Any: + lf = ( + pl.LazyFrame( + { + "a": ["a", "b", "a", "b", "b", "c"], + "b": [1, 2, 3, 4, 5, 6], + "c": [6, 5, 4, 3, 2, 1], + } + ) + .group_by("a", maintain_order=True) + .agg(pl.all().sum()) + ) + if raises: + lf = lf.select(pl.col("foo_bar")) + return pl.collect_all_async([lf], gevent=True) + + +_gevent_collect = pytest.mark.parametrize( + ("get_result", "raises"), + [ + (_gevent_collect_async, None), + (_gevent_collect_all_async, None), + (partial(_gevent_collect_async, True), pl.ColumnNotFoundError), + (partial(_gevent_collect_all_async, True), pl.ColumnNotFoundError), + ], +) + + +def _gevent_run(callback: Callable[[], Any], raises: Exception | None = None) -> None: + if raises is not None: + with pytest.raises(raises): # type: ignore[call-overload] + callback() + else: + assert len(callback()) > 0 + + +@_gevent_collect +def test_gevent_collect_async_without_hub( + get_result: Callable[[], Any], raises: Exception | None +) -> None: + def main() -> Any: + return get_result().get() + + _gevent_run(main, raises) + + +@_gevent_collect +def test_gevent_collect_async_with_hub( + get_result: Callable[[], Any], raises: Exception | None +) -> None: + _hub = gevent.get_hub() + + def main() -> Any: + return get_result().get() + + _gevent_run(main, raises) + + +@_gevent_collect +def test_gevent_collect_async_switch( + get_result: Callable[[], Any], raises: Exception | None +) -> None: + def main() -> Any: + result = get_result() + gevent.sleep(0.1) + return result.get(block=False, timeout=3) + + _gevent_run(main, raises) + + +@_gevent_collect +def test_gevent_collect_async_no_switch( + get_result: Callable[[], Any], raises: Exception | None +) -> None: + def main() -> Any: + result = get_result() + time.sleep(1) + return result.get(block=False, timeout=None) + + _gevent_run(main, raises) + + +@_gevent_collect +def test_gevent_collect_async_spawn( + get_result: Callable[[], Any], raises: Exception | None +) -> None: + def main() -> Any: + result_greenlet = gevent.spawn(get_result) + gevent.spawn(gevent.sleep, 0.1) + return result_greenlet.get().get() + + _gevent_run(main, raises)