Skip to content

Commit

Permalink
fix(rust, python): fix rolling_* functions when "by" has nanosecond r…
Browse files Browse the repository at this point in the history
…esolution (pola-rs#11005)
  • Loading branch information
MarcoGorelli authored Sep 11, 2023
1 parent 8d5de95 commit 8a82137
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 7 deletions.
7 changes: 3 additions & 4 deletions crates/polars-plan/src/dsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1234,10 +1234,9 @@ impl Expr {
ComputeError: "`weights` is not supported in 'rolling by' expression"
);
let (by, tz) = match by.dtype() {
DataType::Datetime(_, tz) => (
by.cast(&DataType::Datetime(TimeUnit::Microseconds, None))?,
tz,
),
DataType::Datetime(tu, tz) => {
(by.cast(&DataType::Datetime(*tu, None))?, tz)
},
_ => (by.clone(), &None),
};
let by = by.datetime().unwrap();
Expand Down
26 changes: 23 additions & 3 deletions py-polars/tests/unit/operations/rolling/test_rolling.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from polars.testing import assert_frame_equal, assert_series_equal

if TYPE_CHECKING:
from polars.type_aliases import ClosedInterval
from polars.type_aliases import ClosedInterval, TimeUnit


@pytest.fixture()
Expand Down Expand Up @@ -784,11 +784,12 @@ def test_rolling_cov_corr() -> None:
assert res["corr"][:2] == [None] * 2


def test_rolling_empty_window_9406() -> None:
@pytest.mark.parametrize("time_unit", ["ms", "us", "ns"])
def test_rolling_empty_window_9406(time_unit: TimeUnit) -> None:
datecol = pl.Series(
"d",
[datetime(2019, 1, x) for x in [16, 17, 18, 22, 23]],
dtype=pl.Datetime(time_unit="us", time_zone=None),
dtype=pl.Datetime(time_unit=time_unit, time_zone=None),
)
rawdata = pl.Series("x", [1.1, 1.2, 1.3, 1.15, 1.25], dtype=pl.Float64)
rmin = pl.Series("x", [None, 1.1, 1.1, None, 1.15], dtype=pl.Float64)
Expand Down Expand Up @@ -900,3 +901,22 @@ def test_rolling() -> None:
a.rolling_sum(3),
pl.Series("a", [None, None, 22.0, nan, nan]),
)


def test_rolling_nanoseconds_11003() -> None:
df = pl.DataFrame(
{
"dt": [
"2020-01-01T00:00:00.000000000",
"2020-01-01T00:00:00.000000100",
"2020-01-01T00:00:00.000000200",
],
"val": [1, 2, 3],
}
)
df = df.with_columns(pl.col("dt").str.to_datetime(time_unit="ns"))
result = df.with_columns(
pl.col("val").rolling_sum("500ns", by="dt", closed="right")
)
expected = df.with_columns(val=pl.Series([1, 3, 6]))
assert_frame_equal(result, expected)

0 comments on commit 8a82137

Please sign in to comment.