Skip to content

Commit

Permalink
fix: empty product returns identity (#10842)
Browse files Browse the repository at this point in the history
  • Loading branch information
orlp authored Sep 1, 2023
1 parent 84c5c60 commit 69e4826
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 142 deletions.
36 changes: 15 additions & 21 deletions crates/polars-core/src/chunked_array/ops/aggregate/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use std::ops::Add;
use arrow::compute;
use arrow::types::simd::Simd;
use arrow::types::NativeType;
use num_traits::{Float, ToPrimitive, Zero};
use num_traits::{Float, One, ToPrimitive, Zero};
use polars_arrow::kernels::rolling::{compare_fn_nan_max, compare_fn_nan_min};
pub use quantile::*;
pub use var::*;
Expand Down Expand Up @@ -276,7 +276,7 @@ impl BooleanChunked {
}
}

// Needs the same trait bounds as the implementation of ChunkedArray<T> of dyn Series
// Needs the same trait bounds as the implementation of ChunkedArray<T> of dyn Series.
impl<T> ChunkAggSeries for ChunkedArray<T>
where
T: PolarsNumericType,
Expand All @@ -291,12 +291,14 @@ where
ca.rename(self.name());
ca.into_series()
}

fn max_as_series(&self) -> Series {
let v = ChunkAgg::max(self);
let mut ca: ChunkedArray<T> = [v].iter().copied().collect();
ca.rename(self.name());
ca.into_series()
}

fn min_as_series(&self) -> Series {
let v = ChunkAgg::min(self);
let mut ca: ChunkedArray<T> = [v].iter().copied().collect();
Expand All @@ -305,15 +307,11 @@ where
}

fn prod_as_series(&self) -> Series {
let mut prod = None;
for opt_v in self.into_iter() {
match (prod, opt_v) {
(_, None) => return Self::full_null(self.name(), 1).into_series(),
(None, Some(v)) => prod = Some(v),
(Some(p), Some(v)) => prod = Some(p * v),
}
let mut prod = T::Native::one();
for opt_v in self.into_iter().flatten() {
prod = prod * opt_v;
}
Self::from_slice_options(self.name(), &[prod]).into_series()
Self::from_slice_options(self.name(), &[Some(prod)]).into_series()
}
}

Expand Down Expand Up @@ -509,15 +507,13 @@ impl BinaryChunked {
match self.is_sorted_flag() {
IsSorted::Ascending => {
self.last_non_null().and_then(|idx| {
// Safety:
// last_non_null returns in bound index
// SAFETY: last_non_null returns in bound index.
unsafe { self.get_unchecked(idx) }
})
},
IsSorted::Descending => {
self.first_non_null().and_then(|idx| {
// Safety:
// first_non_null returns in bound index
// SAFETY: first_non_null returns in bound index.
unsafe { self.get_unchecked(idx) }
})
},
Expand All @@ -535,15 +531,13 @@ impl BinaryChunked {
match self.is_sorted_flag() {
IsSorted::Ascending => {
self.first_non_null().and_then(|idx| {
// Safety:
// first_non_null returns in bound index
// SAFETY: first_non_null returns in bound index.
unsafe { self.get_unchecked(idx) }
})
},
IsSorted::Descending => {
self.last_non_null().and_then(|idx| {
// Safety:
// last_non_null returns in bound index
// SAFETY: last_non_null returns in bound index.
unsafe { self.get_unchecked(idx) }
})
},
Expand Down Expand Up @@ -606,9 +600,9 @@ mod test {

#[test]
fn test_var() {
// validated with numpy
// Note that numpy as an argument ddof which influences results. The default is ddof=0
// we chose ddof=1, which is standard in statistics
// Validated with numpy. Note that numpy uses ddof as an argument which
// influences results. The default ddof=0, we chose ddof=1, which is
// standard in statistics.
let ca1 = Int32Chunked::new("", &[5, 8, 9, 5, 0]);
let ca2 = Int32Chunked::new(
"",
Expand Down
3 changes: 2 additions & 1 deletion crates/polars-core/src/datatypes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ use arrow::types::NativeType;
pub use dtype::*;
pub use field::*;
pub use from_values::ArrayFromElementIter;
use num_traits::{Bounded, FromPrimitive, Num, NumCast, Zero};
use num_traits::{Bounded, FromPrimitive, Num, NumCast, One, Zero};
use polars_arrow::data_types::IsFloat;
#[cfg(feature = "serde")]
use serde::de::{EnumAccess, Error, Unexpected, VariantAccess, Visitor};
Expand Down Expand Up @@ -190,6 +190,7 @@ pub trait NumericNative:
+ Num
+ NumCast
+ Zero
+ One
+ Simd
+ Simd8
+ std::iter::Sum<Self>
Expand Down
Loading

0 comments on commit 69e4826

Please sign in to comment.