Skip to content

Commit

Permalink
fix(rust, python): fix sorted fast path in streaming groupby wrt nulls (
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Aug 4, 2023
1 parent 3683344 commit 989ac8a
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 4 deletions.
16 changes: 12 additions & 4 deletions crates/polars-pipe/src/executors/sinks/groupby/primitive/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use std::sync::Mutex;

use hashbrown::hash_map::RawEntryMut;
use num_traits::NumCast;
use polars_arrow::is_valid::IsValid;
use polars_arrow::kernels::sort_partition::partition_to_groups_amortized;
use polars_core::export::ahash::RandomState;
use polars_core::frame::row::AnyValueBuffer;
Expand Down Expand Up @@ -221,15 +222,22 @@ where
partition_to_groups_amortized(values, 0, false, 0, &mut self.sort_partitions);

let pre_agg_len = self.pre_agg_partitions.len();
let null: Option<K::Native> = None;
let null_hash = self.hb.hash_one(null);

for group in &self.sort_partitions {
let [offset, length] = group;
let first_g_value = unsafe { *values.get_unchecked_release(*offset as usize) };
let h = self.hb.hash_one(first_g_value);
let (opt_v, h) = if unsafe { arr.is_valid_unchecked(*offset as usize) } {
let first_g_value = unsafe { *values.get_unchecked_release(*offset as usize) };
let h = self.hb.hash_one(first_g_value);
(Some(first_g_value), h)
} else {
(null, null_hash)
};

let agg_idx = insert_and_get(
h,
Some(first_g_value),
opt_v,
pre_agg_len,
&mut self.pre_agg_partitions,
&mut self.aggregators,
Expand Down Expand Up @@ -351,7 +359,7 @@ where
let ca: &ChunkedArray<K> = s.as_ref().as_ref();

// sorted fast path
if matches!(ca.is_sorted_flag(), IsSorted::Ascending) && ca.null_count() == 0 {
if matches!(ca.is_sorted_flag(), IsSorted::Ascending) {
return self.sink_sorted(ca, chunk);
}

Expand Down
23 changes: 23 additions & 0 deletions py-polars/tests/unit/streaming/test_streaming_groupby.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import pytest

import polars as pl


@pytest.mark.slow()
def test_streaming_groupby_sorted_fast_path_nulls_10273() -> None:
df = pl.Series(
name="x",
values=(
*(i for i in range(4) for _ in range(100)),
*(None for _ in range(100)),
),
).to_frame()

assert (
df.set_sorted("x")
.lazy()
.groupby("x")
.agg(pl.count())
.collect(streaming=True)
.sort("x")
).to_dict(False) == {"x": [None, 0, 1, 2, 3], "count": [100, 100, 100, 100, 100]}

0 comments on commit 989ac8a

Please sign in to comment.