Skip to content

Commit

Permalink
perf(rust, python): speedup mode on sorted data
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Jul 26, 2023
1 parent c52e70c commit 5c34959
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 22 deletions.
57 changes: 36 additions & 21 deletions polars/polars-core/src/chunked_array/ops/unique/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,36 +78,51 @@ where
}

#[cfg(feature = "mode")]
#[allow(clippy::needless_collect)]
fn mode_indices(groups: GroupsProxy) -> Vec<IdxSize> {
match groups {
GroupsProxy::Idx(groups) => {
let mut groups = groups.into_iter().collect_trusted::<Vec<_>>();
groups.sort_unstable_by_key(|k| k.1.len());
let last = &groups.last().unwrap();
let max_occur = last.1.len();
groups
.iter()
.rev()
.take_while(|v| v.1.len() == max_occur)
.map(|v| v.0)
.collect()
}
GroupsProxy::Slice { groups, .. } => {
let last = groups.last().unwrap();
let max_occur = last[1];

groups
.iter()
.rev()
.take_while(|v| {
let len = v[1];
len == max_occur
})
.map(|v| v[0])
.collect()
}
}
}

#[cfg(feature = "mode")]
fn mode<T: PolarsDataType>(ca: &ChunkedArray<T>) -> ChunkedArray<T>
where
ChunkedArray<T>: IntoGroupsProxy + ChunkTake,
{
if ca.is_empty() {
return ca.clone();
}
let mut groups = ca
.group_tuples(true, false)
.unwrap()
.into_idx()
.into_iter()
.collect_trusted::<Vec<_>>();
groups.sort_unstable_by_key(|k| k.1.len());
let last = &groups.last().unwrap();

let max_occur = last.1.len();

// collect until we don't take with trusted len anymore
// TODO! take directly from iter, but first remove standard trusted-length collect.
let idx = groups
.iter()
.rev()
.take_while(|v| v.1.len() == max_occur)
.map(|v| v.0)
.collect::<Vec<_>>();
let groups = ca.group_tuples(true, false).unwrap();
let idx = mode_indices(groups);

// Safety:
// group indices are in bounds
unsafe { ca.take_unchecked(idx.into_iter().map(|i| i as usize).into()) }
unsafe { ca.take_unchecked(idx.as_slice().into()) }
}

macro_rules! arg_unique_ca {
Expand Down
2 changes: 1 addition & 1 deletion polars/polars-core/src/frame/groupby/proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ impl GroupsProxy {
match self {
GroupsProxy::Idx(groups) => groups,
GroupsProxy::Slice { groups, .. } => {
eprintln!("Had to reallocate groups, missed an optimization opportunity. Please open an issue.");
polars_warn!("Had to reallocate groups, missed an optimization opportunity. Please open an issue.");
groups
.iter()
.map(|&[first, len]| (first, (first..first + len).collect_trusted::<Vec<_>>()))
Expand Down
3 changes: 3 additions & 0 deletions py-polars/tests/unit/series/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1262,6 +1262,9 @@ def test_mode() -> None:
)
assert pl.Series([1.0, 2.0, 3.0, 2.0]).mode().item() == 2.0

# sorted data
assert pl.int_range(0, 3, eager=True).mode().to_list() == [2, 1, 0]


def test_rank() -> None:
s = pl.Series("a", [1, 2, 3, 2, 2, 3, 0])
Expand Down

0 comments on commit 5c34959

Please sign in to comment.