Skip to content

Commit

Permalink
fix(rust, python): use row-encoded for struct::is_sorted (#10129)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Jul 28, 2023
1 parent fad5b13 commit 28dad4d
Showing 1 changed file with 15 additions and 20 deletions.
35 changes: 15 additions & 20 deletions polars/polars-ops/src/series/ops/various.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#[cfg(feature = "hash")]
use polars_core::export::ahash;
#[cfg(feature = "dtype-struct")]
use polars_core::prelude::sort::arg_sort_multiple::_get_rows_encoded_ca;
use polars_core::prelude::*;
use polars_core::series::IsSorted;

Expand Down Expand Up @@ -42,6 +44,14 @@ pub trait SeriesMethods: SeriesSealed {
fn is_sorted(&self, options: SortOptions) -> PolarsResult<bool> {
let s = self.as_series();

// for struct types we row-encode and recurse
#[cfg(feature = "dtype-struct")]
if matches!(s.dtype(), DataType::Struct(_)) {
let encoded =
_get_rows_encoded_ca("", &[s.clone()], &[options.descending], options.nulls_last)?;
return encoded.into_series().is_sorted(options);
}

// fast paths
if (options.descending
&& options.nulls_last
Expand Down Expand Up @@ -70,27 +80,12 @@ pub trait SeriesMethods: SeriesSealed {
// Compare adjacent elements with no-copy slices that don't include any nulls
let offset = !options.nulls_last as i64 * nc as i64;
let (s1, s2) = (s.slice(offset, slen), s.slice(offset + 1, slen));
let cmp_op = match options.descending {
true => Series::gt_eq,
false => Series::lt_eq,
let cmp_op = if options.descending {
Series::gt_eq
} else {
Series::lt_eq
};
match s.dtype() {
// For structs compare per-field. We don't have to check any types or field names though
// since we're just comparing two offset slices of the same Series. The loop is to both
// short-circuit on false and propagate errors. Maybe there's a way with iterators?
#[cfg(feature = "dtype-struct")]
DataType::Struct(_) => {
let mut struct_cmp = true;
for (l, r) in s1.struct_()?.fields().iter().zip(s2.struct_()?.fields()) {
struct_cmp &= cmp_op(l, r)?.all();
if !struct_cmp {
break;
}
}
Ok(struct_cmp)
}
_ => Ok(cmp_op(&s1, &s2)?.all()),
}
Ok(cmp_op(&s1, &s2)?.all())
}
}

Expand Down

0 comments on commit 28dad4d

Please sign in to comment.