From 55f88703d20b68a43269993c0729c609ff4f1422 Mon Sep 17 00:00:00 2001 From: Weijie Guo Date: Mon, 18 Sep 2023 00:29:17 +0800 Subject: [PATCH] feat: list.join's separator can be expression --- .../src/chunked_array/list/iterator.rs | 8 ++ .../src/chunked_array/list/namespace.rs | 73 ++++++++++++++----- .../polars-plan/src/dsl/function_expr/list.rs | 8 ++ .../polars-plan/src/dsl/function_expr/mod.rs | 1 + .../src/dsl/function_expr/schema.rs | 1 + crates/polars-plan/src/dsl/list.rs | 18 ++--- crates/polars-sql/src/functions.rs | 9 +-- py-polars/polars/expr/list.py | 16 +++- py-polars/src/expr/list.rs | 4 +- py-polars/tests/unit/namespaces/test_list.py | 13 ++++ py-polars/tests/unit/test_exprs.py | 6 -- 11 files changed, 111 insertions(+), 46 deletions(-) diff --git a/crates/polars-core/src/chunked_array/list/iterator.rs b/crates/polars-core/src/chunked_array/list/iterator.rs index b30a339ba812..6bf9ed2f60d6 100644 --- a/crates/polars-core/src/chunked_array/list/iterator.rs +++ b/crates/polars-core/src/chunked_array/list/iterator.rs @@ -180,6 +180,14 @@ impl ListChunked { unsafe { self.amortized_iter().map(f).collect_ca(self.name()) } } + pub fn for_each_amortized<'a, F>(&'a self, f: F) + where + F: FnMut(Option>), + { + // SAFETY: unstable series never lives longer than the iterator. + unsafe { self.amortized_iter().for_each(f) } + } + /// Apply a closure `F` elementwise. #[must_use] pub fn apply_amortized<'a, F>(&'a self, mut f: F) -> Self diff --git a/crates/polars-ops/src/chunked_array/list/namespace.rs b/crates/polars-ops/src/chunked_array/list/namespace.rs index bcd545a1ba06..7090613b117b 100644 --- a/crates/polars-ops/src/chunked_array/list/namespace.rs +++ b/crates/polars-ops/src/chunked_array/list/namespace.rs @@ -76,22 +76,61 @@ fn cast_rhs( pub trait ListNameSpaceImpl: AsList { /// In case the inner dtype [`DataType::Utf8`], the individual items will be joined into a /// single string separated by `separator`. - fn lst_join(&self, separator: &str) -> PolarsResult { + fn lst_join(&self, separator: &Utf8Chunked) -> PolarsResult { let ca = self.as_list(); match ca.inner_dtype() { - DataType::Utf8 => { - // used to amortize heap allocs - let mut buf = String::with_capacity(128); + DataType::Utf8 => match separator.len() { + 1 => match separator.get(0) { + Some(separator) => self.join_literal(separator), + _ => Ok(Utf8Chunked::full_null(ca.name(), ca.len())), + }, + _ => self.join_many(separator), + }, + dt => polars_bail!(op = "`lst.join`", got = dt, expected = "Utf8"), + } + } - let mut builder = Utf8ChunkedBuilder::new( - ca.name(), - ca.len(), - ca.get_values_size() + separator.len() * ca.len(), - ); + fn join_literal(&self, separator: &str) -> PolarsResult { + let ca = self.as_list(); + // used to amortize heap allocs + let mut buf = String::with_capacity(128); + let mut builder = Utf8ChunkedBuilder::new( + ca.name(), + ca.len(), + ca.get_values_size() + separator.len() * ca.len(), + ); + + ca.for_each_amortized(|opt_s| { + let opt_val = opt_s.map(|s| { + // make sure that we don't write values of previous iteration + buf.clear(); + let ca = s.as_ref().utf8().unwrap(); + let iter = ca.into_iter().map(|opt_v| opt_v.unwrap_or("null")); + + for val in iter { + buf.write_str(val).unwrap(); + buf.write_str(separator).unwrap(); + } + // last value should not have a separator, so slice that off + // saturating sub because there might have been nothing written. + &buf[..buf.len().saturating_sub(separator.len())] + }); + builder.append_option(opt_val) + }); + Ok(builder.finish()) + } - // SAFETY: unstable series never lives longer than the iterator. - unsafe { - ca.amortized_iter().for_each(|opt_s| { + fn join_many(&self, separator: &Utf8Chunked) -> PolarsResult { + let ca = self.as_list(); + // used to amortize heap allocs + let mut buf = String::with_capacity(128); + let mut builder = + Utf8ChunkedBuilder::new(ca.name(), ca.len(), ca.get_values_size() + ca.len()); + // SAFETY: unstable series never lives longer than the iterator. + unsafe { + ca.amortized_iter().zip(separator.into_iter()).for_each( + |(opt_s, opt_sep)| match opt_sep { + Some(separator) => { let opt_val = opt_s.map(|s| { // make sure that we don't write values of previous iteration buf.clear(); @@ -107,12 +146,12 @@ pub trait ListNameSpaceImpl: AsList { &buf[..buf.len().saturating_sub(separator.len())] }); builder.append_option(opt_val) - }) - }; - Ok(builder.finish()) - }, - dt => polars_bail!(op = "`lst.join`", got = dt, expected = "Utf8"), + }, + _ => builder.append_null(), + }, + ) } + Ok(builder.finish()) } fn lst_max(&self) -> Series { diff --git a/crates/polars-plan/src/dsl/function_expr/list.rs b/crates/polars-plan/src/dsl/function_expr/list.rs index 856273fadc76..2fce49fd8585 100644 --- a/crates/polars-plan/src/dsl/function_expr/list.rs +++ b/crates/polars-plan/src/dsl/function_expr/list.rs @@ -22,6 +22,7 @@ pub enum ListFunction { Any, #[cfg(feature = "list_any_all")] All, + Join, } impl Display for ListFunction { @@ -45,6 +46,7 @@ impl Display for ListFunction { Any => "any", #[cfg(feature = "list_any_all")] All => "all", + Join => "join", }; write!(f, "{name}") } @@ -279,3 +281,9 @@ pub(super) fn lst_any(s: &Series) -> PolarsResult { pub(super) fn lst_all(s: &Series) -> PolarsResult { s.list()?.lst_all() } + +pub(super) fn join(s: &[Series]) -> PolarsResult { + let ca = s[0].list()?; + let separator = s[1].utf8()?; + Ok(ca.lst_join(&separator)?.into_series()) +} diff --git a/crates/polars-plan/src/dsl/function_expr/mod.rs b/crates/polars-plan/src/dsl/function_expr/mod.rs index 9ffb7efd2a45..94574a0b7a67 100644 --- a/crates/polars-plan/src/dsl/function_expr/mod.rs +++ b/crates/polars-plan/src/dsl/function_expr/mod.rs @@ -551,6 +551,7 @@ impl From for SpecialEq> { Any => map!(list::lst_any), #[cfg(feature = "list_any_all")] All => map!(list::lst_all), + Join => map_as_slice!(list::join), } }, #[cfg(feature = "dtype-array")] diff --git a/crates/polars-plan/src/dsl/function_expr/schema.rs b/crates/polars-plan/src/dsl/function_expr/schema.rs index 6080ad3a406a..843f91bc5e5e 100644 --- a/crates/polars-plan/src/dsl/function_expr/schema.rs +++ b/crates/polars-plan/src/dsl/function_expr/schema.rs @@ -113,6 +113,7 @@ impl FunctionExpr { Any => mapper.with_dtype(DataType::Boolean), #[cfg(feature = "list_any_all")] All => mapper.with_dtype(DataType::Boolean), + Join => mapper.with_dtype(DataType::Utf8), } }, #[cfg(feature = "dtype-array")] diff --git a/crates/polars-plan/src/dsl/list.rs b/crates/polars-plan/src/dsl/list.rs index 427d99d5ca6c..4670f90e473a 100644 --- a/crates/polars-plan/src/dsl/list.rs +++ b/crates/polars-plan/src/dsl/list.rs @@ -162,18 +162,12 @@ impl ListNameSpace { /// Join all string items in a sublist and place a separator between them. /// # Error /// This errors if inner type of list `!= DataType::Utf8`. - pub fn join(self, separator: &str) -> Expr { - let separator = separator.to_string(); - self.0 - .map( - move |s| { - s.list()? - .lst_join(&separator) - .map(|ca| Some(ca.into_series())) - }, - GetOutput::from_type(DataType::Utf8), - ) - .with_fmt("list.join") + pub fn join(self, separator: Expr) -> Expr { + self.0.map_many_private( + FunctionExpr::ListExpr(ListFunction::Join), + &[separator], + false, + ) } /// Return the index of the minimal value of every sublist diff --git a/crates/polars-sql/src/functions.rs b/crates/polars-sql/src/functions.rs index aeca8376aff2..d4b73a371b97 100644 --- a/crates/polars-sql/src/functions.rs +++ b/crates/polars-sql/src/functions.rs @@ -682,14 +682,7 @@ impl SqlFunctionVisitor<'_> { ArrayReverse => self.visit_unary(|e| e.list().reverse()), ArraySum => self.visit_unary(|e| e.list().sum()), ArrayToString => self.try_visit_binary(|e, s| { - let sep = match s { - Expr::Literal(LiteralValue::Utf8(ref sep)) => sep, - _ => { - polars_bail!(InvalidOperation: "Invalid 'separator' for ArrayToString: {}", function.args[1]); - } - }; - - Ok(e.list().join(sep)) + Ok(e.list().join(s)) }), ArrayUnique => self.visit_unary(|e| e.list().unique()), Explode => self.visit_unary(|e| e.explode()), diff --git a/py-polars/polars/expr/list.py b/py-polars/polars/expr/list.py index 642ca1afcabc..a6a13c87a2ae 100644 --- a/py-polars/polars/expr/list.py +++ b/py-polars/polars/expr/list.py @@ -459,7 +459,7 @@ def contains( item = parse_as_expression(item, str_as_lit=True) return wrap_expr(self._pyexpr.list_contains(item)) - def join(self, separator: str) -> Expr: + def join(self, separator: str | Expr) -> Expr: """ Join all string items in a sublist and place a separator between them. @@ -489,7 +489,21 @@ def join(self, separator: str) -> Expr: │ x y │ └───────┘ + >>> df = pl.DataFrame( + ... {"s": [["a", "b", "c"], ["x", "y"]], "separator": ["*", "_"]} + ... ) + >>> df.select(pl.col("s").list.join(pl.col("separator"))) + shape: (2, 1) + ┌───────┐ + │ s │ + │ --- │ + │ str │ + ╞═══════╡ + │ a*b*c │ + │ x_y │ + └───────┘ """ + separator = parse_as_expression(separator, str_as_lit=True) return wrap_expr(self._pyexpr.list_join(separator)) def arg_min(self) -> Expr: diff --git a/py-polars/src/expr/list.rs b/py-polars/src/expr/list.rs index 3495718d817d..261a96fa6259 100644 --- a/py-polars/src/expr/list.rs +++ b/py-polars/src/expr/list.rs @@ -49,8 +49,8 @@ impl PyExpr { self.inner.clone().list().get(index.inner).into() } - fn list_join(&self, separator: &str) -> Self { - self.inner.clone().list().join(separator).into() + fn list_join(&self, separator: PyExpr) -> Self { + self.inner.clone().list().join(separator.inner).into() } fn list_lengths(&self) -> Self { diff --git a/py-polars/tests/unit/namespaces/test_list.py b/py-polars/tests/unit/namespaces/test_list.py index d84184fdf6d1..100581c80b49 100644 --- a/py-polars/tests/unit/namespaces/test_list.py +++ b/py-polars/tests/unit/namespaces/test_list.py @@ -94,6 +94,19 @@ def test_list_concat() -> None: assert out_s[0].to_list() == [1, 2, 4, 1] +def test_list_join() -> None: + df = pl.DataFrame( + { + "a": [["ab", "c", "d"], ["e", "f"], ["g"], [], None], + "separator": ["&", None, "*", "_", "*"], + } + ) + out = df.select(pl.col("a").list.join("-")) + assert out.to_dict(False) == {"a": ["ab-c-d", "e-f", "g", "", None]} + out = df.select(pl.col("a").list.join(pl.col("separator"))) + assert out.to_dict(False) == {"a": ["ab&c&d", None, "g", "", None]} + + def test_list_arr_empty() -> None: df = pl.DataFrame({"cars": [[1, 2, 3], [2, 3], [4], []]}) diff --git a/py-polars/tests/unit/test_exprs.py b/py-polars/tests/unit/test_exprs.py index 8f10655a20a1..5201b91ae4ba 100644 --- a/py-polars/tests/unit/test_exprs.py +++ b/py-polars/tests/unit/test_exprs.py @@ -85,12 +85,6 @@ def test_filter_where() -> None: assert_frame_equal(result_filter, expected) -def test_list_join_strings() -> None: - s = pl.Series("a", [["ab", "c", "d"], ["e", "f"], ["g"], []]) - expected = pl.Series("a", ["ab-c-d", "e-f", "g", ""]) - assert_series_equal(s.list.join("-"), expected) - - def test_count_expr() -> None: df = pl.DataFrame({"a": [1, 2, 3, 3, 3], "b": ["a", "a", "b", "a", "a"]})