-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix(rust, python): Keep min/max and arg_min/arg_max consistent. #10716
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -31,7 +31,7 @@ impl ArgAgg for Series { | |||||||||||||
dt if dt.is_numeric() => { | ||||||||||||||
with_match_physical_numeric_polars_type!(s.dtype(), |$T| { | ||||||||||||||
let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); | ||||||||||||||
if ca.is_empty() { // because argminmax assumes not empty | ||||||||||||||
if ca.is_empty() || ca.null_count() == ca.len() { // because argminmax assumes not empty | ||||||||||||||
None | ||||||||||||||
} else if let Ok(vals) = ca.cont_slice() { | ||||||||||||||
arg_min_numeric_slice(vals, ca.is_sorted_flag()) | ||||||||||||||
|
@@ -59,7 +59,7 @@ impl ArgAgg for Series { | |||||||||||||
dt if dt.is_numeric() => { | ||||||||||||||
with_match_physical_numeric_polars_type!(s.dtype(), |$T| { | ||||||||||||||
let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); | ||||||||||||||
if ca.is_empty() { // because argminmax assumes not empty | ||||||||||||||
if ca.is_empty() || ca.null_count() == ca.len(){ // because argminmax assumes not empty | ||||||||||||||
None | ||||||||||||||
} else if let Ok(vals) = ca.cont_slice() { | ||||||||||||||
arg_max_numeric_slice(vals, ca.is_sorted_flag()) | ||||||||||||||
|
@@ -74,33 +74,50 @@ impl ArgAgg for Series { | |||||||||||||
} | ||||||||||||||
|
||||||||||||||
pub(crate) fn arg_max_bool(ca: &BooleanChunked) -> Option<usize> { | ||||||||||||||
if ca.is_empty() { | ||||||||||||||
if ca.is_empty() || ca.null_count() == ca.len() { | ||||||||||||||
None | ||||||||||||||
} else if ca.null_count() == ca.len() { | ||||||||||||||
Some(0) | ||||||||||||||
} | ||||||||||||||
// don't check for any, that on itself is already an argmax search | ||||||||||||||
else if ca.null_count() == 0 && ca.chunks().len() == 1 { | ||||||||||||||
let arr = ca.downcast_iter().next().unwrap(); | ||||||||||||||
let mask = arr.values(); | ||||||||||||||
Some(first_set_bit(mask)) | ||||||||||||||
} else { | ||||||||||||||
let mut first_false_idx: Option<usize> = None; | ||||||||||||||
ca.into_iter() | ||||||||||||||
.position(|opt_val| matches!(opt_val, Some(true))) | ||||||||||||||
.enumerate() | ||||||||||||||
.find_map(|(idx, val)| match val { | ||||||||||||||
Some(true) => Some(idx), | ||||||||||||||
Some(false) if first_false_idx.is_none() => { | ||||||||||||||
first_false_idx = Some(idx); | ||||||||||||||
None | ||||||||||||||
}, | ||||||||||||||
_ => None, | ||||||||||||||
}) | ||||||||||||||
.or(first_false_idx) | ||||||||||||||
} | ||||||||||||||
} | ||||||||||||||
|
||||||||||||||
fn arg_min_bool(ca: &BooleanChunked) -> Option<usize> { | ||||||||||||||
if ca.is_empty() || ca.null_count() == ca.len() || ca.all() { | ||||||||||||||
Some(0) | ||||||||||||||
if ca.is_empty() || ca.null_count() == ca.len() { | ||||||||||||||
None | ||||||||||||||
} else if ca.null_count() == 0 && ca.chunks().len() == 1 { | ||||||||||||||
let arr = ca.downcast_iter().next().unwrap(); | ||||||||||||||
let mask = arr.values(); | ||||||||||||||
Some(first_unset_bit(mask)) | ||||||||||||||
} else { | ||||||||||||||
// also null as we see that as lower in ordering than a set value | ||||||||||||||
let mut first_true_idx: Option<usize> = None; | ||||||||||||||
ca.into_iter() | ||||||||||||||
.position(|opt_val| matches!(opt_val, Some(false) | None)) | ||||||||||||||
.enumerate() | ||||||||||||||
.find_map(|(idx, val)| match val { | ||||||||||||||
Some(false) => Some(idx), | ||||||||||||||
Some(true) if first_true_idx.is_none() => { | ||||||||||||||
first_true_idx = Some(idx); | ||||||||||||||
None | ||||||||||||||
}, | ||||||||||||||
_ => None, | ||||||||||||||
}) | ||||||||||||||
.or(first_true_idx) | ||||||||||||||
} | ||||||||||||||
} | ||||||||||||||
|
||||||||||||||
|
@@ -195,21 +212,28 @@ fn first_unset_bit(mask: &Bitmap) -> usize { | |||||||||||||
} | ||||||||||||||
|
||||||||||||||
fn arg_min_str(ca: &Utf8Chunked) -> Option<usize> { | ||||||||||||||
if ca.is_empty() || ca.null_count() == ca.len() { | ||||||||||||||
return None; | ||||||||||||||
} | ||||||||||||||
match ca.is_sorted_flag() { | ||||||||||||||
IsSorted::Ascending => Some(0), | ||||||||||||||
IsSorted::Descending => Some(ca.len() - 1), | ||||||||||||||
IsSorted::Ascending => ca.first_non_null(), | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Perhaps we can test There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think it would matter much in performance. I leave this up to you. |
||||||||||||||
IsSorted::Descending => ca.last_non_null(), | ||||||||||||||
IsSorted::Not => ca | ||||||||||||||
.into_iter() | ||||||||||||||
.enumerate() | ||||||||||||||
.flat_map(|(idx, val)| val.map(|val| (idx, val))) | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Match in reduce will be faster as we have less indirection. I think we should first loop over There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Make sense, will rewrite this.
Does this means rewriting it to the same pattern of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. They are similar. We could even share a same generic if you are up for that. ^^ But that is maybe a nice follow up PR. :) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am trying to do this refactoring: Let them both share a same generic function as they are almost similar. The only difference is that for numerical types, we have a fast path with the following bound:
But for polars/crates/polars-ops/src/series/ops/arg_min_max.rs Lines 264 to 268 in ecb819a
I can think of some hack solution (such as a magical macro), but I don't really like that way. Do we have a good solution for this that looks simpler and cleaner 🤔. |
||||||||||||||
.reduce(|acc, (idx, val)| if acc.1 > val { (idx, val) } else { acc }) | ||||||||||||||
.map(|tpl| tpl.0), | ||||||||||||||
} | ||||||||||||||
} | ||||||||||||||
|
||||||||||||||
fn arg_max_str(ca: &Utf8Chunked) -> Option<usize> { | ||||||||||||||
if ca.is_empty() || ca.null_count() == ca.len() { | ||||||||||||||
return None; | ||||||||||||||
} | ||||||||||||||
match ca.is_sorted_flag() { | ||||||||||||||
IsSorted::Ascending => Some(ca.len() - 1), | ||||||||||||||
IsSorted::Descending => Some(0), | ||||||||||||||
IsSorted::Ascending => ca.last_non_null(), | ||||||||||||||
IsSorted::Descending => ca.first_non_null(), | ||||||||||||||
IsSorted::Not => ca | ||||||||||||||
.into_iter() | ||||||||||||||
.enumerate() | ||||||||||||||
|
@@ -224,45 +248,35 @@ where | |||||||||||||
for<'b> &'b [T::Native]: ArgMinMax, | ||||||||||||||
{ | ||||||||||||||
match ca.is_sorted_flag() { | ||||||||||||||
IsSorted::Ascending => Some(0), | ||||||||||||||
IsSorted::Descending => Some(ca.len() - 1), | ||||||||||||||
IsSorted::Ascending => ca.first_non_null(), | ||||||||||||||
IsSorted::Descending => ca.last_non_null(), | ||||||||||||||
IsSorted::Not => { | ||||||||||||||
ca.downcast_iter() | ||||||||||||||
.fold((None, None, 0), |acc, arr| { | ||||||||||||||
if arr.len() == 0 { | ||||||||||||||
return acc; | ||||||||||||||
} | ||||||||||||||
let chunk_min_idx: Option<usize>; | ||||||||||||||
let chunk_min_val: Option<T::Native>; | ||||||||||||||
if arr.null_count() > 0 { | ||||||||||||||
// When there are nulls, we should compare Option<T::Native> | ||||||||||||||
chunk_min_val = None; // because None < Some(_) | ||||||||||||||
chunk_min_idx = arr | ||||||||||||||
.into_iter() | ||||||||||||||
let chunk_min: Option<(usize, T::Native)> = if arr.null_count() > 0 { | ||||||||||||||
arr.into_iter() | ||||||||||||||
.enumerate() | ||||||||||||||
.flat_map(|(idx, val)| val.map(|val| (idx, *val))) | ||||||||||||||
reswqa marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||
.reduce(|acc, (idx, val)| if acc.1 > val { (idx, val) } else { acc }) | ||||||||||||||
.map(|tpl| tpl.0); | ||||||||||||||
} else { | ||||||||||||||
// When no nulls & array not empty => we can use fast argminmax | ||||||||||||||
let min_idx: usize = arr.values().as_slice().argmin(); | ||||||||||||||
chunk_min_idx = Some(min_idx); | ||||||||||||||
chunk_min_val = Some(arr.value(min_idx)); | ||||||||||||||
} | ||||||||||||||
Some((min_idx, arr.value(min_idx))) | ||||||||||||||
}; | ||||||||||||||
|
||||||||||||||
let new_offset: usize = acc.2 + arr.len(); | ||||||||||||||
match acc { | ||||||||||||||
(Some(_), Some(_), offset) => { | ||||||||||||||
if chunk_min_val < acc.1 { | ||||||||||||||
match chunk_min_idx { | ||||||||||||||
Some(idx) => (Some(idx + offset), chunk_min_val, new_offset), | ||||||||||||||
None => (acc.0, acc.1, new_offset), | ||||||||||||||
} | ||||||||||||||
} else { | ||||||||||||||
(acc.0, acc.1, new_offset) | ||||||||||||||
} | ||||||||||||||
(Some(_), Some(acc_v), offset) => match chunk_min { | ||||||||||||||
Some((idx, val)) if val < acc_v => { | ||||||||||||||
(Some(idx + offset), Some(val), new_offset) | ||||||||||||||
}, | ||||||||||||||
_ => (acc.0, acc.1, new_offset), | ||||||||||||||
}, | ||||||||||||||
(None, None, offset) => match chunk_min_idx { | ||||||||||||||
Some(idx) => (Some(idx + offset), chunk_min_val, new_offset), | ||||||||||||||
(None, None, offset) => match chunk_min { | ||||||||||||||
Some((idx, val)) => (Some(idx + offset), Some(val), new_offset), | ||||||||||||||
None => (None, None, new_offset), | ||||||||||||||
}, | ||||||||||||||
_ => unreachable!(), | ||||||||||||||
|
@@ -279,45 +293,36 @@ where | |||||||||||||
for<'b> &'b [T::Native]: ArgMinMax, | ||||||||||||||
{ | ||||||||||||||
match ca.is_sorted_flag() { | ||||||||||||||
IsSorted::Ascending => Some(ca.len() - 1), | ||||||||||||||
IsSorted::Descending => Some(0), | ||||||||||||||
IsSorted::Ascending => ca.last_non_null(), | ||||||||||||||
IsSorted::Descending => ca.first_non_null(), | ||||||||||||||
IsSorted::Not => { | ||||||||||||||
ca.downcast_iter() | ||||||||||||||
.fold((None, None, 0), |acc, arr| { | ||||||||||||||
if arr.len() == 0 { | ||||||||||||||
return acc; | ||||||||||||||
} | ||||||||||||||
let chunk_max_idx: Option<usize>; | ||||||||||||||
let chunk_max_val: Option<T::Native>; | ||||||||||||||
if arr.null_count() > 0 { | ||||||||||||||
let chunk_max: Option<(usize, T::Native)> = if arr.null_count() > 0 { | ||||||||||||||
// When there are nulls, we should compare Option<T::Native> | ||||||||||||||
chunk_max_idx = arr | ||||||||||||||
.into_iter() | ||||||||||||||
arr.into_iter() | ||||||||||||||
.enumerate() | ||||||||||||||
.flat_map(|(idx, val)| val.map(|val| (idx, *val))) | ||||||||||||||
.reduce(|acc, (idx, val)| if acc.1 < val { (idx, val) } else { acc }) | ||||||||||||||
.map(|tpl| tpl.0); | ||||||||||||||
chunk_max_val = chunk_max_idx.map(|idx| arr.value(idx)); | ||||||||||||||
} else { | ||||||||||||||
// When no nulls & array not empty => we can use fast argminmax | ||||||||||||||
let max_idx: usize = arr.values().as_slice().argmax(); | ||||||||||||||
chunk_max_idx = Some(max_idx); | ||||||||||||||
chunk_max_val = Some(arr.value(max_idx)); | ||||||||||||||
} | ||||||||||||||
Some((max_idx, arr.value(max_idx))) | ||||||||||||||
}; | ||||||||||||||
|
||||||||||||||
let new_offset: usize = acc.2 + arr.len(); | ||||||||||||||
match acc { | ||||||||||||||
(Some(_), Some(_), offset) => { | ||||||||||||||
if chunk_max_val > acc.1 { | ||||||||||||||
match chunk_max_idx { | ||||||||||||||
Some(idx) => (Some(idx + offset), chunk_max_val, new_offset), | ||||||||||||||
_ => unreachable!(), // because None < Some(_) | ||||||||||||||
} | ||||||||||||||
} else { | ||||||||||||||
(acc.0, acc.1, new_offset) | ||||||||||||||
} | ||||||||||||||
(Some(_), Some(acc_v), offset) => match chunk_max { | ||||||||||||||
Some((idx, val)) if acc_v < val => { | ||||||||||||||
(Some(idx + offset), Some(val), new_offset) | ||||||||||||||
}, | ||||||||||||||
_ => (acc.0, acc.1, new_offset), | ||||||||||||||
}, | ||||||||||||||
(None, None, offset) => match chunk_max_idx { | ||||||||||||||
Some(idx) => (Some(idx + offset), chunk_max_val, new_offset), | ||||||||||||||
(None, None, offset) => match chunk_max { | ||||||||||||||
Some((idx, val)) => (Some(idx + offset), Some(val), new_offset), | ||||||||||||||
None => (None, None, new_offset), | ||||||||||||||
}, | ||||||||||||||
_ => unreachable!(), | ||||||||||||||
|
@@ -333,7 +338,9 @@ where | |||||||||||||
for<'a> &'a [T]: ArgMinMax, | ||||||||||||||
{ | ||||||||||||||
match is_sorted { | ||||||||||||||
// all vals are not null guarded by cont_slice | ||||||||||||||
IsSorted::Ascending => Some(0), | ||||||||||||||
// all vals are not null guarded by cont_slice | ||||||||||||||
IsSorted::Descending => Some(vals.len() - 1), | ||||||||||||||
IsSorted::Not => Some(vals.argmin()), // assumes not empty | ||||||||||||||
} | ||||||||||||||
|
@@ -344,7 +351,9 @@ where | |||||||||||||
for<'a> &'a [T]: ArgMinMax, | ||||||||||||||
{ | ||||||||||||||
match is_sorted { | ||||||||||||||
// all vals are not null guarded by cont_slice | ||||||||||||||
IsSorted::Ascending => Some(vals.len() - 1), | ||||||||||||||
// all vals are not null guarded by cont_slice | ||||||||||||||
IsSorted::Descending => Some(0), | ||||||||||||||
IsSorted::Not => Some(vals.argmax()), // assumes not empty | ||||||||||||||
} | ||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic is: we first find
true
, and if it does not exist, we return the firstfalse
location. Is this exactly what we want?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@reswqa Yes, that's the correct behavior, although I think we can do a faster implementation more explicitly iterating over the bitmap. But that doesn't have to be in this PR.