diff --git a/py-polars/tests/parametric/test_groupby_rolling.py b/py-polars/tests/parametric/test_groupby_rolling.py index c4c62b36a250..1d8eab0fefba 100644 --- a/py-polars/tests/parametric/test_groupby_rolling.py +++ b/py-polars/tests/parametric/test_groupby_rolling.py @@ -41,7 +41,7 @@ def test_group_by_rolling( ], ) ) - df = dataframe.sort("ts").unique("ts") + df = dataframe.sort("ts") try: result = df.group_by_rolling( "ts", period=period, offset=offset, closed=closed @@ -75,3 +75,80 @@ def test_group_by_rolling( pl.col("value").cast(pl.List(pl.Int64)), ) assert_frame_equal(result, expected) + + +@given( + window_size=st.timedeltas(min_value=timedelta(microseconds=0)).map( + _timedelta_to_pl_duration + ), + closed=strategy_closed, + data=st.data(), + time_unit=strategy_time_unit, + aggregation=st.sampled_from( + [ + "min", + "max", + "mean", + "sum", + # "std", blocked by https://github.com/pola-rs/polars/issues/11140 + # "var", blocked by https://github.com/pola-rs/polars/issues/11140 + "median", + ] + ), +) +def test_rolling_aggs( + window_size: str, + closed: ClosedInterval, + data: st.DataObject, + time_unit: TimeUnit, + aggregation: str, +) -> None: + assume(window_size != "") + dataframe = data.draw( + dataframes( + [ + column("ts", dtype=pl.Datetime(time_unit)), + column("value", dtype=pl.Int64), + ], + ) + ) + df = dataframe.sort("ts") + func = f"rolling_{aggregation}" + try: + result = df.with_columns( + getattr(pl.col("value"), func)( + window_size=window_size, by="ts", closed=closed + ) + ) + except pl.exceptions.PolarsPanicError as exc: + assert any( # noqa: PT017 + msg in str(exc) + for msg in ( + "attempt to multiply with overflow", + "attempt to add with overflow", + ) + ) + reject() + + expected_dict: dict[str, list[object]] = {"ts": [], "value": []} + for ts, _ in df.iter_rows(): + window = df.filter( + pl.col("ts").is_between( + pl.lit(ts, dtype=pl.Datetime(time_unit)).dt.offset_by( + f"-{window_size}" + ), + pl.lit(ts, dtype=pl.Datetime(time_unit)), + closed=closed, + ) + ) + expected_dict["ts"].append(ts) + if window.is_empty(): + expected_dict["value"].append(None) + else: + value = getattr(window["value"], aggregation)() + expected_dict["value"].append(value) + expected = pl.DataFrame(expected_dict).select( + pl.col("ts").cast(pl.Datetime(time_unit)), + pl.col("value").cast(result["value"].dtype), + ) + assert_frame_equal(result, expected)