From 6b23f79bf4fd6591824393af4a56a132e6a4bf8d Mon Sep 17 00:00:00 2001 From: Marco Edward Gorelli Date: Fri, 8 Mar 2024 17:33:30 +0000 Subject: [PATCH] fix(rust, python): std when ddof>=n_values returns None even in rolling context (#11750) --- .../legacy/kernels/rolling/no_nulls/mean.rs | 6 ++-- .../kernels/rolling/no_nulls/min_max.rs | 10 +++--- .../legacy/kernels/rolling/no_nulls/mod.rs | 34 +++++++++++-------- .../kernels/rolling/no_nulls/quantile.rs | 12 +++---- .../legacy/kernels/rolling/no_nulls/sum.rs | 13 +++++-- .../kernels/rolling/no_nulls/variance.rs | 30 +++++++--------- .../src/frame/group_by/aggregations/mod.rs | 34 ++++++++++++++++--- .../rolling_kernels/no_nulls.rs | 2 +- py-polars/polars/expr/expr.py | 8 ++--- .../tests/parametric/test_groupby_rolling.py | 4 +-- .../unit/operations/rolling/test_rolling.py | 25 ++++++++------ py-polars/tests/unit/test_lazy.py | 4 +-- 12 files changed, 110 insertions(+), 72 deletions(-) diff --git a/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/mean.rs b/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/mean.rs index b5da98336178..f74f88248b2f 100644 --- a/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/mean.rs +++ b/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/mean.rs @@ -17,9 +17,9 @@ impl< } } - unsafe fn update(&mut self, start: usize, end: usize) -> T { - let sum = self.sum.update(start, end); - sum / NumCast::from(end - start).unwrap() + unsafe fn update(&mut self, start: usize, end: usize) -> Option { + let sum = self.sum.update(start, end).unwrap_unchecked(); + Some(sum / NumCast::from(end - start).unwrap()) } } diff --git a/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/min_max.rs b/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/min_max.rs index d7368d130c00..54fe9a927dde 100644 --- a/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/min_max.rs +++ b/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/min_max.rs @@ -148,7 +148,7 @@ macro_rules! minmax_window { } } - unsafe fn update(&mut self, start: usize, end: usize) -> T { + unsafe fn update(&mut self, start: usize, end: usize) -> Option { //For details see: https://github.com/pola-rs/polars/pull/9277#issuecomment-1581401692 self.last_start = start; // Don't care where the last one started let old_last_end = self.last_end; // But we need this @@ -168,10 +168,10 @@ macro_rules! minmax_window { if entering.map(|em| $new_is_m(&self.m, em.1) || empty_overlap) == Some(true) { // The entering extremum "beats" the previous extremum so we can ignore the overlap self.update_m_and_m_idx(entering.unwrap()); - return self.m; + return Some(self.m); } else if self.m_idx >= start || empty_overlap { // The previous extremum didn't drop off. Keep it - return self.m; + return Some(self.m); } // Otherwise get the min of the overlapping window and the entering min match ( @@ -191,7 +191,7 @@ macro_rules! minmax_window { (None, None) => unreachable!(), } - self.m + Some(self.m) } } }; @@ -241,7 +241,7 @@ macro_rules! rolling_minmax_func { _params: DynArgs, ) -> PolarsResult where - T: NativeType + PartialOrd + IsFloat + Bounded + NumCast + Mul, + T: NativeType + PartialOrd + IsFloat + Bounded + NumCast + Mul + Num, { let offset_fn = match center { true => det_offsets_center, diff --git a/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/mod.rs b/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/mod.rs index 298d9a2a0283..ffe04bcfd598 100644 --- a/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/mod.rs +++ b/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/mod.rs @@ -3,12 +3,11 @@ mod min_max; mod quantile; mod sum; mod variance; - use std::fmt::Debug; pub use mean::*; pub use min_max::*; -use num_traits::{Float, NumCast}; +use num_traits::{Float, Num, NumCast}; pub use quantile::*; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; @@ -28,7 +27,7 @@ pub trait RollingAggWindowNoNulls<'a, T: NativeType> { /// /// # Safety /// `start` and `end` must be within the windows bounds - unsafe fn update(&mut self, start: usize, end: usize) -> T; + unsafe fn update(&mut self, start: usize, end: usize) -> Option; } // Use an aggregation window that maintains the state @@ -42,27 +41,34 @@ pub(super) fn rolling_apply_agg_window<'a, Agg, T, Fo>( where Fo: Fn(Idx, WindowSize, Len) -> (Start, End), Agg: RollingAggWindowNoNulls<'a, T>, - T: Debug + NativeType, + T: Debug + NativeType + Num, { let len = values.len(); let (start, end) = det_offsets_fn(0, window_size, len); let mut agg_window = Agg::new(values, start, end, params); + if let Some(validity) = create_validity(min_periods, len, window_size, &det_offsets_fn) { + if validity.iter().all(|x| !x) { + return Ok(Box::new(PrimitiveArray::::new_null( + T::PRIMITIVE.into(), + len, + ))); + } + } let out = (0..len) .map(|idx| { let (start, end) = det_offsets_fn(idx, window_size, len); - // SAFETY: - // we are in bounds - unsafe { agg_window.update(start, end) } + if end - start < min_periods { + None + } else { + // SAFETY: + // we are in bounds + unsafe { agg_window.update(start, end) } + } }) .collect_trusted::>(); - - let validity = create_validity(min_periods, len, window_size, det_offsets_fn); - Ok(Box::new(PrimitiveArray::new( - T::PRIMITIVE.into(), - out.into(), - validity.map(|b| b.into()), - ))) + let arr = PrimitiveArray::from(out); + Ok(Box::new(arr)) } #[derive(Clone, Copy, PartialEq, Eq, Debug, Default, Hash)] diff --git a/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/quantile.rs b/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/quantile.rs index a4d590eca931..50b7702bdbdd 100644 --- a/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/quantile.rs +++ b/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/quantile.rs @@ -35,7 +35,7 @@ impl< } } - unsafe fn update(&mut self, start: usize, end: usize) -> T { + unsafe fn update(&mut self, start: usize, end: usize) -> Option { let vals = self.sorted.update(start, end); let length = vals.len(); @@ -48,13 +48,13 @@ impl< let float_idx_top = (length_f - 1.0) * self.prob; let top_idx = float_idx_top.ceil() as usize; return if idx == top_idx { - unsafe { *vals.get_unchecked_release(idx) } + Some(unsafe { *vals.get_unchecked_release(idx) }) } else { let proportion = T::from(float_idx_top - idx as f64).unwrap(); let vi = unsafe { *vals.get_unchecked_release(idx) }; let vj = unsafe { *vals.get_unchecked_release(top_idx) }; - proportion * (vj - vi) + vi + Some(proportion * (vj - vi) + vi) }; }, Midpoint => { @@ -66,7 +66,7 @@ impl< return if top_idx == idx { // SAFETY: // we are in bounds - unsafe { *vals.get_unchecked_release(idx) } + Some(unsafe { *vals.get_unchecked_release(idx) }) } else { // SAFETY: // we are in bounds @@ -77,7 +77,7 @@ impl< ) }; - (mid + mid_plus_1) / (T::one() + T::one()) + Some((mid + mid_plus_1) / (T::one() + T::one())) }; }, Nearest => { @@ -93,7 +93,7 @@ impl< // SAFETY: // we are in bounds - unsafe { *vals.get_unchecked_release(idx) } + Some(unsafe { *vals.get_unchecked_release(idx) }) } } diff --git a/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/sum.rs b/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/sum.rs index 5c35c3df5840..b66a3a4fc5ed 100644 --- a/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/sum.rs +++ b/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/sum.rs @@ -20,7 +20,7 @@ impl<'a, T: NativeType + IsFloat + std::iter::Sum + AddAssign + SubAssign> } } - unsafe fn update(&mut self, start: usize, end: usize) -> T { + unsafe fn update(&mut self, start: usize, end: usize) -> Option { // if we exceed the end, we have a completely new window // so we recompute let recompute_sum = if start >= self.last_end { @@ -60,7 +60,7 @@ impl<'a, T: NativeType + IsFloat + std::iter::Sum + AddAssign + SubAssign> } } self.last_end = end; - self.sum + Some(self.sum) } } @@ -73,7 +73,14 @@ pub fn rolling_sum( _params: DynArgs, ) -> PolarsResult where - T: NativeType + std::iter::Sum + NumCast + Mul + AddAssign + SubAssign + IsFloat, + T: NativeType + + std::iter::Sum + + NumCast + + Mul + + AddAssign + + SubAssign + + IsFloat + + Num, { match (center, weights) { (true, None) => rolling_apply_agg_window::, _, _>( diff --git a/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/variance.rs b/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/variance.rs index 564f43642d22..4e3de45cfeff 100644 --- a/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/variance.rs +++ b/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/variance.rs @@ -26,7 +26,7 @@ impl<'a, T: NativeType + IsFloat + std::iter::Sum + AddAssign + SubAssign + Mul< } } - unsafe fn update(&mut self, start: usize, end: usize) -> T { + unsafe fn update(&mut self, start: usize, end: usize) -> Option { // if we exceed the end, we have a completely new window // so we recompute let recompute_sum = if start >= self.last_end || self.last_recompute > 128 { @@ -68,7 +68,7 @@ impl<'a, T: NativeType + IsFloat + std::iter::Sum + AddAssign + SubAssign + Mul< } } self.last_end = end; - self.sum_of_squares + Some(self.sum_of_squares) } } @@ -108,25 +108,24 @@ impl< } } - unsafe fn update(&mut self, start: usize, end: usize) -> T { + unsafe fn update(&mut self, start: usize, end: usize) -> Option { let count: T = NumCast::from(end - start).unwrap(); - let sum_of_squares = self.sum_of_squares.update(start, end); - let mean = self.mean.update(start, end); + let sum_of_squares = self.sum_of_squares.update(start, end).unwrap_unchecked(); + let mean = self.mean.update(start, end).unwrap_unchecked(); let denom = count - NumCast::from(self.ddof).unwrap(); - if end - start == 1 { - T::zero() - } else if denom <= T::zero() { - //ddof would be greater than # of observations - T::infinity() + if denom <= T::zero() { + None + } else if end - start == 1 { + Some(T::zero()) } else { let out = (sum_of_squares - count * mean * mean) / denom; // variance cannot be negative. // if it is negative it is due to numeric instability if out < T::zero() { - T::zero() + Some(T::zero()) } else { - out + Some(out) } } } @@ -208,14 +207,11 @@ mod test { let out = rolling_var(values, 2, 1, false, None, None).unwrap(); let out = out.as_any().downcast_ref::>().unwrap(); - let out = out - .into_iter() - .map(|v| v.copied().unwrap()) - .collect::>(); + let out = out.into_iter().map(|v| v.copied()).collect::>(); // we cannot compare nans, so we compare the string values assert_eq!( format!("{:?}", out.as_slice()), - format!("{:?}", &[0.0, 8.0, 2.0, 0.5]) + format!("{:?}", &[None, Some(8.0), Some(2.0), Some(0.5)]) ); // test nan handling. let values = &[-10.0, 2.0, 3.0, f64::nan(), 5.0, 6.0, 7.0]; diff --git a/crates/polars-core/src/frame/group_by/aggregations/mod.rs b/crates/polars-core/src/frame/group_by/aggregations/mod.rs index 4ed2e4bab9a5..46dc3261d680 100644 --- a/crates/polars-core/src/frame/group_by/aggregations/mod.rs +++ b/crates/polars-core/src/frame/group_by/aggregations/mod.rs @@ -142,7 +142,7 @@ where None } else { // SAFETY: we are in bounds. - Some(unsafe { agg_window.update(start as usize, end as usize) }) + unsafe { agg_window.update(start as usize, end as usize) } } }) .collect::>() @@ -799,7 +799,13 @@ where debug_assert!(len <= self.len() as IdxSize); match len { 0 => None, - 1 => NumCast::from(0), + 1 => { + if ddof == 0 { + NumCast::from(0) + } else { + None + } + }, _ => { let arr_group = _slice_from_offsets(self, first, len); arr_group.var(ddof).map(|flt| NumCast::from(flt).unwrap()) @@ -861,7 +867,13 @@ where debug_assert!(len <= self.len() as IdxSize); match len { 0 => None, - 1 => NumCast::from(0), + 1 => { + if ddof == 0 { + NumCast::from(0) + } else { + None + } + }, _ => { let arr_group = _slice_from_offsets(self, first, len); arr_group.std(ddof).map(|flt| NumCast::from(flt).unwrap()) @@ -1012,7 +1024,13 @@ where debug_assert!(first + len <= self.len() as IdxSize); match len { 0 => None, - 1 => NumCast::from(0), + 1 => { + if ddof == 0 { + NumCast::from(0) + } else { + None + } + }, _ => { let arr_group = _slice_from_offsets(self, first, len); arr_group.var(ddof) @@ -1054,7 +1072,13 @@ where debug_assert!(first + len <= self.len() as IdxSize); match len { 0 => None, - 1 => NumCast::from(0), + 1 => { + if ddof == 0 { + NumCast::from(0) + } else { + None + } + }, _ => { let arr_group = _slice_from_offsets(self, first, len); arr_group.std(ddof) diff --git a/crates/polars-time/src/chunkedarray/rolling_window/rolling_kernels/no_nulls.rs b/crates/polars-time/src/chunkedarray/rolling_window/rolling_kernels/no_nulls.rs index 8c02b38625a7..abd4eadffc79 100644 --- a/crates/polars-time/src/chunkedarray/rolling_window/rolling_kernels/no_nulls.rs +++ b/crates/polars-time/src/chunkedarray/rolling_window/rolling_kernels/no_nulls.rs @@ -41,7 +41,7 @@ where } else { // SAFETY: // we are in bounds - Some(unsafe { agg_window.update(start as usize, end as usize) }) + unsafe { agg_window.update(start as usize, end as usize) } } }) }) diff --git a/py-polars/polars/expr/expr.py b/py-polars/polars/expr/expr.py index 398283e914d3..e0487c6ee99c 100644 --- a/py-polars/polars/expr/expr.py +++ b/py-polars/polars/expr/expr.py @@ -6834,7 +6834,7 @@ def rolling_std( │ u32 ┆ datetime[μs] ┆ f64 │ ╞═══════╪═════════════════════╪═════════════════╡ │ 0 ┆ 2001-01-01 00:00:00 ┆ null │ - │ 1 ┆ 2001-01-01 01:00:00 ┆ 0.0 │ + │ 1 ┆ 2001-01-01 01:00:00 ┆ null │ │ 2 ┆ 2001-01-01 02:00:00 ┆ 0.707107 │ │ 3 ┆ 2001-01-01 03:00:00 ┆ 0.707107 │ │ 4 ┆ 2001-01-01 04:00:00 ┆ 0.707107 │ @@ -6859,7 +6859,7 @@ def rolling_std( │ --- ┆ --- ┆ --- │ │ u32 ┆ datetime[μs] ┆ f64 │ ╞═══════╪═════════════════════╪═════════════════╡ - │ 0 ┆ 2001-01-01 00:00:00 ┆ 0.0 │ + │ 0 ┆ 2001-01-01 00:00:00 ┆ null │ │ 1 ┆ 2001-01-01 01:00:00 ┆ 0.707107 │ │ 2 ┆ 2001-01-01 02:00:00 ┆ 1.0 │ │ 3 ┆ 2001-01-01 03:00:00 ┆ 1.0 │ @@ -7081,7 +7081,7 @@ def rolling_var( │ u32 ┆ datetime[μs] ┆ f64 │ ╞═══════╪═════════════════════╪═════════════════╡ │ 0 ┆ 2001-01-01 00:00:00 ┆ null │ - │ 1 ┆ 2001-01-01 01:00:00 ┆ 0.0 │ + │ 1 ┆ 2001-01-01 01:00:00 ┆ null │ │ 2 ┆ 2001-01-01 02:00:00 ┆ 0.5 │ │ 3 ┆ 2001-01-01 03:00:00 ┆ 0.5 │ │ 4 ┆ 2001-01-01 04:00:00 ┆ 0.5 │ @@ -7106,7 +7106,7 @@ def rolling_var( │ --- ┆ --- ┆ --- │ │ u32 ┆ datetime[μs] ┆ f64 │ ╞═══════╪═════════════════════╪═════════════════╡ - │ 0 ┆ 2001-01-01 00:00:00 ┆ 0.0 │ + │ 0 ┆ 2001-01-01 00:00:00 ┆ null │ │ 1 ┆ 2001-01-01 01:00:00 ┆ 0.5 │ │ 2 ┆ 2001-01-01 02:00:00 ┆ 1.0 │ │ 3 ┆ 2001-01-01 03:00:00 ┆ 1.0 │ diff --git a/py-polars/tests/parametric/test_groupby_rolling.py b/py-polars/tests/parametric/test_groupby_rolling.py index d6356167c7a8..39836c388014 100644 --- a/py-polars/tests/parametric/test_groupby_rolling.py +++ b/py-polars/tests/parametric/test_groupby_rolling.py @@ -95,8 +95,8 @@ def test_rolling( "max", "mean", "sum", - # "std", blocked by https://github.com/pola-rs/polars/issues/11140 - # "var", blocked by https://github.com/pola-rs/polars/issues/11140 + "std", + "var", "median", ] ), diff --git a/py-polars/tests/unit/operations/rolling/test_rolling.py b/py-polars/tests/unit/operations/rolling/test_rolling.py index e30cc160f505..6ac4a5937a12 100644 --- a/py-polars/tests/unit/operations/rolling/test_rolling.py +++ b/py-polars/tests/unit/operations/rolling/test_rolling.py @@ -12,7 +12,7 @@ from polars.testing import assert_frame_equal, assert_series_equal if TYPE_CHECKING: - from polars.type_aliases import ClosedInterval, TimeUnit + from polars.type_aliases import ClosedInterval, PolarsDataType, TimeUnit @pytest.fixture() @@ -188,18 +188,21 @@ def test_rolling_skew() -> None: @pytest.mark.parametrize("time_zone", [None, "US/Central"]) @pytest.mark.parametrize( - ("rolling_fn", "expected_values"), + ("rolling_fn", "expected_values", "expected_dtype"), [ - ("rolling_mean", [None, 1.0, 2.0, 3.0, 4.0, 5.0]), - ("rolling_sum", [None, 1, 2, 3, 4, 5]), - ("rolling_min", [None, 1, 2, 3, 4, 5]), - ("rolling_max", [None, 1, 2, 3, 4, 5]), - ("rolling_std", [None, 0.0, 0.0, 0.0, 0.0, 0.0]), - ("rolling_var", [None, 0.0, 0.0, 0.0, 0.0, 0.0]), + ("rolling_mean", [None, 1.0, 2.0, 3.0, 4.0, 5.0], pl.Float64), + ("rolling_sum", [None, 1, 2, 3, 4, 5], pl.Int64), + ("rolling_min", [None, 1, 2, 3, 4, 5], pl.Int64), + ("rolling_max", [None, 1, 2, 3, 4, 5], pl.Int64), + ("rolling_std", [None, None, None, None, None, None], pl.Float64), + ("rolling_var", [None, None, None, None, None, None], pl.Float64), ], ) def test_rolling_crossing_dst( - time_zone: str | None, rolling_fn: str, expected_values: list[int | None | float] + time_zone: str | None, + rolling_fn: str, + expected_values: list[int | None | float], + expected_dtype: PolarsDataType, ) -> None: ts = pl.datetime_range( datetime(2021, 11, 5), datetime(2021, 11, 10), "1d", time_zone="UTC", eager=True @@ -208,7 +211,9 @@ def test_rolling_crossing_dst( result = df.with_columns( getattr(pl.col("value"), rolling_fn)("1d", by="ts", closed="left") ) - expected = pl.DataFrame({"ts": ts, "value": expected_values}) + expected = pl.DataFrame( + {"ts": ts, "value": expected_values}, schema_overrides={"value": expected_dtype} + ) assert_frame_equal(result, expected) diff --git a/py-polars/tests/unit/test_lazy.py b/py-polars/tests/unit/test_lazy.py index dcc387d060fd..383294250109 100644 --- a/py-polars/tests/unit/test_lazy.py +++ b/py-polars/tests/unit/test_lazy.py @@ -738,8 +738,8 @@ def test_rolling(fruits_cars: pl.DataFrame) -> None: ] ).collect() - assert cast(float, out_single_val_variance[0, "std"]) == 0.0 - assert cast(float, out_single_val_variance[0, "var"]) == 0.0 + assert cast(float, out_single_val_variance[0, "std"]) is None + assert cast(float, out_single_val_variance[0, "var"]) is None def test_arr_namespace(fruits_cars: pl.DataFrame) -> None: