From e12fb6f83afbe2d7f8d89b67759b9a2cc83687eb Mon Sep 17 00:00:00 2001 From: Weijie Guo Date: Fri, 15 Sep 2023 23:30:23 +0800 Subject: [PATCH] feat: Expressify str.split argument. (#11117) --- .../src/chunked_array/ops/arity.rs | 18 ++++++ .../src/chunked_array/ops/for_each.rs | 15 +++++ .../polars-core/src/chunked_array/ops/mod.rs | 1 + .../src/chunked_array/strings/namespace.rs | 62 ++++++++++++++++++- .../polars-plan/src/dsl/function_expr/mod.rs | 6 ++ .../src/dsl/function_expr/strings.rs | 43 +++++++++++++ crates/polars-plan/src/dsl/string.rs | 46 ++------------ crates/polars/tests/it/lazy/explodes.rs | 2 +- py-polars/polars/expr/string.py | 50 +++++++++++---- py-polars/src/expr/string.rs | 8 +-- .../tests/unit/namespaces/test_string.py | 25 ++++++++ 11 files changed, 215 insertions(+), 61 deletions(-) create mode 100644 crates/polars-core/src/chunked_array/ops/for_each.rs diff --git a/crates/polars-core/src/chunked_array/ops/arity.rs b/crates/polars-core/src/chunked_array/ops/arity.rs index 3df7c726df7a..788ec5db840f 100644 --- a/crates/polars-core/src/chunked_array/ops/arity.rs +++ b/crates/polars-core/src/chunked_array/ops/arity.rs @@ -34,6 +34,24 @@ where ChunkedArray::from_chunk_iter(lhs.name(), iter) } +#[inline] +pub fn binary_elementwise_for_each(lhs: &ChunkedArray, rhs: &ChunkedArray, mut op: F) +where + T: PolarsDataType, + U: PolarsDataType, + F: for<'a> FnMut(Option>, Option>), +{ + let (lhs, rhs) = align_chunks_binary(lhs, rhs); + lhs.downcast_iter() + .zip(rhs.downcast_iter()) + .for_each(|(lhs_arr, rhs_arr)| { + lhs_arr + .iter() + .zip(rhs_arr.iter()) + .for_each(|(lhs_opt_val, rhs_opt_val)| op(lhs_opt_val, rhs_opt_val)); + }) +} + #[inline] pub fn try_binary_elementwise( lhs: &ChunkedArray, diff --git a/crates/polars-core/src/chunked_array/ops/for_each.rs b/crates/polars-core/src/chunked_array/ops/for_each.rs new file mode 100644 index 000000000000..42713e0cdff2 --- /dev/null +++ b/crates/polars-core/src/chunked_array/ops/for_each.rs @@ -0,0 +1,15 @@ +use crate::prelude::*; + +impl ChunkedArray +where + T: PolarsDataType, +{ + pub fn for_each<'a, F>(&'a self, mut op: F) + where + F: FnMut(Option>), + { + self.downcast_iter().for_each(|arr| { + arr.iter().for_each(&mut op); + }) + } +} diff --git a/crates/polars-core/src/chunked_array/ops/mod.rs b/crates/polars-core/src/chunked_array/ops/mod.rs index 4b4e4af1ec5e..5c2a079434b6 100644 --- a/crates/polars-core/src/chunked_array/ops/mod.rs +++ b/crates/polars-core/src/chunked_array/ops/mod.rs @@ -29,6 +29,7 @@ mod explode_and_offsets; mod extend; mod fill_null; mod filter; +mod for_each; pub mod full; #[cfg(feature = "interpolate")] mod interpolate; diff --git a/crates/polars-ops/src/chunked_array/strings/namespace.rs b/crates/polars-ops/src/chunked_array/strings/namespace.rs index 0cc3b8721c71..51b443eab329 100644 --- a/crates/polars-ops/src/chunked_array/strings/namespace.rs +++ b/crates/polars-ops/src/chunked_array/strings/namespace.rs @@ -8,7 +8,7 @@ use polars_arrow::kernels::string::*; #[cfg(feature = "string_from_radix")] use polars_core::export::num::Num; use polars_core::export::regex::Regex; -use polars_core::prelude::arity::try_binary_elementwise; +use polars_core::prelude::arity::{binary_elementwise_for_each, try_binary_elementwise}; use polars_utils::cache::FastFixedCache; use regex::escape; @@ -311,6 +311,66 @@ pub trait Utf8NameSpaceImpl: AsUtf8 { Ok(builder.finish()) } + fn split(&self, by: &str) -> ListChunked { + let ca = self.as_utf8(); + let mut builder = ListUtf8ChunkedBuilder::new(ca.name(), ca.len(), ca.get_values_size()); + + ca.for_each(|opt_v| match opt_v { + Some(val) => { + let iter = val.split(by); + builder.append_values_iter(iter) + }, + _ => builder.append_null(), + }); + builder.finish() + } + + fn split_many(&self, by: &Utf8Chunked) -> ListChunked { + let ca = self.as_utf8(); + + let mut builder = ListUtf8ChunkedBuilder::new(ca.name(), ca.len(), ca.get_values_size()); + + binary_elementwise_for_each(ca, by, |opt_s, opt_by| match (opt_s, opt_by) { + (Some(s), Some(by)) => { + let iter = s.split(by); + builder.append_values_iter(iter); + }, + _ => builder.append_null(), + }); + + builder.finish() + } + + fn split_inclusive(&self, by: &str) -> ListChunked { + let ca = self.as_utf8(); + let mut builder = ListUtf8ChunkedBuilder::new(ca.name(), ca.len(), ca.get_values_size()); + + ca.for_each(|opt_v| match opt_v { + Some(val) => { + let iter = val.split_inclusive(by); + builder.append_values_iter(iter) + }, + _ => builder.append_null(), + }); + builder.finish() + } + + fn split_inclusive_many(&self, by: &Utf8Chunked) -> ListChunked { + let ca = self.as_utf8(); + + let mut builder = ListUtf8ChunkedBuilder::new(ca.name(), ca.len(), ca.get_values_size()); + + binary_elementwise_for_each(ca, by, |opt_s, opt_by| match (opt_s, opt_by) { + (Some(s), Some(by)) => { + let iter = s.split_inclusive(by); + builder.append_values_iter(iter); + }, + _ => builder.append_null(), + }); + + builder.finish() + } + /// Extract each successive non-overlapping regex match in an individual string as an array. fn extract_all_many(&self, pat: &Utf8Chunked) -> PolarsResult { let ca = self.as_utf8(); diff --git a/crates/polars-plan/src/dsl/function_expr/mod.rs b/crates/polars-plan/src/dsl/function_expr/mod.rs index 8caabb2f9700..9ffb7efd2a45 100644 --- a/crates/polars-plan/src/dsl/function_expr/mod.rs +++ b/crates/polars-plan/src/dsl/function_expr/mod.rs @@ -700,6 +700,12 @@ impl From for SpecialEq> { Strptime(dtype, options) => { map_as_slice!(strings::strptime, dtype.clone(), &options) }, + Split => { + map_as_slice!(strings::split) + }, + SplitInclusive => { + map_as_slice!(strings::split_inclusive) + }, #[cfg(feature = "concat_str")] ConcatVertical(delimiter) => map!(strings::concat, &delimiter), #[cfg(feature = "concat_str")] diff --git a/crates/polars-plan/src/dsl/function_expr/strings.rs b/crates/polars-plan/src/dsl/function_expr/strings.rs index ec86edfd41dd..f58e93042438 100644 --- a/crates/polars-plan/src/dsl/function_expr/strings.rs +++ b/crates/polars-plan/src/dsl/function_expr/strings.rs @@ -75,6 +75,8 @@ pub enum StringFunction { StripSuffix(String), #[cfg(feature = "temporal")] Strptime(DataType, StrptimeOptions), + Split, + SplitInclusive, #[cfg(feature = "dtype-decimal")] ToDecimal(usize), #[cfg(feature = "nightly")] @@ -109,6 +111,7 @@ impl StringFunction { Replace { .. } => mapper.with_same_dtype(), #[cfg(feature = "temporal")] Strptime(dtype, _) => mapper.with_dtype(dtype.clone()), + Split | SplitInclusive => mapper.with_dtype(DataType::List(Box::new(DataType::Utf8))), #[cfg(feature = "nightly")] Titlecase => mapper.with_same_dtype(), #[cfg(feature = "dtype-decimal")] @@ -165,6 +168,8 @@ impl Display for StringFunction { StringFunction::StripSuffix(_) => "strip_suffix", #[cfg(feature = "temporal")] StringFunction::Strptime(_, _) => "strptime", + StringFunction::Split => "split", + StringFunction::SplitInclusive => "split_inclusive", #[cfg(feature = "nightly")] StringFunction::Titlecase => "titlecase", #[cfg(feature = "dtype-decimal")] @@ -441,6 +446,44 @@ pub(super) fn strptime( } } +pub(super) fn split(s: &[Series]) -> PolarsResult { + let ca = s[0].utf8()?; + let by = s[1].utf8()?; + + if by.len() == 1 { + if let Some(by) = by.get(0) { + Ok(ca.split(by).into_series()) + } else { + Ok(Series::full_null( + ca.name(), + ca.len(), + &DataType::List(Box::new(DataType::Utf8)), + )) + } + } else { + Ok(ca.split_many(by).into_series()) + } +} + +pub(super) fn split_inclusive(s: &[Series]) -> PolarsResult { + let ca = s[0].utf8()?; + let by = s[1].utf8()?; + + if by.len() == 1 { + if let Some(by) = by.get(0) { + Ok(ca.split_inclusive(by).into_series()) + } else { + Ok(Series::full_null( + ca.name(), + ca.len(), + &DataType::List(Box::new(DataType::Utf8)), + )) + } + } else { + Ok(ca.split_inclusive_many(by).into_series()) + } +} + fn handle_temporal_parsing_error( ca: &Utf8Chunked, out: &Series, diff --git a/crates/polars-plan/src/dsl/string.rs b/crates/polars-plan/src/dsl/string.rs index d4bec70dcbf0..e7d0a7322061 100644 --- a/crates/polars-plan/src/dsl/string.rs +++ b/crates/polars-plan/src/dsl/string.rs @@ -214,53 +214,15 @@ impl StringNameSpace { } /// Split the string by a substring. The resulting dtype is `List`. - pub fn split(self, by: &str) -> Expr { - let by = by.to_string(); - - let function = move |s: Series| { - let ca = s.utf8()?; - - let mut builder = ListUtf8ChunkedBuilder::new(s.name(), s.len(), ca.get_values_size()); - ca.into_iter().for_each(|opt_s| match opt_s { - None => builder.append_null(), - Some(s) => { - let iter = s.split(&by); - builder.append_values_iter(iter); - }, - }); - Ok(Some(builder.finish().into_series())) - }; + pub fn split(self, by: Expr) -> Expr { self.0 - .map( - function, - GetOutput::from_type(DataType::List(Box::new(DataType::Utf8))), - ) - .with_fmt("str.split") + .map_many_private(StringFunction::Split.into(), &[by], false) } /// Split the string by a substring and keep the substring. The resulting dtype is `List`. - pub fn split_inclusive(self, by: &str) -> Expr { - let by = by.to_string(); - - let function = move |s: Series| { - let ca = s.utf8()?; - - let mut builder = ListUtf8ChunkedBuilder::new(s.name(), s.len(), ca.get_values_size()); - ca.into_iter().for_each(|opt_s| match opt_s { - None => builder.append_null(), - Some(s) => { - let iter = s.split_inclusive(&by); - builder.append_values_iter(iter); - }, - }); - Ok(Some(builder.finish().into_series())) - }; + pub fn split_inclusive(self, by: Expr) -> Expr { self.0 - .map( - function, - GetOutput::from_type(DataType::List(Box::new(DataType::Utf8))), - ) - .with_fmt("str.split_inclusive") + .map_many_private(StringFunction::SplitInclusive.into(), &[by], false) } #[cfg(feature = "dtype-struct")] diff --git a/crates/polars/tests/it/lazy/explodes.rs b/crates/polars/tests/it/lazy/explodes.rs index 540af19a1525..01cc6ff69db7 100644 --- a/crates/polars/tests/it/lazy/explodes.rs +++ b/crates/polars/tests/it/lazy/explodes.rs @@ -9,7 +9,7 @@ fn test_explode_row_numbers() -> PolarsResult<()> { "text" => ["one two three four", "uno dos tres cuatro"] ]? .lazy() - .select([col("text").str().split(" ").alias("tokens")]) + .select([col("text").str().split(lit(" ")).alias("tokens")]) .with_row_count("row_nr", None) .explode([col("tokens")]) .select([col("row_nr"), col("tokens")]) diff --git a/py-polars/polars/expr/string.py b/py-polars/polars/expr/string.py index 58c44d3b50bc..f85a3c935b2b 100644 --- a/py-polars/polars/expr/string.py +++ b/py-polars/polars/expr/string.py @@ -1511,7 +1511,7 @@ def count_matches(self, pattern: str | Expr, *, literal: bool = False) -> Expr: pattern = parse_as_expression(pattern, str_as_lit=True) return wrap_expr(self._pyexpr.str_count_matches(pattern, literal)) - def split(self, by: str, *, inclusive: bool = False) -> Expr: + def split(self, by: str | Expr, *, inclusive: bool = False) -> Expr: """ Split the string by a substring. @@ -1524,18 +1524,41 @@ def split(self, by: str, *, inclusive: bool = False) -> Expr: Examples -------- - >>> df = pl.DataFrame({"s": ["foo bar", "foo-bar", "foo bar baz"]}) - >>> df.select(pl.col("s").str.split(by=" ")) - shape: (3, 1) - ┌───────────────────────┐ - │ s │ - │ --- │ - │ list[str] │ - ╞═══════════════════════╡ - │ ["foo", "bar"] │ - │ ["foo-bar"] │ - │ ["foo", "bar", "baz"] │ - └───────────────────────┘ + >>> df = pl.DataFrame({"s": ["foo bar", "foo_bar", "foo_bar_baz"]}) + >>> df.with_columns( + ... pl.col("s").str.split(by="_").alias("split"), + ... pl.col("s").str.split(by="_", inclusive=True).alias("split_inclusive"), + ... ) + shape: (3, 3) + ┌─────────────┬───────────────────────┬─────────────────────────┐ + │ s ┆ split ┆ split_inclusive │ + │ --- ┆ --- ┆ --- │ + │ str ┆ list[str] ┆ list[str] │ + ╞═════════════╪═══════════════════════╪═════════════════════════╡ + │ foo bar ┆ ["foo bar"] ┆ ["foo bar"] │ + │ foo_bar ┆ ["foo", "bar"] ┆ ["foo_", "bar"] │ + │ foo_bar_baz ┆ ["foo", "bar", "baz"] ┆ ["foo_", "bar_", "baz"] │ + └─────────────┴───────────────────────┴─────────────────────────┘ + + >>> df = pl.DataFrame( + ... {"s": ["foo^bar", "foo_bar", "foo*bar*baz"], "by": ["_", "_", "*"]} + ... ) + >>> df.with_columns( + ... pl.col("s").str.split(by=pl.col("by")).alias("split"), + ... pl.col("s") + ... .str.split(by=pl.col("by"), inclusive=True) + ... .alias("split_inclusive"), + ... ) + shape: (3, 4) + ┌─────────────┬─────┬───────────────────────┬─────────────────────────┐ + │ s ┆ by ┆ split ┆ split_inclusive │ + │ --- ┆ --- ┆ --- ┆ --- │ + │ str ┆ str ┆ list[str] ┆ list[str] │ + ╞═════════════╪═════╪═══════════════════════╪═════════════════════════╡ + │ foo^bar ┆ _ ┆ ["foo^bar"] ┆ ["foo^bar"] │ + │ foo_bar ┆ _ ┆ ["foo", "bar"] ┆ ["foo_", "bar"] │ + │ foo*bar*baz ┆ * ┆ ["foo", "bar", "baz"] ┆ ["foo*", "bar*", "baz"] │ + └─────────────┴─────┴───────────────────────┴─────────────────────────┘ Returns ------- @@ -1543,6 +1566,7 @@ def split(self, by: str, *, inclusive: bool = False) -> Expr: Expression of data type :class:`Utf8`. """ + by = parse_as_expression(by, str_as_lit=True) if inclusive: return wrap_expr(self._pyexpr.str_split_inclusive(by)) return wrap_expr(self._pyexpr.str_split(by)) diff --git a/py-polars/src/expr/string.rs b/py-polars/src/expr/string.rs index ed542b6c0fb0..e466e5256254 100644 --- a/py-polars/src/expr/string.rs +++ b/py-polars/src/expr/string.rs @@ -275,12 +275,12 @@ impl PyExpr { .into() } - fn str_split(&self, by: &str) -> Self { - self.inner.clone().str().split(by).into() + fn str_split(&self, by: Self) -> Self { + self.inner.clone().str().split(by.inner).into() } - fn str_split_inclusive(&self, by: &str) -> Self { - self.inner.clone().str().split_inclusive(by).into() + fn str_split_inclusive(&self, by: Self) -> Self { + self.inner.clone().str().split_inclusive(by.inner).into() } fn str_split_exact(&self, by: &str, n: usize) -> Self { diff --git a/py-polars/tests/unit/namespaces/test_string.py b/py-polars/tests/unit/namespaces/test_string.py index 68da200cc7d1..a491a06fb70c 100644 --- a/py-polars/tests/unit/namespaces/test_string.py +++ b/py-polars/tests/unit/namespaces/test_string.py @@ -843,6 +843,31 @@ def test_split() -> None: assert_frame_equal(df["x"].str.split("_", inclusive=True).to_frame(), expected) +def test_split_expr() -> None: + df = pl.DataFrame({"x": ["a_a", None, "b", "c*c*c"], "by": ["_", "#", "^", "*"]}) + out = df.select([pl.col("x").str.split(pl.col("by"))]) + expected = pl.DataFrame( + [ + {"x": ["a", "a"]}, + {"x": None}, + {"x": ["b"]}, + {"x": ["c", "c", "c"]}, + ] + ) + assert_frame_equal(out, expected) + + out = df.select([pl.col("x").str.split(pl.col("by"), inclusive=True)]) + expected = pl.DataFrame( + [ + {"x": ["a_", "a"]}, + {"x": None}, + {"x": ["b"]}, + {"x": ["c*", "c*", "c"]}, + ] + ) + assert_frame_equal(out, expected) + + def test_split_exact() -> None: df = pl.DataFrame({"x": ["a_a", None, "b", "c_c"]}) out = df.select([pl.col("x").str.split_exact("_", 2, inclusive=False)]).unnest("x")