Skip to content

Commit

Permalink
feat: add tests for LazyFrame.collect_async and pl.collect_all_async,…
Browse files Browse the repository at this point in the history
… add parameters to documentation
  • Loading branch information
Object905 committed Sep 5, 2023
1 parent ef111de commit 9b256f3
Show file tree
Hide file tree
Showing 4 changed files with 252 additions and 0 deletions.
25 changes: 25 additions & 0 deletions py-polars/polars/functions/lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1710,6 +1710,31 @@ def collect_all_async(
May be useful if you use gevent or asyncio and want to release control to other
greenlets/tasks while LazyFrames are being collected.
Parameters
----------
lazy_frames
A list of LazyFrames to collect.
gevent
Return wrapper to `gevent.event.AsyncResult` instead of Awaitable
type_coercion
Do type coercion optimization.
predicate_pushdown
Do predicate pushdown optimization.
projection_pushdown
Do projection pushdown optimization.
simplify_expression
Run simplify expressions optimization.
no_optimization
Turn off (certain) optimizations.
slice_pushdown
Slice pushdown optimization.
comm_subplan_elim
Will try to cache branching subplans that occur on self-joins or unions.
comm_subexpr_elim
Common subexpressions will be cached and reused.
streaming
Run parts of the query in a streaming fashion (this is in an alpha state)
Notes
-----
In case of error `set_exception` is used on
Expand Down
23 changes: 23 additions & 0 deletions py-polars/polars/lazyframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -1751,6 +1751,29 @@ def collect_async(
May be useful if you use gevent or asyncio and want to release control to other
greenlets/tasks while LazyFrames are being collected.
Parameters
----------
gevent
Return wrapper to `gevent.event.AsyncResult` instead of Awaitable
type_coercion
Do type coercion optimization.
predicate_pushdown
Do predicate pushdown optimization.
projection_pushdown
Do projection pushdown optimization.
simplify_expression
Run simplify expressions optimization.
no_optimization
Turn off (certain) optimizations.
slice_pushdown
Slice pushdown optimization.
comm_subplan_elim
Will try to cache branching subplans that occur on self-joins or unions.
comm_subexpr_elim
Common subexpressions will be cached and reused.
streaming
Run parts of the query in a streaming fashion (this is in an alpha state)
Notes
-----
In case of error `set_exception` is used on
Expand Down
1 change: 1 addition & 0 deletions py-polars/requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ adbc_driver_sqlite; python_version >= '3.9' and platform_system != 'Windows'
connectorx
cloudpickle
fsspec
gevent

# Tooling
hypothesis==6.82.6
Expand Down
203 changes: 203 additions & 0 deletions py-polars/tests/unit/test_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
from __future__ import annotations

import asyncio
import time
from functools import partial
from typing import Any, Callable

import gevent # type: ignore[import]
import pytest

import polars as pl


async def _aio_collect_async(raises: bool = False) -> pl.DataFrame:
lf = (
pl.LazyFrame(
{
"a": ["a", "b", "a", "b", "b", "c"],
"b": [1, 2, 3, 4, 5, 6],
"c": [6, 5, 4, 3, 2, 1],
}
)
.group_by("a", maintain_order=True)
.agg(pl.all().sum())
)
if raises:
lf = lf.select(pl.col("foo_bar"))
return await lf.collect_async()


async def _aio_collect_all_async(raises: bool = False) -> list[pl.DataFrame]:
lf = (
pl.LazyFrame(
{
"a": ["a", "b", "a", "b", "b", "c"],
"b": [1, 2, 3, 4, 5, 6],
"c": [6, 5, 4, 3, 2, 1],
}
)
.group_by("a", maintain_order=True)
.agg(pl.all().sum())
)
if raises:
lf = lf.select(pl.col("foo_bar"))

lf2 = pl.LazyFrame({"a": [1, 2], "b": [1, 2]}).group_by("a").sum()

return await pl.collect_all_async([lf, lf2])


_aio_collect = pytest.mark.parametrize(
("collect", "raises"),
[
(_aio_collect_async, None),
(_aio_collect_all_async, None),
(partial(_aio_collect_async, True), pl.ColumnNotFoundError),
(partial(_aio_collect_all_async, True), pl.ColumnNotFoundError),
],
)


def _aio_run(coroutine: Any, raises: Exception | None = None) -> None:
if raises is not None:
with pytest.raises(raises): # type: ignore[call-overload]
asyncio.run(coroutine)
else:
assert len(asyncio.run(coroutine)) > 0


@_aio_collect
def test_collect_async_switch(
collect: Callable[[], Any],
raises: Exception | None,
) -> None:
async def main() -> Any:
df = collect()
await asyncio.sleep(0.3)
return await df

_aio_run(main(), raises)


@_aio_collect
def test_collect_async_task(
collect: Callable[[], Any], raises: Exception | None
) -> None:
async def main() -> Any:
df = asyncio.create_task(collect())
await asyncio.sleep(0.3)
return await df

_aio_run(main(), raises)


def _gevent_collect_async(raises: bool = False) -> Any:
lf = (
pl.LazyFrame(
{
"a": ["a", "b", "a", "b", "b", "c"],
"b": [1, 2, 3, 4, 5, 6],
"c": [6, 5, 4, 3, 2, 1],
}
)
.group_by("a", maintain_order=True)
.agg(pl.all().sum())
)
if raises:
lf = lf.select(pl.col("foo_bar"))
return lf.collect_async(gevent=True)


def _gevent_collect_all_async(raises: bool = False) -> Any:
lf = (
pl.LazyFrame(
{
"a": ["a", "b", "a", "b", "b", "c"],
"b": [1, 2, 3, 4, 5, 6],
"c": [6, 5, 4, 3, 2, 1],
}
)
.group_by("a", maintain_order=True)
.agg(pl.all().sum())
)
if raises:
lf = lf.select(pl.col("foo_bar"))
return pl.collect_all_async([lf], gevent=True)


_gevent_collect = pytest.mark.parametrize(
("get_result", "raises"),
[
(_gevent_collect_async, None),
(_gevent_collect_all_async, None),
(partial(_gevent_collect_async, True), pl.ColumnNotFoundError),
(partial(_gevent_collect_all_async, True), pl.ColumnNotFoundError),
],
)


def _gevent_run(callback: Callable[[], Any], raises: Exception | None = None) -> None:
if raises is not None:
with pytest.raises(raises): # type: ignore[call-overload]
callback()
else:
assert len(callback()) > 0


@_gevent_collect
def test_gevent_collect_async_without_hub(
get_result: Callable[[], Any], raises: Exception | None
) -> None:
def main() -> Any:
return get_result().get()

_gevent_run(main, raises)


@_gevent_collect
def test_gevent_collect_async_with_hub(
get_result: Callable[[], Any], raises: Exception | None
) -> None:
_hub = gevent.get_hub()

def main() -> Any:
return get_result().get()

_gevent_run(main, raises)


@_gevent_collect
def test_gevent_collect_async_switch(
get_result: Callable[[], Any], raises: Exception | None
) -> None:
def main() -> Any:
result = get_result()
gevent.sleep(0.1)
return result.get(block=False, timeout=3)

_gevent_run(main, raises)


@_gevent_collect
def test_gevent_collect_async_no_switch(
get_result: Callable[[], Any], raises: Exception | None
) -> None:
def main() -> Any:
result = get_result()
time.sleep(1)
return result.get(block=False, timeout=None)

_gevent_run(main, raises)


@_gevent_collect
def test_gevent_collect_async_spawn(
get_result: Callable[[], Any], raises: Exception | None
) -> None:
def main() -> Any:
result_greenlet = gevent.spawn(get_result)
gevent.spawn(gevent.sleep, 0.1)
return result_greenlet.get().get()

_gevent_run(main, raises)

0 comments on commit 9b256f3

Please sign in to comment.