From 49167f272c393d9497e585f08ae0f52656d9be8d Mon Sep 17 00:00:00 2001 From: Weijie Guo Date: Wed, 23 Aug 2023 00:35:46 +0800 Subject: [PATCH] feat(rust, python): Support min and max strategy for string columns fill null --- .../src/chunked_array/ops/aggregate/mod.rs | 74 +++++++++++++++---- .../src/chunked_array/ops/fill_null.rs | 6 ++ py-polars/tests/unit/series/test_series.py | 8 ++ 3 files changed, 72 insertions(+), 16 deletions(-) diff --git a/crates/polars-core/src/chunked_array/ops/aggregate/mod.rs b/crates/polars-core/src/chunked_array/ops/aggregate/mod.rs index 5ed07614951a..6738991d2271 100644 --- a/crates/polars-core/src/chunked_array/ops/aggregate/mod.rs +++ b/crates/polars-core/src/chunked_array/ops/aggregate/mod.rs @@ -138,14 +138,14 @@ where IsSorted::Ascending => { self.last_non_null().and_then(|idx| { // Safety: - // first_non_null returns in bound index + // last_non_null returns in bound index unsafe { self.get_unchecked(idx) } }) }, IsSorted::Descending => { self.first_non_null().and_then(|idx| { // Safety: - // last returns in bound index + // first_non_null returns in bound index unsafe { self.get_unchecked(idx) } }) }, @@ -485,27 +485,69 @@ impl ChunkAggSeries for Utf8Chunked { } } +impl BinaryChunked { + pub(crate) fn max_binary(&self) -> Option<&[u8]> { + if self.is_empty() { + return None; + } + match self.is_sorted_flag() { + IsSorted::Ascending => { + self.last_non_null().and_then(|idx| { + // Safety: + // last_non_null returns in bound index + unsafe { self.get_unchecked(idx) } + }) + }, + IsSorted::Descending => { + self.first_non_null().and_then(|idx| { + // Safety: + // first_non_null returns in bound index + unsafe { self.get_unchecked(idx) } + }) + }, + IsSorted::Not => self + .downcast_iter() + .filter_map(compute::aggregate::max_binary) + .fold_first_(|acc, v| if acc > v { acc } else { v }), + } + } + + pub(crate) fn min_binary(&self) -> Option<&[u8]> { + if self.is_empty() { + return None; + } + match self.is_sorted_flag() { + IsSorted::Ascending => { + self.first_non_null().and_then(|idx| { + // Safety: + // first_non_null returns in bound index + unsafe { self.get_unchecked(idx) } + }) + }, + IsSorted::Descending => { + self.last_non_null().and_then(|idx| { + // Safety: + // last_non_null returns in bound index + unsafe { self.get_unchecked(idx) } + }) + }, + IsSorted::Not => self + .downcast_iter() + .filter_map(compute::aggregate::min_binary) + .fold_first_(|acc, v| if acc < v { acc } else { v }), + } + } +} + impl ChunkAggSeries for BinaryChunked { fn sum_as_series(&self) -> Series { BinaryChunked::full_null(self.name(), 1).into_series() } fn max_as_series(&self) -> Series { - Series::new( - self.name(), - &[self - .downcast_iter() - .filter_map(compute::aggregate::max_binary) - .fold_first_(|acc, v| if acc > v { acc } else { v })], - ) + Series::new(self.name(), [self.max_binary()]) } fn min_as_series(&self) -> Series { - Series::new( - self.name(), - &[self - .downcast_iter() - .filter_map(compute::aggregate::min_binary) - .fold_first_(|acc, v| if acc < v { acc } else { v })], - ) + Series::new(self.name(), [self.min_binary()]) } } diff --git a/crates/polars-core/src/chunked_array/ops/fill_null.rs b/crates/polars-core/src/chunked_array/ops/fill_null.rs index 0e6539f7e5c2..440fb6591ea8 100644 --- a/crates/polars-core/src/chunked_array/ops/fill_null.rs +++ b/crates/polars-core/src/chunked_array/ops/fill_null.rs @@ -363,6 +363,12 @@ fn fill_null_binary(ca: &BinaryChunked, strategy: FillNullStrategy) -> PolarsRes out.rename(ca.name()); Ok(out) }, + FillNullStrategy::Min => { + ca.fill_null_with_values(ca.min_binary().ok_or_else(err_fill_null)?) + }, + FillNullStrategy::Max => { + ca.fill_null_with_values(ca.max_binary().ok_or_else(err_fill_null)?) + }, strat => polars_bail!(InvalidOperation: "fill-null strategy {:?} is not supported", strat), } } diff --git a/py-polars/tests/unit/series/test_series.py b/py-polars/tests/unit/series/test_series.py index d8b6336ac9ed..dfc7827ed967 100644 --- a/py-polars/tests/unit/series/test_series.py +++ b/py-polars/tests/unit/series/test_series.py @@ -927,6 +927,14 @@ def test_fill_null() -> None: assert a.fill_null(strategy="backward").to_list() == [0.0, 1.0, 2.0, 2.0, 3.0, 3.0] assert a.fill_null(strategy="mean").to_list() == [0.0, 1.0, 1.5, 2.0, 1.5, 3.0] + b = pl.Series("b", ["a", None, "c", None, "e"]) + assert b.fill_null(strategy="min").to_list() == ["a", "a", "c", "a", "e"] + assert b.fill_null(strategy="max").to_list() == ["a", "e", "c", "e", "e"] + + c = pl.Series("c", [b"a", None, b"c", None, b"e"]) + assert c.fill_null(strategy="min").to_list() == [b"a", b"a", b"c", b"a", b"e"] + assert c.fill_null(strategy="max").to_list() == [b"a", b"e", b"c", b"e", b"e"] + df = pl.DataFrame( [ pl.Series("i32", [1, 2, None], dtype=pl.Int32),