Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(rust, python): Keep min/max and arg_min/arg_max consistent. #10716

Merged
merged 2 commits into from
Aug 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 70 additions & 61 deletions crates/polars-ops/src/series/ops/arg_min_max.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ impl ArgAgg for Series {
dt if dt.is_numeric() => {
with_match_physical_numeric_polars_type!(s.dtype(), |$T| {
let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref();
if ca.is_empty() { // because argminmax assumes not empty
if ca.is_empty() || ca.null_count() == ca.len() { // because argminmax assumes not empty
None
} else if let Ok(vals) = ca.cont_slice() {
arg_min_numeric_slice(vals, ca.is_sorted_flag())
Expand Down Expand Up @@ -59,7 +59,7 @@ impl ArgAgg for Series {
dt if dt.is_numeric() => {
with_match_physical_numeric_polars_type!(s.dtype(), |$T| {
let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref();
if ca.is_empty() { // because argminmax assumes not empty
if ca.is_empty() || ca.null_count() == ca.len(){ // because argminmax assumes not empty
None
} else if let Ok(vals) = ca.cont_slice() {
arg_max_numeric_slice(vals, ca.is_sorted_flag())
Expand All @@ -74,33 +74,50 @@ impl ArgAgg for Series {
}

pub(crate) fn arg_max_bool(ca: &BooleanChunked) -> Option<usize> {
if ca.is_empty() {
if ca.is_empty() || ca.null_count() == ca.len() {
None
} else if ca.null_count() == ca.len() {
Some(0)
}
// don't check for any, that on itself is already an argmax search
else if ca.null_count() == 0 && ca.chunks().len() == 1 {
let arr = ca.downcast_iter().next().unwrap();
let mask = arr.values();
Some(first_set_bit(mask))
} else {
let mut first_false_idx: Option<usize> = None;
ca.into_iter()
.position(|opt_val| matches!(opt_val, Some(true)))
.enumerate()
.find_map(|(idx, val)| match val {
Copy link
Collaborator Author

@reswqa reswqa Aug 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic is: we first find true, and if it does not exist, we return the first false location. Is this exactly what we want?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@reswqa Yes, that's the correct behavior, although I think we can do a faster implementation more explicitly iterating over the bitmap. But that doesn't have to be in this PR.

Some(true) => Some(idx),
Some(false) if first_false_idx.is_none() => {
first_false_idx = Some(idx);
None
},
_ => None,
})
.or(first_false_idx)
}
}

fn arg_min_bool(ca: &BooleanChunked) -> Option<usize> {
if ca.is_empty() || ca.null_count() == ca.len() || ca.all() {
Some(0)
if ca.is_empty() || ca.null_count() == ca.len() {
None
} else if ca.null_count() == 0 && ca.chunks().len() == 1 {
let arr = ca.downcast_iter().next().unwrap();
let mask = arr.values();
Some(first_unset_bit(mask))
} else {
// also null as we see that as lower in ordering than a set value
let mut first_true_idx: Option<usize> = None;
ca.into_iter()
.position(|opt_val| matches!(opt_val, Some(false) | None))
.enumerate()
.find_map(|(idx, val)| match val {
Some(false) => Some(idx),
Some(true) if first_true_idx.is_none() => {
first_true_idx = Some(idx);
None
},
_ => None,
})
.or(first_true_idx)
}
}

Expand Down Expand Up @@ -195,21 +212,28 @@ fn first_unset_bit(mask: &Bitmap) -> usize {
}

fn arg_min_str(ca: &Utf8Chunked) -> Option<usize> {
if ca.is_empty() || ca.null_count() == ca.len() {
return None;
}
match ca.is_sorted_flag() {
IsSorted::Ascending => Some(0),
IsSorted::Descending => Some(ca.len() - 1),
IsSorted::Ascending => ca.first_non_null(),
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps we can test ca.null_count() == 0, and if it doesn't hold up, we can go here. Otherwise, we will still follow the previous logic?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it would matter much in performance. I leave this up to you.

IsSorted::Descending => ca.last_non_null(),
IsSorted::Not => ca
.into_iter()
.enumerate()
.flat_map(|(idx, val)| val.map(|val| (idx, val)))
Copy link
Collaborator Author

@reswqa reswqa Aug 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

None will always be considered the minimum value as None < Some(_). But I'm not quite sure if it's better to do flat_map here or do match directly in reduce?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Match in reduce will be faster as we have less indirection. I think we should first loop over downcast_iter and then loop over the array.

Copy link
Collaborator Author

@reswqa reswqa Aug 25, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Match in reduce will be faster as we have less indirection.

Make sense, will rewrite this.

I think we should first loop over downcast_iter and then loop over the array.

Does this means rewriting it to the same pattern of arg_max_numeric 🤔

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are similar. We could even share a same generic if you are up for that. ^^

But that is maybe a nice follow up PR. :)

Copy link
Collaborator Author

@reswqa reswqa Aug 25, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am trying to do this refactoring: Let them both share a same generic function as they are almost similar. The only difference is that for numerical types, we have a fast path with the following bound:

for<'b> &'b [T::Native]: ArgMinMax,

But for Utf8ChunkedArray: we do not have this bound, also we doesn't have this branch:

} else {
// When no nulls & array not empty => we can use fast argminmax
let min_idx: usize = arr.values().as_slice().argmin();
Some((min_idx, arr.value(min_idx)))
};

I can think of some hack solution (such as a magical macro), but I don't really like that way. Do we have a good solution for this that looks simpler and cleaner 🤔.

.reduce(|acc, (idx, val)| if acc.1 > val { (idx, val) } else { acc })
.map(|tpl| tpl.0),
}
}

fn arg_max_str(ca: &Utf8Chunked) -> Option<usize> {
if ca.is_empty() || ca.null_count() == ca.len() {
return None;
}
match ca.is_sorted_flag() {
IsSorted::Ascending => Some(ca.len() - 1),
IsSorted::Descending => Some(0),
IsSorted::Ascending => ca.last_non_null(),
IsSorted::Descending => ca.first_non_null(),
IsSorted::Not => ca
.into_iter()
.enumerate()
Expand All @@ -224,45 +248,35 @@ where
for<'b> &'b [T::Native]: ArgMinMax,
{
match ca.is_sorted_flag() {
IsSorted::Ascending => Some(0),
IsSorted::Descending => Some(ca.len() - 1),
IsSorted::Ascending => ca.first_non_null(),
IsSorted::Descending => ca.last_non_null(),
IsSorted::Not => {
ca.downcast_iter()
.fold((None, None, 0), |acc, arr| {
if arr.len() == 0 {
return acc;
}
let chunk_min_idx: Option<usize>;
let chunk_min_val: Option<T::Native>;
if arr.null_count() > 0 {
// When there are nulls, we should compare Option<T::Native>
chunk_min_val = None; // because None < Some(_)
chunk_min_idx = arr
.into_iter()
let chunk_min: Option<(usize, T::Native)> = if arr.null_count() > 0 {
arr.into_iter()
.enumerate()
.flat_map(|(idx, val)| val.map(|val| (idx, *val)))
reswqa marked this conversation as resolved.
Show resolved Hide resolved
.reduce(|acc, (idx, val)| if acc.1 > val { (idx, val) } else { acc })
.map(|tpl| tpl.0);
} else {
// When no nulls & array not empty => we can use fast argminmax
let min_idx: usize = arr.values().as_slice().argmin();
chunk_min_idx = Some(min_idx);
chunk_min_val = Some(arr.value(min_idx));
}
Some((min_idx, arr.value(min_idx)))
};

let new_offset: usize = acc.2 + arr.len();
match acc {
(Some(_), Some(_), offset) => {
if chunk_min_val < acc.1 {
match chunk_min_idx {
Some(idx) => (Some(idx + offset), chunk_min_val, new_offset),
None => (acc.0, acc.1, new_offset),
}
} else {
(acc.0, acc.1, new_offset)
}
(Some(_), Some(acc_v), offset) => match chunk_min {
Some((idx, val)) if val < acc_v => {
(Some(idx + offset), Some(val), new_offset)
},
_ => (acc.0, acc.1, new_offset),
},
(None, None, offset) => match chunk_min_idx {
Some(idx) => (Some(idx + offset), chunk_min_val, new_offset),
(None, None, offset) => match chunk_min {
Some((idx, val)) => (Some(idx + offset), Some(val), new_offset),
None => (None, None, new_offset),
},
_ => unreachable!(),
Expand All @@ -279,45 +293,36 @@ where
for<'b> &'b [T::Native]: ArgMinMax,
{
match ca.is_sorted_flag() {
IsSorted::Ascending => Some(ca.len() - 1),
IsSorted::Descending => Some(0),
IsSorted::Ascending => ca.last_non_null(),
IsSorted::Descending => ca.first_non_null(),
IsSorted::Not => {
ca.downcast_iter()
.fold((None, None, 0), |acc, arr| {
if arr.len() == 0 {
return acc;
}
let chunk_max_idx: Option<usize>;
let chunk_max_val: Option<T::Native>;
if arr.null_count() > 0 {
let chunk_max: Option<(usize, T::Native)> = if arr.null_count() > 0 {
// When there are nulls, we should compare Option<T::Native>
chunk_max_idx = arr
.into_iter()
arr.into_iter()
.enumerate()
.flat_map(|(idx, val)| val.map(|val| (idx, *val)))
.reduce(|acc, (idx, val)| if acc.1 < val { (idx, val) } else { acc })
.map(|tpl| tpl.0);
chunk_max_val = chunk_max_idx.map(|idx| arr.value(idx));
} else {
// When no nulls & array not empty => we can use fast argminmax
let max_idx: usize = arr.values().as_slice().argmax();
chunk_max_idx = Some(max_idx);
chunk_max_val = Some(arr.value(max_idx));
}
Some((max_idx, arr.value(max_idx)))
};

let new_offset: usize = acc.2 + arr.len();
match acc {
(Some(_), Some(_), offset) => {
if chunk_max_val > acc.1 {
match chunk_max_idx {
Some(idx) => (Some(idx + offset), chunk_max_val, new_offset),
_ => unreachable!(), // because None < Some(_)
}
} else {
(acc.0, acc.1, new_offset)
}
(Some(_), Some(acc_v), offset) => match chunk_max {
Some((idx, val)) if acc_v < val => {
(Some(idx + offset), Some(val), new_offset)
},
_ => (acc.0, acc.1, new_offset),
},
(None, None, offset) => match chunk_max_idx {
Some(idx) => (Some(idx + offset), chunk_max_val, new_offset),
(None, None, offset) => match chunk_max {
Some((idx, val)) => (Some(idx + offset), Some(val), new_offset),
None => (None, None, new_offset),
},
_ => unreachable!(),
Expand All @@ -333,7 +338,9 @@ where
for<'a> &'a [T]: ArgMinMax,
{
match is_sorted {
// all vals are not null guarded by cont_slice
IsSorted::Ascending => Some(0),
// all vals are not null guarded by cont_slice
IsSorted::Descending => Some(vals.len() - 1),
IsSorted::Not => Some(vals.argmin()), // assumes not empty
}
Expand All @@ -344,7 +351,9 @@ where
for<'a> &'a [T]: ArgMinMax,
{
match is_sorted {
// all vals are not null guarded by cont_slice
IsSorted::Ascending => Some(vals.len() - 1),
// all vals are not null guarded by cont_slice
IsSorted::Descending => Some(0),
IsSorted::Not => Some(vals.argmax()), // assumes not empty
}
Expand Down
13 changes: 13 additions & 0 deletions py-polars/tests/unit/namespaces/test_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,3 +544,16 @@ def test_list_take_oob_10079() -> None:
)
with pytest.raises(pl.ComputeError, match="take indices are out of bounds"):
df.select(pl.col("a").take(999))


def test_utf8_empty_series_arg_min_max_10703() -> None:
res = pl.select(pl.lit(pl.Series("list", [["a"], []]))).with_columns(
pl.all(),
pl.all().list.arg_min().alias("arg_min"),
pl.all().list.arg_max().alias("arg_max"),
)
assert res.to_dict(False) == {
"list": [["a"], []],
"arg_min": [0, None],
"arg_max": [0, None],
}
97 changes: 84 additions & 13 deletions py-polars/tests/unit/series/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1738,40 +1738,111 @@ def test_arg_sort() -> None:


def test_arg_min_and_arg_max() -> None:
s = pl.Series("a", [5, 3, 4, 1, 2])
# numerical no null.
s = pl.Series([5, 3, 4, 1, 2])
assert s.arg_min() == 3
assert s.arg_max() == 0

s = pl.Series([None, True, False, True])
assert s.arg_min() == 0
# numerical has null.
s = pl.Series([None, 5, 1])
assert s.arg_min() == 2
assert s.arg_max() == 1
s = pl.Series([None, None], dtype=pl.Boolean)
assert s.arg_min() == 0

# numerical all null.
s = pl.Series([None, None], dtype=Int32)
assert s.arg_min() is None
assert s.arg_max() is None

# boolean no null.
s = pl.Series([True, False])
assert s.arg_min() == 1
assert s.arg_max() == 0
s = pl.Series([True, True])
assert s.arg_min() == 0
assert s.arg_max() == 0
s = pl.Series([False, False])
assert s.arg_min() == 0
assert s.arg_max() == 0

# boolean has null.
s = pl.Series([None, True, False, True])
assert s.arg_min() == 2
assert s.arg_max() == 1
s = pl.Series([None, True, True])
assert s.arg_min() == 1
assert s.arg_max() == 1
s = pl.Series([None, False, False])
assert s.arg_min() == 1
assert s.arg_max() == 1

# boolean all null.
s = pl.Series([None, None], dtype=pl.Boolean)
assert s.arg_min() is None
assert s.arg_max() is None

# utf8 no null
s = pl.Series(["a", "c", "b"])
assert s.arg_min() == 0
assert s.arg_max() == 1

# utf8 has null
s = pl.Series([None, "a", None, "b"])
assert s.arg_min() == 1
assert s.arg_max() == 3

# utf8 all null
s = pl.Series([None, None], dtype=pl.Utf8)
assert s.arg_min() is None
assert s.arg_max() is None

# test ascending and descending series
s = pl.Series("a", [1, 2, 3, 4, 5])
s = pl.Series([None, 1, 2, 3, 4, 5])
s.sort(in_place=True) # set ascending sorted flag
assert s.flags == {"SORTED_ASC": True, "SORTED_DESC": False}
assert s.arg_min() == 0
assert s.arg_max() == 4
s = pl.Series("a", [5, 4, 3, 2, 1])
assert s.arg_min() == 1
assert s.arg_max() == 5
s = pl.Series([None, 5, 4, 3, 2, 1])
s.sort(descending=True, in_place=True) # set descing sorted flag
assert s.flags == {"SORTED_ASC": False, "SORTED_DESC": True}
assert s.arg_min() == 4
assert s.arg_max() == 0
assert s.arg_min() == 5
assert s.arg_max() == 1

# test empty series
s = pl.Series("a", [])
# test ascending and descending numerical series
s = pl.Series([None, 1, 2, 3, 4, 5])
s.sort(in_place=True) # set ascending sorted flag
assert s.flags == {"SORTED_ASC": True, "SORTED_DESC": False}
assert s.arg_min() == 1
assert s.arg_max() == 5
s = pl.Series([None, 5, 4, 3, 2, 1])
s.sort(descending=True, in_place=True) # set descing sorted flag
assert s.flags == {"SORTED_ASC": False, "SORTED_DESC": True}
assert s.arg_min() == 5
assert s.arg_max() == 1

# test ascending and descending utf8 series
s = pl.Series([None, "a", "b", "c", "d", "e"])
s.sort(in_place=True) # set ascending sorted flag
assert s.flags == {"SORTED_ASC": True, "SORTED_DESC": False}
assert s.arg_min() == 1
assert s.arg_max() == 5
s = pl.Series([None, "e", "d", "c", "b", "a"])
s.sort(descending=True, in_place=True) # set descing sorted flag
assert s.flags == {"SORTED_ASC": False, "SORTED_DESC": True}
assert s.arg_min() == 5
assert s.arg_max() == 1

# test numerical empty series
s = pl.Series([], dtype=pl.Int32)
assert s.arg_min() is None
assert s.arg_max() is None

# test boolean empty series
s = pl.Series([], dtype=pl.Boolean)
assert s.arg_min() is None
assert s.arg_max() is None

# test utf8 empty series
s = pl.Series([], dtype=pl.Utf8)
assert s.arg_min() is None
assert s.arg_max() is None

Expand Down