Skip to content

Commit

Permalink
fix: handle aggregation for all-nan groups in group_by (pola-rs#12304)
Browse files Browse the repository at this point in the history
  • Loading branch information
rob-sil authored Nov 30, 2023
1 parent 4027ced commit 41fee44
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 87 deletions.
34 changes: 10 additions & 24 deletions crates/polars-arrow/src/legacy/kernels/take_agg/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,14 @@ pub unsafe fn take_agg_no_null_primitive_iter_unchecked<
arr: &PrimitiveArray<T>,
indices: I,
f: F,
init: TOut,
) -> TOut {
) -> Option<TOut> {
debug_assert!(arr.null_count() == 0);
let array_values = arr.values().as_slice();

indices.into_iter().fold(init, |acc, idx| {
f(
acc,
NumCast::from(*array_values.get_unchecked(idx)).unwrap_unchecked(),
)
})
indices
.into_iter()
.map(|idx| TOut::from(*array_values.get_unchecked(idx)).unwrap_unchecked())
.reduce(f)
}

/// Take kernel for single chunk and an iterator as index.
Expand All @@ -48,26 +45,15 @@ pub unsafe fn take_agg_primitive_iter_unchecked<
arr: &PrimitiveArray<T>,
indices: I,
f: F,
init: T,
len: IdxSize,
) -> Option<T> {
let array_values = arr.values().as_slice();
let validity = arr.validity().unwrap();
let mut null_count = 0 as IdxSize;

let out = indices.into_iter().fold(init, |acc, idx| {
if validity.get_bit_unchecked(idx) {
f(acc, *array_values.get_unchecked(idx))
} else {
null_count += 1;
acc
}
});
if null_count == len {
None
} else {
Some(out)
}
indices
.into_iter()
.filter(|&idx| validity.get_bit_unchecked(idx))
.map(|idx| *array_values.get_unchecked(idx))
.reduce(f)
}

/// Take kernel for single chunk and an iterator as index.
Expand Down
58 changes: 15 additions & 43 deletions crates/polars-core/src/frame/group_by/aggregations/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -479,20 +479,13 @@ where
} else if idx.len() == 1 {
arr.get(first as usize)
} else if no_nulls {
Some(take_agg_no_null_primitive_iter_unchecked(
take_agg_no_null_primitive_iter_unchecked::<_, T::Native, _, _>(
arr,
idx2usize(idx),
|a, b| a.take_min(b),
T::Native::max_value(),
))
} else {
take_agg_primitive_iter_unchecked::<T::Native, _, _>(
arr,
idx2usize(idx),
|a, b| a.take_min(b),
T::Native::max_value(),
idx.len() as IdxSize,
)
} else {
take_agg_primitive_iter_unchecked(arr, idx2usize(idx), |a, b| a.take_min(b))
}
})
},
Expand Down Expand Up @@ -560,22 +553,13 @@ where
} else if idx.len() == 1 {
arr.get(first as usize)
} else if no_nulls {
Some({
take_agg_no_null_primitive_iter_unchecked(
arr,
idx2usize(idx),
|a, b| a.take_max(b),
T::Native::min_value(),
)
})
} else {
take_agg_primitive_iter_unchecked::<T::Native, _, _>(
take_agg_no_null_primitive_iter_unchecked::<_, T::Native, _, _>(
arr,
idx2usize(idx),
|a, b| a.take_max(b),
T::Native::min_value(),
idx.len() as IdxSize,
)
} else {
take_agg_primitive_iter_unchecked(arr, idx2usize(idx), |a, b| a.take_max(b))
}
})
},
Expand Down Expand Up @@ -632,21 +616,11 @@ where
} else if idx.len() == 1 {
arr.get(first as usize).unwrap_or(T::Native::zero())
} else if no_nulls {
take_agg_no_null_primitive_iter_unchecked(
arr,
idx2usize(idx),
|a, b| a + b,
T::Native::zero(),
)
take_agg_no_null_primitive_iter_unchecked(arr, idx2usize(idx), |a, b| a + b)
.unwrap_or(T::Native::zero())
} else {
take_agg_primitive_iter_unchecked::<T::Native, _, _>(
arr,
idx2usize(idx),
|a, b| a + b,
T::Native::zero(),
idx.len() as IdxSize,
)
.unwrap_or(T::Native::zero())
take_agg_primitive_iter_unchecked(arr, idx2usize(idx), |a, b| a + b)
.unwrap_or(T::Native::zero())
}
})
},
Expand Down Expand Up @@ -716,12 +690,12 @@ where
} else if idx.len() == 1 {
arr.get(first as usize).map(|sum| sum.to_f64().unwrap())
} else if no_nulls {
take_agg_no_null_primitive_iter_unchecked(
take_agg_no_null_primitive_iter_unchecked::<_, T::Native, _, _>(
arr,
idx2usize(idx),
|a, b| a + b,
T::Native::zero(),
)
.unwrap()
.to_f64()
.map(|sum| sum / idx.len() as f64)
} else {
Expand Down Expand Up @@ -955,15 +929,13 @@ where
} else {
match (self.has_validity(), self.chunks.len()) {
(false, 1) => {
take_agg_no_null_primitive_iter_unchecked(
take_agg_no_null_primitive_iter_unchecked::<_, f64, _, _>(
self.downcast_iter().next().unwrap(),
idx2usize(idx),
|a, b| a + b,
0.0f64,
)
}
.to_f64()
.map(|sum| sum / idx.len() as f64),
.map(|sum| sum / idx.len() as f64)
},
(_, 1) => {
{
take_agg_primitive_iter_unchecked_count_nulls::<
Expand Down
27 changes: 9 additions & 18 deletions crates/polars-ops/src/chunked_array/nan_propagating_aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ use arrow::legacy::kernels::rolling::no_nulls::{MaxWindow, MinWindow};
use arrow::legacy::kernels::take_agg::{
take_agg_no_null_primitive_iter_unchecked, take_agg_primitive_iter_unchecked,
};
use polars_core::export::num::Bounded;
use polars_core::frame::group_by::aggregations::{
_agg_helper_idx, _agg_helper_slice, _rolling_apply_agg_window_no_nulls,
_rolling_apply_agg_window_nulls, _slice_from_offsets, _use_rolling_kernels,
Expand Down Expand Up @@ -99,20 +98,15 @@ where
ca.get(first as usize)
} else {
match (ca.has_validity(), ca.chunks().len()) {
(false, 1) => Some({
take_agg_no_null_primitive_iter_unchecked(
ca.downcast_iter().next().unwrap(),
idx.iter().map(|i| *i as usize),
nan_max,
T::Native::min_value(),
)
}),
(_, 1) => take_agg_primitive_iter_unchecked::<T::Native, _, _>(
(false, 1) => take_agg_no_null_primitive_iter_unchecked(
ca.downcast_iter().next().unwrap(),
idx.iter().map(|i| *i as usize),
nan_max,
),
(_, 1) => take_agg_primitive_iter_unchecked(
ca.downcast_iter().next().unwrap(),
idx.iter().map(|i| *i as usize),
nan_max,
T::Native::min_value(),
idx.len() as IdxSize,
),
_ => {
let take = { ca.take_unchecked(idx) };
Expand Down Expand Up @@ -173,18 +167,15 @@ where
ca.get(first as usize)
} else {
match (ca.has_validity(), ca.chunks().len()) {
(false, 1) => Some(take_agg_no_null_primitive_iter_unchecked(
(false, 1) => take_agg_no_null_primitive_iter_unchecked(
ca.downcast_iter().next().unwrap(),
idx.iter().map(|i| *i as usize),
nan_min,
T::Native::max_value(),
)),
(_, 1) => take_agg_primitive_iter_unchecked::<T::Native, _, _>(
),
(_, 1) => take_agg_primitive_iter_unchecked(
ca.downcast_iter().next().unwrap(),
idx.iter().map(|i| *i as usize),
nan_min,
T::Native::max_value(),
idx.len() as IdxSize,
),
_ => {
let take = { ca.take_unchecked(idx) };
Expand Down
43 changes: 41 additions & 2 deletions py-polars/tests/unit/operations/test_aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,8 @@ def test_mean_null_simd() -> None:
.select(pl.when(pl.col("a") > 40).then(pl.col("a")))
)

s = df["a"]
assert s.mean() == s.to_pandas().mean()
s = df["a"]
assert s.mean() == s.to_pandas().mean()


def test_literal_group_agg_chunked_7968() -> None:
Expand Down Expand Up @@ -328,3 +328,42 @@ def test_binary_op_agg_context_no_simplify_expr_12423() -> None:
.agg(y=pl.lit(1) * pl.lit(1))
.collect(simplify_expression=simplify_expression),
)


def test_nan_inf_aggregation() -> None:
df = pl.DataFrame(
[
("both nan", np.nan),
("both nan", np.nan),
("nan and 5", np.nan),
("nan and 5", 5),
("nan and null", np.nan),
("nan and null", None),
("both none", None),
("both none", None),
("both inf", np.inf),
("both inf", np.inf),
("inf and null", np.inf),
("inf and null", None),
],
schema=["group", "value"],
)

assert_frame_equal(
df.group_by("group", maintain_order=True).agg(
min=pl.col("value").min(),
max=pl.col("value").max(),
mean=pl.col("value").mean(),
),
pl.DataFrame(
[
("both nan", np.nan, np.nan, np.nan),
("nan and 5", 5, 5, np.nan),
("nan and null", np.nan, np.nan, np.nan),
("both none", None, None, None),
("both inf", np.inf, np.inf, np.inf),
("inf and null", np.inf, np.inf, np.inf),
],
schema=["group", "min", "max", "mean"],
),
)

0 comments on commit 41fee44

Please sign in to comment.