diff --git a/crates/polars-plan/src/dsl/mod.rs b/crates/polars-plan/src/dsl/mod.rs index 6267a8e42354..74b4c2c73c4e 100644 --- a/crates/polars-plan/src/dsl/mod.rs +++ b/crates/polars-plan/src/dsl/mod.rs @@ -1234,10 +1234,9 @@ impl Expr { ComputeError: "`weights` is not supported in 'rolling by' expression" ); let (by, tz) = match by.dtype() { - DataType::Datetime(_, tz) => ( - by.cast(&DataType::Datetime(TimeUnit::Microseconds, None))?, - tz, - ), + DataType::Datetime(tu, tz) => { + (by.cast(&DataType::Datetime(*tu, None))?, tz) + }, _ => (by.clone(), &None), }; let by = by.datetime().unwrap(); diff --git a/py-polars/tests/unit/operations/rolling/test_rolling.py b/py-polars/tests/unit/operations/rolling/test_rolling.py index e9b6e616f5ca..6eef6580017d 100644 --- a/py-polars/tests/unit/operations/rolling/test_rolling.py +++ b/py-polars/tests/unit/operations/rolling/test_rolling.py @@ -21,7 +21,7 @@ from polars.testing import assert_frame_equal, assert_series_equal if TYPE_CHECKING: - from polars.type_aliases import ClosedInterval + from polars.type_aliases import ClosedInterval, TimeUnit @pytest.fixture() @@ -784,11 +784,12 @@ def test_rolling_cov_corr() -> None: assert res["corr"][:2] == [None] * 2 -def test_rolling_empty_window_9406() -> None: +@pytest.mark.parametrize("time_unit", ["ms", "us", "ns"]) +def test_rolling_empty_window_9406(time_unit: TimeUnit) -> None: datecol = pl.Series( "d", [datetime(2019, 1, x) for x in [16, 17, 18, 22, 23]], - dtype=pl.Datetime(time_unit="us", time_zone=None), + dtype=pl.Datetime(time_unit=time_unit, time_zone=None), ) rawdata = pl.Series("x", [1.1, 1.2, 1.3, 1.15, 1.25], dtype=pl.Float64) rmin = pl.Series("x", [None, 1.1, 1.1, None, 1.15], dtype=pl.Float64) @@ -900,3 +901,22 @@ def test_rolling() -> None: a.rolling_sum(3), pl.Series("a", [None, None, 22.0, nan, nan]), ) + + +def test_rolling_nanoseconds_11003() -> None: + df = pl.DataFrame( + { + "dt": [ + "2020-01-01T00:00:00.000000000", + "2020-01-01T00:00:00.000000100", + "2020-01-01T00:00:00.000000200", + ], + "val": [1, 2, 3], + } + ) + df = df.with_columns(pl.col("dt").str.to_datetime(time_unit="ns")) + result = df.with_columns( + pl.col("val").rolling_sum("500ns", by="dt", closed="right") + ) + expected = df.with_columns(val=pl.Series([1, 3, 6])) + assert_frame_equal(result, expected)