Skip to content

Commit

Permalink
fix(rust, python): fix nan aggregation in groupby (#10287)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Aug 4, 2023
1 parent 8874a17 commit 1bbe101
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 41 deletions.
2 changes: 2 additions & 0 deletions crates/polars-arrow/src/kernels/rolling/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ pub fn compare_fn_nan_min<T>(a: &T, b: &T) -> Ordering
where
T: PartialOrd + IsFloat,
{
// this branch should be opimized away for integers
if T::is_float() {
match (a.is_nan(), b.is_nan()) {
// safety: we checked nans
Expand All @@ -51,6 +52,7 @@ pub fn compare_fn_nan_max<T>(a: &T, b: &T) -> Ordering
where
T: PartialOrd + IsFloat,
{
// this branch should be opimized away for integers
if T::is_float() {
match (a.is_nan(), b.is_nan()) {
// safety: we checked nans
Expand Down
104 changes: 83 additions & 21 deletions crates/polars-core/src/frame/groupby/aggregations/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ mod boolean;
mod dispatch;
mod utf8;

use std::cmp::Ordering;

pub use agg_list::*;
use arrow::bitmap::{Bitmap, MutableBitmap};
use arrow::types::simd::Simd;
Expand All @@ -15,7 +17,9 @@ use polars_arrow::kernels::rolling::no_nulls::{
MaxWindow, MeanWindow, MinWindow, QuantileWindow, RollingAggWindowNoNulls, SumWindow, VarWindow,
};
use polars_arrow::kernels::rolling::nulls::RollingAggWindowNulls;
use polars_arrow::kernels::rolling::{DynArgs, RollingQuantileParams, RollingVarParams};
use polars_arrow::kernels::rolling::{
compare_fn_nan_max, compare_fn_nan_min, DynArgs, RollingQuantileParams, RollingVarParams,
};
use polars_arrow::kernels::take_agg::*;
use polars_arrow::prelude::QuantileInterpolOptions;
use polars_arrow::trusted_len::TrustedLenPush;
Expand Down Expand Up @@ -211,24 +215,75 @@ where
ca.into_inner().into_series()
}

#[inline(always)]
fn take_min<T: PartialOrd>(a: T, b: T) -> T {
if a < b {
a
} else {
b
}
pub trait TakeExtremum {
fn take_min(self, other: Self) -> Self;

fn take_max(self, other: Self) -> Self;
}

#[inline(always)]
fn take_max<T: PartialOrd>(a: T, b: T) -> T {
if a > b {
a
} else {
b
}
macro_rules! impl_take_extremum {
($tp:ty) => {
impl TakeExtremum for $tp {
#[inline(always)]
fn take_min(self, other: Self) -> Self {
if self < other {
self
} else {
other
}
}

#[inline(always)]
fn take_max(self, other: Self) -> Self {
if self > other {
self
} else {
other
}
}
}
};

(float: $tp:ty) => {
impl TakeExtremum for $tp {
#[inline(always)]
fn take_min(self, other: Self) -> Self {
if matches!(compare_fn_nan_max(&self, &other), Ordering::Less) {
self
} else {
other
}
}

#[inline(always)]
fn take_max(self, other: Self) -> Self {
if matches!(compare_fn_nan_min(&self, &other), Ordering::Greater) {
self
} else {
other
}
}
}
};
}

#[cfg(feature = "dtype-u8")]
impl_take_extremum!(u8);
#[cfg(feature = "dtype-u16")]
impl_take_extremum!(u16);
impl_take_extremum!(u32);
impl_take_extremum!(u64);
#[cfg(feature = "dtype-i8")]
impl_take_extremum!(i8);
#[cfg(feature = "dtype-i16")]
impl_take_extremum!(i16);
impl_take_extremum!(i32);
impl_take_extremum!(i64);
#[cfg(feature = "dtype-decimal")]
impl_take_extremum!(i128);
impl_take_extremum!(float: f32);
impl_take_extremum!(float: f64);

/// Intermediate helper trait so we can have a single generic implementation
/// This trait will ensure the specific dispatch works without complicating
/// the trait bounds.
Expand Down Expand Up @@ -394,8 +449,15 @@ where
impl<T> ChunkedArray<T>
where
T: PolarsNumericType + Sync,
T::Native:
NativeType + PartialOrd + Num + NumCast + Zero + Simd + Bounded + std::iter::Sum<T::Native>,
T::Native: NativeType
+ PartialOrd
+ Num
+ NumCast
+ Zero
+ Simd
+ Bounded
+ std::iter::Sum<T::Native>
+ TakeExtremum,
<T::Native as Simd>::Simd: std::ops::Add<Output = <T::Native as Simd>::Simd>
+ arrow::compute::aggregate::Sum<T::Native>
+ arrow::compute::aggregate::SimdOrd<T::Native>,
Expand Down Expand Up @@ -427,14 +489,14 @@ where
Some(take_agg_no_null_primitive_iter_unchecked(
arr,
idx2usize(idx),
take_min,
|a, b| a.take_min(b),
T::Native::max_value(),
))
} else {
take_agg_primitive_iter_unchecked::<T::Native, _, _>(
arr,
idx2usize(idx),
take_min,
|a, b| a.take_min(b),
T::Native::max_value(),
idx.len() as IdxSize,
)
Expand Down Expand Up @@ -509,15 +571,15 @@ where
take_agg_no_null_primitive_iter_unchecked(
arr,
idx2usize(idx),
take_max,
|a, b| a.take_max(b),
T::Native::min_value(),
)
})
} else {
take_agg_primitive_iter_unchecked::<T::Native, _, _>(
arr,
idx2usize(idx),
take_max,
|a, b| a.take_max(b),
T::Native::min_value(),
idx.len() as IdxSize,
)
Expand Down
34 changes: 34 additions & 0 deletions py-polars/tests/unit/datatypes/test_float.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import polars as pl


def test_nan_in_groupby_agg() -> None:
df = pl.DataFrame(
{
"key": ["a", "a", "a", "a"],
"value": [18.58, 18.78, float("nan"), 18.63],
"bar": [0, 0, 0, 0],
}
)

assert df.groupby("bar", "key").agg(pl.col("value").max())["value"].item() == 18.78
assert df.groupby("bar", "key").agg(pl.col("value").min())["value"].item() == 18.58


def test_nan_aggregations() -> None:
df = pl.DataFrame({"a": [1.0, float("nan"), 2.0, 3.0], "b": [1, 1, 1, 1]})

aggs = [
pl.col("a").max().alias("max"),
pl.col("a").min().alias("min"),
pl.col("a").nan_max().alias("nan_max"),
pl.col("a").nan_min().alias("nan_min"),
]

assert (
str(df.select(aggs).to_dict(False))
== "{'max': [3.0], 'min': [1.0], 'nan_max': [nan], 'nan_min': [nan]}"
)
assert (
str(df.groupby("b").agg(aggs).to_dict(False))
== "{'b': [1], 'max': [3.0], 'min': [1.0], 'nan_max': [nan], 'nan_min': [nan]}"
)
20 changes: 0 additions & 20 deletions py-polars/tests/unit/functions/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,26 +271,6 @@ def test_align_frames_duplicate_key() -> None:
]


def test_nan_aggregations() -> None:
df = pl.DataFrame({"a": [1.0, float("nan"), 2.0, 3.0], "b": [1, 1, 1, 1]})

aggs = [
pl.col("a").max().alias("max"),
pl.col("a").min().alias("min"),
pl.col("a").nan_max().alias("nan_max"),
pl.col("a").nan_min().alias("nan_min"),
]

assert (
str(df.select(aggs).to_dict(False))
== "{'max': [3.0], 'min': [1.0], 'nan_max': [nan], 'nan_min': [nan]}"
)
assert (
str(df.groupby("b").agg(aggs).to_dict(False))
== "{'b': [1], 'max': [3.0], 'min': [2.0], 'nan_max': [nan], 'nan_min': [nan]}"
)


def test_coalesce() -> None:
df = pl.DataFrame(
{
Expand Down

0 comments on commit 1bbe101

Please sign in to comment.