Skip to content

Commit

Permalink
fix(rust, python): Keep min/max and arg_min/arg_max consistent.
Browse files Browse the repository at this point in the history
  • Loading branch information
reswqa committed Aug 24, 2023
1 parent 0d6a4bc commit 21a544b
Show file tree
Hide file tree
Showing 2 changed files with 162 additions and 73 deletions.
138 changes: 78 additions & 60 deletions crates/polars-ops/src/series/ops/arg_min_max.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ impl ArgAgg for Series {
dt if dt.is_numeric() => {
with_match_physical_numeric_polars_type!(s.dtype(), |$T| {
let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref();
if ca.is_empty() { // because argminmax assumes not empty
if ca.is_empty() || ca.null_count() == ca.len() { // because argminmax assumes not empty
None
} else if let Ok(vals) = ca.cont_slice() {
arg_min_numeric_slice(vals, ca.is_sorted_flag())
Expand Down Expand Up @@ -59,7 +59,7 @@ impl ArgAgg for Series {
dt if dt.is_numeric() => {
with_match_physical_numeric_polars_type!(s.dtype(), |$T| {
let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref();
if ca.is_empty() { // because argminmax assumes not empty
if ca.is_empty() || ca.null_count() == ca.len(){ // because argminmax assumes not empty
None
} else if let Ok(vals) = ca.cont_slice() {
arg_max_numeric_slice(vals, ca.is_sorted_flag())
Expand All @@ -74,35 +74,50 @@ impl ArgAgg for Series {
}

pub(crate) fn arg_max_bool(ca: &BooleanChunked) -> Option<usize> {
if ca.is_empty() {
if ca.is_empty() || ca.null_count() == ca.len() {
None
} else if ca.null_count() == ca.len() {
Some(0)
}
// don't check for any, that on itself is already an argmax search
else if ca.null_count() == 0 && ca.chunks().len() == 1 {
let arr = ca.downcast_iter().next().unwrap();
let mask = arr.values();
Some(first_set_bit(mask))
} else {
let mut first_false_idx: Option<usize> = None;
ca.into_iter()
.position(|opt_val| matches!(opt_val, Some(true)))
.enumerate()
.find_map(|(idx, val)| match val {
Some(true) => Some(idx),
Some(false) if first_false_idx.is_none() => {
first_false_idx = Some(idx);
None
},
_ => None,
})
.or(first_false_idx)
}
}

fn arg_min_bool(ca: &BooleanChunked) -> Option<usize> {
if ca.is_empty() {
if ca.is_empty() || ca.null_count() == ca.len() {
None
} else if ca.null_count() == ca.len() || ca.all() {
Some(0)
} else if ca.null_count() == 0 && ca.chunks().len() == 1 {
let arr = ca.downcast_iter().next().unwrap();
let mask = arr.values();
Some(first_unset_bit(mask))
} else {
// also null as we see that as lower in ordering than a set value
let mut first_true_idx: Option<usize> = None;
ca.into_iter()
.position(|opt_val| matches!(opt_val, Some(false) | None))
.enumerate()
.find_map(|(idx, val)| match val {
Some(false) => Some(idx),
Some(true) if first_true_idx.is_none() => {
first_true_idx = Some(idx);
None
},
_ => None,
})
.or(first_true_idx)
}
}

Expand Down Expand Up @@ -197,30 +212,38 @@ fn first_unset_bit(mask: &Bitmap) -> usize {
}

fn arg_min_str(ca: &Utf8Chunked) -> Option<usize> {
if ca.is_empty() {
if ca.is_empty() || ca.null_count() == ca.len() {
return None;
}
match ca.is_sorted_flag() {
IsSorted::Ascending => Some(0),
IsSorted::Descending => Some(ca.len() - 1),
IsSorted::Ascending => ca.first_non_null(),
IsSorted::Descending => ca.last_non_null(),
IsSorted::Not => ca
.into_iter()
.enumerate()
.filter_map(|(idx, val)| match val {
Some(val) => Some((idx, val)),
_ => None,
})
.reduce(|acc, (idx, val)| if acc.1 > val { (idx, val) } else { acc })
.map(|tpl| tpl.0),
}
}

fn arg_max_str(ca: &Utf8Chunked) -> Option<usize> {
if ca.is_empty() {
if ca.is_empty() || ca.null_count() == ca.len() {
return None;
}
match ca.is_sorted_flag() {
IsSorted::Ascending => Some(ca.len() - 1),
IsSorted::Descending => Some(0),
IsSorted::Ascending => ca.last_non_null(),
IsSorted::Descending => ca.first_non_null(),
IsSorted::Not => ca
.into_iter()
.enumerate()
.filter_map(|(idx, val)| match val {
Some(val) => Some((idx, val)),
_ => None,
})
.reduce(|acc, (idx, val)| if acc.1 < val { (idx, val) } else { acc })
.map(|tpl| tpl.0),
}
Expand All @@ -232,45 +255,40 @@ where
for<'b> &'b [T::Native]: ArgMinMax,
{
match ca.is_sorted_flag() {
IsSorted::Ascending => Some(0),
IsSorted::Descending => Some(ca.len() - 1),
IsSorted::Ascending => ca.first_non_null(),
IsSorted::Descending => ca.last_non_null(),
IsSorted::Not => {
ca.downcast_iter()
.fold((None, None, 0), |acc, arr| {
if arr.len() == 0 {
return acc;
}
let chunk_min_idx: Option<usize>;
let chunk_min_val: Option<T::Native>;
let chunk_min: Option<(usize, T::Native)>;
if arr.null_count() > 0 {
// When there are nulls, we should compare Option<T::Native>
chunk_min_val = None; // because None < Some(_)
chunk_min_idx = arr
chunk_min = arr
.into_iter()
.enumerate()
.reduce(|acc, (idx, val)| if acc.1 > val { (idx, val) } else { acc })
.map(|tpl| tpl.0);
.filter_map(|(idx, val)| match val {
Some(val) => Some((idx, *val)),
_ => None,
})
.reduce(|acc, (idx, val)| if acc.1 > val { (idx, val) } else { acc });
} else {
// When no nulls & array not empty => we can use fast argminmax
let min_idx: usize = arr.values().as_slice().argmin();
chunk_min_idx = Some(min_idx);
chunk_min_val = Some(arr.value(min_idx));
chunk_min = Some((min_idx, arr.value(min_idx)));
}

let new_offset: usize = acc.2 + arr.len();
match acc {
(Some(_), Some(_), offset) => {
if chunk_min_val < acc.1 {
match chunk_min_idx {
Some(idx) => (Some(idx + offset), chunk_min_val, new_offset),
None => (acc.0, acc.1, new_offset),
}
} else {
(acc.0, acc.1, new_offset)
}
(Some(_), Some(acc_v), offset) => match chunk_min {
Some((idx, val)) if val < acc_v => {
(Some(idx + offset), Some(val), new_offset)
},
_ => (acc.0, acc.1, new_offset),
},
(None, None, offset) => match chunk_min_idx {
Some(idx) => (Some(idx + offset), chunk_min_val, new_offset),
(None, None, offset) => match chunk_min {
Some((idx, val)) => (Some(idx + offset), Some(val), new_offset),
None => (None, None, new_offset),
},
_ => unreachable!(),
Expand All @@ -287,45 +305,41 @@ where
for<'b> &'b [T::Native]: ArgMinMax,
{
match ca.is_sorted_flag() {
IsSorted::Ascending => Some(ca.len() - 1),
IsSorted::Descending => Some(0),
IsSorted::Ascending => ca.last_non_null(),
IsSorted::Descending => ca.first_non_null(),
IsSorted::Not => {
ca.downcast_iter()
.fold((None, None, 0), |acc, arr| {
if arr.len() == 0 {
return acc;
}
let chunk_max_idx: Option<usize>;
let chunk_max_val: Option<T::Native>;
let chunk_max: Option<(usize, T::Native)>;
if arr.null_count() > 0 {
// When there are nulls, we should compare Option<T::Native>
chunk_max_idx = arr
chunk_max = arr
.into_iter()
.enumerate()
.reduce(|acc, (idx, val)| if acc.1 < val { (idx, val) } else { acc })
.map(|tpl| tpl.0);
chunk_max_val = chunk_max_idx.map(|idx| arr.value(idx));
.filter_map(|(idx, val)| match val {
Some(val) => Some((idx, *val)),
_ => None,
})
.reduce(|acc, (idx, val)| if acc.1 < val { (idx, val) } else { acc });
} else {
// When no nulls & array not empty => we can use fast argminmax
let max_idx: usize = arr.values().as_slice().argmax();
chunk_max_idx = Some(max_idx);
chunk_max_val = Some(arr.value(max_idx));
chunk_max = Some((max_idx, arr.value(max_idx)));
}

let new_offset: usize = acc.2 + arr.len();
match acc {
(Some(_), Some(_), offset) => {
if chunk_max_val > acc.1 {
match chunk_max_idx {
Some(idx) => (Some(idx + offset), chunk_max_val, new_offset),
_ => unreachable!(), // because None < Some(_)
}
} else {
(acc.0, acc.1, new_offset)
}
(Some(_), Some(acc_v), offset) => match chunk_max {
Some((idx, val)) if acc_v < val => {
(Some(idx + offset), Some(val), new_offset)
},
_ => (acc.0, acc.1, new_offset),
},
(None, None, offset) => match chunk_max_idx {
Some(idx) => (Some(idx + offset), chunk_max_val, new_offset),
(None, None, offset) => match chunk_max {
Some((idx, val)) => (Some(idx + offset), Some(val), new_offset),
None => (None, None, new_offset),
},
_ => unreachable!(),
Expand All @@ -341,7 +355,9 @@ where
for<'a> &'a [T]: ArgMinMax,
{
match is_sorted {
// all vals are not null guarded by cont_slice
IsSorted::Ascending => Some(0),
// all vals are not null guarded by cont_slice
IsSorted::Descending => Some(vals.len() - 1),
IsSorted::Not => Some(vals.argmin()), // assumes not empty
}
Expand All @@ -352,7 +368,9 @@ where
for<'a> &'a [T]: ArgMinMax,
{
match is_sorted {
// all vals are not null guarded by cont_slice
IsSorted::Ascending => Some(vals.len() - 1),
// all vals are not null guarded by cont_slice
IsSorted::Descending => Some(0),
IsSorted::Not => Some(vals.argmax()), // assumes not empty
}
Expand Down
97 changes: 84 additions & 13 deletions py-polars/tests/unit/series/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1738,40 +1738,111 @@ def test_arg_sort() -> None:


def test_arg_min_and_arg_max() -> None:
s = pl.Series("a", [5, 3, 4, 1, 2])
# numerical no null.
s = pl.Series([5, 3, 4, 1, 2])
assert s.arg_min() == 3
assert s.arg_max() == 0

s = pl.Series([None, True, False, True])
assert s.arg_min() == 0
# numerical has null.
s = pl.Series([None, 5, 1])
assert s.arg_min() == 2
assert s.arg_max() == 1
s = pl.Series([None, None], dtype=pl.Boolean)
assert s.arg_min() == 0

# numerical all null.
s = pl.Series([None, None], dtype=Int32)
assert s.arg_min() is None
assert s.arg_max() is None

# boolean no null.
s = pl.Series([True, False])
assert s.arg_min() == 1
assert s.arg_max() == 0
s = pl.Series([True, True])
assert s.arg_min() == 0
assert s.arg_max() == 0
s = pl.Series([False, False])
assert s.arg_min() == 0
assert s.arg_max() == 0

# boolean has null.
s = pl.Series([None, True, False, True])
assert s.arg_min() == 2
assert s.arg_max() == 1
s = pl.Series([None, True, True])
assert s.arg_min() == 1
assert s.arg_max() == 1
s = pl.Series([None, False, False])
assert s.arg_min() == 1
assert s.arg_max() == 1

# boolean all null.
s = pl.Series([None, None], dtype=pl.Boolean)
assert s.arg_min() is None
assert s.arg_max() is None

# utf8 no null
s = pl.Series(["a", "c", "b"])
assert s.arg_min() == 0
assert s.arg_max() == 1

# utf8 has null
s = pl.Series([None, "a", None, "b"])
assert s.arg_min() == 1
assert s.arg_max() == 3

# utf8 all null
s = pl.Series([None, None], dtype=pl.Utf8)
assert s.arg_min() is None
assert s.arg_max() is None

# test ascending and descending series
s = pl.Series("a", [1, 2, 3, 4, 5])
s = pl.Series([None, 1, 2, 3, 4, 5])
s.sort(in_place=True) # set ascending sorted flag
assert s.flags == {"SORTED_ASC": True, "SORTED_DESC": False}
assert s.arg_min() == 0
assert s.arg_max() == 4
s = pl.Series("a", [5, 4, 3, 2, 1])
assert s.arg_min() == 1
assert s.arg_max() == 5
s = pl.Series([None, 5, 4, 3, 2, 1])
s.sort(descending=True, in_place=True) # set descing sorted flag
assert s.flags == {"SORTED_ASC": False, "SORTED_DESC": True}
assert s.arg_min() == 4
assert s.arg_max() == 0
assert s.arg_min() == 5
assert s.arg_max() == 1

# test empty series
s = pl.Series("a", [])
# test ascending and descending numerical series
s = pl.Series([None, 1, 2, 3, 4, 5])
s.sort(in_place=True) # set ascending sorted flag
assert s.flags == {"SORTED_ASC": True, "SORTED_DESC": False}
assert s.arg_min() == 1
assert s.arg_max() == 5
s = pl.Series([None, 5, 4, 3, 2, 1])
s.sort(descending=True, in_place=True) # set descing sorted flag
assert s.flags == {"SORTED_ASC": False, "SORTED_DESC": True}
assert s.arg_min() == 5
assert s.arg_max() == 1

# test ascending and descending utf8 series
s = pl.Series([None, "a", "b", "c", "d", "e"])
s.sort(in_place=True) # set ascending sorted flag
assert s.flags == {"SORTED_ASC": True, "SORTED_DESC": False}
assert s.arg_min() == 1
assert s.arg_max() == 5
s = pl.Series([None, "e", "d", "c", "b", "a"])
s.sort(descending=True, in_place=True) # set descing sorted flag
assert s.flags == {"SORTED_ASC": False, "SORTED_DESC": True}
assert s.arg_min() == 5
assert s.arg_max() == 1

# test numerical empty series
s = pl.Series([], dtype=pl.Int32)
assert s.arg_min() is None
assert s.arg_max() is None

# test boolean empty series
s = pl.Series([], dtype=pl.Boolean)
assert s.arg_min() is None
assert s.arg_max() is None

# test utf8 empty series
s = pl.Series([], dtype=pl.Utf8)
assert s.arg_min() is None
assert s.arg_max() is None

Expand Down

0 comments on commit 21a544b

Please sign in to comment.