From 0d6a4bc4fe563f571828a5af550caf23a15bdcf9 Mon Sep 17 00:00:00 2001 From: Weijie Guo Date: Thu, 24 Aug 2023 13:38:19 +0800 Subject: [PATCH 1/2] fix(rust, python): arg_min & arg_max return null for empty series --- crates/polars-ops/src/series/ops/arg_min_max.rs | 10 +++++++++- py-polars/tests/unit/namespaces/test_list.py | 13 +++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/crates/polars-ops/src/series/ops/arg_min_max.rs b/crates/polars-ops/src/series/ops/arg_min_max.rs index 734a6c5165e8..4e2c1d79aecd 100644 --- a/crates/polars-ops/src/series/ops/arg_min_max.rs +++ b/crates/polars-ops/src/series/ops/arg_min_max.rs @@ -91,7 +91,9 @@ pub(crate) fn arg_max_bool(ca: &BooleanChunked) -> Option { } fn arg_min_bool(ca: &BooleanChunked) -> Option { - if ca.is_empty() || ca.null_count() == ca.len() || ca.all() { + if ca.is_empty() { + None + } else if ca.null_count() == ca.len() || ca.all() { Some(0) } else if ca.null_count() == 0 && ca.chunks().len() == 1 { let arr = ca.downcast_iter().next().unwrap(); @@ -195,6 +197,9 @@ fn first_unset_bit(mask: &Bitmap) -> usize { } fn arg_min_str(ca: &Utf8Chunked) -> Option { + if ca.is_empty() { + return None; + } match ca.is_sorted_flag() { IsSorted::Ascending => Some(0), IsSorted::Descending => Some(ca.len() - 1), @@ -207,6 +212,9 @@ fn arg_min_str(ca: &Utf8Chunked) -> Option { } fn arg_max_str(ca: &Utf8Chunked) -> Option { + if ca.is_empty() { + return None; + } match ca.is_sorted_flag() { IsSorted::Ascending => Some(ca.len() - 1), IsSorted::Descending => Some(0), diff --git a/py-polars/tests/unit/namespaces/test_list.py b/py-polars/tests/unit/namespaces/test_list.py index 786e2f6290de..6e2c8518fa59 100644 --- a/py-polars/tests/unit/namespaces/test_list.py +++ b/py-polars/tests/unit/namespaces/test_list.py @@ -544,3 +544,16 @@ def test_list_take_oob_10079() -> None: ) with pytest.raises(pl.ComputeError, match="take indices are out of bounds"): df.select(pl.col("a").take(999)) + + +def test_utf8_empty_series_arg_min_max_10703() -> None: + res = pl.select(pl.lit(pl.Series("list", [["a"], []]))).with_columns( + pl.all(), + pl.all().list.arg_min().alias("arg_min"), + pl.all().list.arg_max().alias("arg_max"), + ) + assert res.to_dict(False) == { + "list": [["a"], []], + "arg_min": [0, None], + "arg_max": [0, None], + } From aedaed3597b74c49abdb3cf97661ce62b5ef9b7f Mon Sep 17 00:00:00 2001 From: Weijie Guo Date: Fri, 25 Aug 2023 01:23:18 +0800 Subject: [PATCH 2/2] fix(rust, python): Keep min/max and arg_min/arg_max consistent. --- .../polars-ops/src/series/ops/arg_min_max.rs | 129 +++++++++--------- py-polars/tests/unit/series/test_series.py | 97 +++++++++++-- 2 files changed, 149 insertions(+), 77 deletions(-) diff --git a/crates/polars-ops/src/series/ops/arg_min_max.rs b/crates/polars-ops/src/series/ops/arg_min_max.rs index 4e2c1d79aecd..613f833ab497 100644 --- a/crates/polars-ops/src/series/ops/arg_min_max.rs +++ b/crates/polars-ops/src/series/ops/arg_min_max.rs @@ -31,7 +31,7 @@ impl ArgAgg for Series { dt if dt.is_numeric() => { with_match_physical_numeric_polars_type!(s.dtype(), |$T| { let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); - if ca.is_empty() { // because argminmax assumes not empty + if ca.is_empty() || ca.null_count() == ca.len() { // because argminmax assumes not empty None } else if let Ok(vals) = ca.cont_slice() { arg_min_numeric_slice(vals, ca.is_sorted_flag()) @@ -59,7 +59,7 @@ impl ArgAgg for Series { dt if dt.is_numeric() => { with_match_physical_numeric_polars_type!(s.dtype(), |$T| { let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); - if ca.is_empty() { // because argminmax assumes not empty + if ca.is_empty() || ca.null_count() == ca.len(){ // because argminmax assumes not empty None } else if let Ok(vals) = ca.cont_slice() { arg_max_numeric_slice(vals, ca.is_sorted_flag()) @@ -74,10 +74,8 @@ impl ArgAgg for Series { } pub(crate) fn arg_max_bool(ca: &BooleanChunked) -> Option { - if ca.is_empty() { + if ca.is_empty() || ca.null_count() == ca.len() { None - } else if ca.null_count() == ca.len() { - Some(0) } // don't check for any, that on itself is already an argmax search else if ca.null_count() == 0 && ca.chunks().len() == 1 { @@ -85,24 +83,41 @@ pub(crate) fn arg_max_bool(ca: &BooleanChunked) -> Option { let mask = arr.values(); Some(first_set_bit(mask)) } else { + let mut first_false_idx: Option = None; ca.into_iter() - .position(|opt_val| matches!(opt_val, Some(true))) + .enumerate() + .find_map(|(idx, val)| match val { + Some(true) => Some(idx), + Some(false) if first_false_idx.is_none() => { + first_false_idx = Some(idx); + None + }, + _ => None, + }) + .or(first_false_idx) } } fn arg_min_bool(ca: &BooleanChunked) -> Option { - if ca.is_empty() { + if ca.is_empty() || ca.null_count() == ca.len() { None - } else if ca.null_count() == ca.len() || ca.all() { - Some(0) } else if ca.null_count() == 0 && ca.chunks().len() == 1 { let arr = ca.downcast_iter().next().unwrap(); let mask = arr.values(); Some(first_unset_bit(mask)) } else { - // also null as we see that as lower in ordering than a set value + let mut first_true_idx: Option = None; ca.into_iter() - .position(|opt_val| matches!(opt_val, Some(false) | None)) + .enumerate() + .find_map(|(idx, val)| match val { + Some(false) => Some(idx), + Some(true) if first_true_idx.is_none() => { + first_true_idx = Some(idx); + None + }, + _ => None, + }) + .or(first_true_idx) } } @@ -197,27 +212,28 @@ fn first_unset_bit(mask: &Bitmap) -> usize { } fn arg_min_str(ca: &Utf8Chunked) -> Option { - if ca.is_empty() { + if ca.is_empty() || ca.null_count() == ca.len() { return None; } match ca.is_sorted_flag() { - IsSorted::Ascending => Some(0), - IsSorted::Descending => Some(ca.len() - 1), + IsSorted::Ascending => ca.first_non_null(), + IsSorted::Descending => ca.last_non_null(), IsSorted::Not => ca .into_iter() .enumerate() + .flat_map(|(idx, val)| val.map(|val| (idx, val))) .reduce(|acc, (idx, val)| if acc.1 > val { (idx, val) } else { acc }) .map(|tpl| tpl.0), } } fn arg_max_str(ca: &Utf8Chunked) -> Option { - if ca.is_empty() { + if ca.is_empty() || ca.null_count() == ca.len() { return None; } match ca.is_sorted_flag() { - IsSorted::Ascending => Some(ca.len() - 1), - IsSorted::Descending => Some(0), + IsSorted::Ascending => ca.last_non_null(), + IsSorted::Descending => ca.first_non_null(), IsSorted::Not => ca .into_iter() .enumerate() @@ -232,45 +248,35 @@ where for<'b> &'b [T::Native]: ArgMinMax, { match ca.is_sorted_flag() { - IsSorted::Ascending => Some(0), - IsSorted::Descending => Some(ca.len() - 1), + IsSorted::Ascending => ca.first_non_null(), + IsSorted::Descending => ca.last_non_null(), IsSorted::Not => { ca.downcast_iter() .fold((None, None, 0), |acc, arr| { if arr.len() == 0 { return acc; } - let chunk_min_idx: Option; - let chunk_min_val: Option; - if arr.null_count() > 0 { - // When there are nulls, we should compare Option - chunk_min_val = None; // because None < Some(_) - chunk_min_idx = arr - .into_iter() + let chunk_min: Option<(usize, T::Native)> = if arr.null_count() > 0 { + arr.into_iter() .enumerate() + .flat_map(|(idx, val)| val.map(|val| (idx, *val))) .reduce(|acc, (idx, val)| if acc.1 > val { (idx, val) } else { acc }) - .map(|tpl| tpl.0); } else { // When no nulls & array not empty => we can use fast argminmax let min_idx: usize = arr.values().as_slice().argmin(); - chunk_min_idx = Some(min_idx); - chunk_min_val = Some(arr.value(min_idx)); - } + Some((min_idx, arr.value(min_idx))) + }; let new_offset: usize = acc.2 + arr.len(); match acc { - (Some(_), Some(_), offset) => { - if chunk_min_val < acc.1 { - match chunk_min_idx { - Some(idx) => (Some(idx + offset), chunk_min_val, new_offset), - None => (acc.0, acc.1, new_offset), - } - } else { - (acc.0, acc.1, new_offset) - } + (Some(_), Some(acc_v), offset) => match chunk_min { + Some((idx, val)) if val < acc_v => { + (Some(idx + offset), Some(val), new_offset) + }, + _ => (acc.0, acc.1, new_offset), }, - (None, None, offset) => match chunk_min_idx { - Some(idx) => (Some(idx + offset), chunk_min_val, new_offset), + (None, None, offset) => match chunk_min { + Some((idx, val)) => (Some(idx + offset), Some(val), new_offset), None => (None, None, new_offset), }, _ => unreachable!(), @@ -287,45 +293,36 @@ where for<'b> &'b [T::Native]: ArgMinMax, { match ca.is_sorted_flag() { - IsSorted::Ascending => Some(ca.len() - 1), - IsSorted::Descending => Some(0), + IsSorted::Ascending => ca.last_non_null(), + IsSorted::Descending => ca.first_non_null(), IsSorted::Not => { ca.downcast_iter() .fold((None, None, 0), |acc, arr| { if arr.len() == 0 { return acc; } - let chunk_max_idx: Option; - let chunk_max_val: Option; - if arr.null_count() > 0 { + let chunk_max: Option<(usize, T::Native)> = if arr.null_count() > 0 { // When there are nulls, we should compare Option - chunk_max_idx = arr - .into_iter() + arr.into_iter() .enumerate() + .flat_map(|(idx, val)| val.map(|val| (idx, *val))) .reduce(|acc, (idx, val)| if acc.1 < val { (idx, val) } else { acc }) - .map(|tpl| tpl.0); - chunk_max_val = chunk_max_idx.map(|idx| arr.value(idx)); } else { // When no nulls & array not empty => we can use fast argminmax let max_idx: usize = arr.values().as_slice().argmax(); - chunk_max_idx = Some(max_idx); - chunk_max_val = Some(arr.value(max_idx)); - } + Some((max_idx, arr.value(max_idx))) + }; let new_offset: usize = acc.2 + arr.len(); match acc { - (Some(_), Some(_), offset) => { - if chunk_max_val > acc.1 { - match chunk_max_idx { - Some(idx) => (Some(idx + offset), chunk_max_val, new_offset), - _ => unreachable!(), // because None < Some(_) - } - } else { - (acc.0, acc.1, new_offset) - } + (Some(_), Some(acc_v), offset) => match chunk_max { + Some((idx, val)) if acc_v < val => { + (Some(idx + offset), Some(val), new_offset) + }, + _ => (acc.0, acc.1, new_offset), }, - (None, None, offset) => match chunk_max_idx { - Some(idx) => (Some(idx + offset), chunk_max_val, new_offset), + (None, None, offset) => match chunk_max { + Some((idx, val)) => (Some(idx + offset), Some(val), new_offset), None => (None, None, new_offset), }, _ => unreachable!(), @@ -341,7 +338,9 @@ where for<'a> &'a [T]: ArgMinMax, { match is_sorted { + // all vals are not null guarded by cont_slice IsSorted::Ascending => Some(0), + // all vals are not null guarded by cont_slice IsSorted::Descending => Some(vals.len() - 1), IsSorted::Not => Some(vals.argmin()), // assumes not empty } @@ -352,7 +351,9 @@ where for<'a> &'a [T]: ArgMinMax, { match is_sorted { + // all vals are not null guarded by cont_slice IsSorted::Ascending => Some(vals.len() - 1), + // all vals are not null guarded by cont_slice IsSorted::Descending => Some(0), IsSorted::Not => Some(vals.argmax()), // assumes not empty } diff --git a/py-polars/tests/unit/series/test_series.py b/py-polars/tests/unit/series/test_series.py index c72bd054f8ce..71d1f9e78db6 100644 --- a/py-polars/tests/unit/series/test_series.py +++ b/py-polars/tests/unit/series/test_series.py @@ -1738,15 +1738,24 @@ def test_arg_sort() -> None: def test_arg_min_and_arg_max() -> None: - s = pl.Series("a", [5, 3, 4, 1, 2]) + # numerical no null. + s = pl.Series([5, 3, 4, 1, 2]) assert s.arg_min() == 3 assert s.arg_max() == 0 - s = pl.Series([None, True, False, True]) - assert s.arg_min() == 0 + # numerical has null. + s = pl.Series([None, 5, 1]) + assert s.arg_min() == 2 assert s.arg_max() == 1 - s = pl.Series([None, None], dtype=pl.Boolean) - assert s.arg_min() == 0 + + # numerical all null. + s = pl.Series([None, None], dtype=Int32) + assert s.arg_min() is None + assert s.arg_max() is None + + # boolean no null. + s = pl.Series([True, False]) + assert s.arg_min() == 1 assert s.arg_max() == 0 s = pl.Series([True, True]) assert s.arg_min() == 0 @@ -1754,24 +1763,86 @@ def test_arg_min_and_arg_max() -> None: s = pl.Series([False, False]) assert s.arg_min() == 0 assert s.arg_max() == 0 + + # boolean has null. + s = pl.Series([None, True, False, True]) + assert s.arg_min() == 2 + assert s.arg_max() == 1 + s = pl.Series([None, True, True]) + assert s.arg_min() == 1 + assert s.arg_max() == 1 + s = pl.Series([None, False, False]) + assert s.arg_min() == 1 + assert s.arg_max() == 1 + + # boolean all null. + s = pl.Series([None, None], dtype=pl.Boolean) + assert s.arg_min() is None + assert s.arg_max() is None + + # utf8 no null s = pl.Series(["a", "c", "b"]) assert s.arg_min() == 0 assert s.arg_max() == 1 + # utf8 has null + s = pl.Series([None, "a", None, "b"]) + assert s.arg_min() == 1 + assert s.arg_max() == 3 + + # utf8 all null + s = pl.Series([None, None], dtype=pl.Utf8) + assert s.arg_min() is None + assert s.arg_max() is None + # test ascending and descending series - s = pl.Series("a", [1, 2, 3, 4, 5]) + s = pl.Series([None, 1, 2, 3, 4, 5]) s.sort(in_place=True) # set ascending sorted flag assert s.flags == {"SORTED_ASC": True, "SORTED_DESC": False} - assert s.arg_min() == 0 - assert s.arg_max() == 4 - s = pl.Series("a", [5, 4, 3, 2, 1]) + assert s.arg_min() == 1 + assert s.arg_max() == 5 + s = pl.Series([None, 5, 4, 3, 2, 1]) s.sort(descending=True, in_place=True) # set descing sorted flag assert s.flags == {"SORTED_ASC": False, "SORTED_DESC": True} - assert s.arg_min() == 4 - assert s.arg_max() == 0 + assert s.arg_min() == 5 + assert s.arg_max() == 1 - # test empty series - s = pl.Series("a", []) + # test ascending and descending numerical series + s = pl.Series([None, 1, 2, 3, 4, 5]) + s.sort(in_place=True) # set ascending sorted flag + assert s.flags == {"SORTED_ASC": True, "SORTED_DESC": False} + assert s.arg_min() == 1 + assert s.arg_max() == 5 + s = pl.Series([None, 5, 4, 3, 2, 1]) + s.sort(descending=True, in_place=True) # set descing sorted flag + assert s.flags == {"SORTED_ASC": False, "SORTED_DESC": True} + assert s.arg_min() == 5 + assert s.arg_max() == 1 + + # test ascending and descending utf8 series + s = pl.Series([None, "a", "b", "c", "d", "e"]) + s.sort(in_place=True) # set ascending sorted flag + assert s.flags == {"SORTED_ASC": True, "SORTED_DESC": False} + assert s.arg_min() == 1 + assert s.arg_max() == 5 + s = pl.Series([None, "e", "d", "c", "b", "a"]) + s.sort(descending=True, in_place=True) # set descing sorted flag + assert s.flags == {"SORTED_ASC": False, "SORTED_DESC": True} + assert s.arg_min() == 5 + assert s.arg_max() == 1 + + # test numerical empty series + s = pl.Series([], dtype=pl.Int32) + assert s.arg_min() is None + assert s.arg_max() is None + + # test boolean empty series + s = pl.Series([], dtype=pl.Boolean) + assert s.arg_min() is None + assert s.arg_max() is None + + # test utf8 empty series + s = pl.Series([], dtype=pl.Utf8) assert s.arg_min() is None assert s.arg_max() is None