Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Expressify str.split argument. #11117

Merged
merged 5 commits into from
Sep 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions crates/polars-core/src/chunked_array/ops/arity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,24 @@ where
ChunkedArray::from_chunk_iter(lhs.name(), iter)
}

#[inline]
pub fn binary_elementwise_for_each<T, U, F>(lhs: &ChunkedArray<T>, rhs: &ChunkedArray<U>, mut op: F)
ritchie46 marked this conversation as resolved.
Show resolved Hide resolved
where
T: PolarsDataType,
U: PolarsDataType,
F: for<'a> FnMut(Option<T::Physical<'a>>, Option<U::Physical<'a>>),
{
let (lhs, rhs) = align_chunks_binary(lhs, rhs);
lhs.downcast_iter()
.zip(rhs.downcast_iter())
.for_each(|(lhs_arr, rhs_arr)| {
lhs_arr
.iter()
.zip(rhs_arr.iter())
.for_each(|(lhs_opt_val, rhs_opt_val)| op(lhs_opt_val, rhs_opt_val));
})
}

#[inline]
pub fn try_binary_elementwise<T, U, V, F, K, E>(
lhs: &ChunkedArray<T>,
Expand Down
15 changes: 15 additions & 0 deletions crates/polars-core/src/chunked_array/ops/for_each.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
use crate::prelude::*;

impl<T> ChunkedArray<T>
where
T: PolarsDataType,
{
pub fn for_each<'a, F>(&'a self, mut op: F)
where
F: FnMut(Option<T::Physical<'a>>),
{
self.downcast_iter().for_each(|arr| {
arr.iter().for_each(&mut op);
})
}
}
1 change: 1 addition & 0 deletions crates/polars-core/src/chunked_array/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ mod explode_and_offsets;
mod extend;
mod fill_null;
mod filter;
mod for_each;
pub mod full;
#[cfg(feature = "interpolate")]
mod interpolate;
Expand Down
62 changes: 61 additions & 1 deletion crates/polars-ops/src/chunked_array/strings/namespace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use polars_arrow::kernels::string::*;
#[cfg(feature = "string_from_radix")]
use polars_core::export::num::Num;
use polars_core::export::regex::Regex;
use polars_core::prelude::arity::try_binary_elementwise;
use polars_core::prelude::arity::{binary_elementwise_for_each, try_binary_elementwise};
use polars_utils::cache::FastFixedCache;
use regex::escape;

Expand Down Expand Up @@ -311,6 +311,66 @@ pub trait Utf8NameSpaceImpl: AsUtf8 {
Ok(builder.finish())
}

fn split(&self, by: &str) -> ListChunked {
let ca = self.as_utf8();
let mut builder = ListUtf8ChunkedBuilder::new(ca.name(), ca.len(), ca.get_values_size());

ca.for_each(|opt_v| match opt_v {
Some(val) => {
let iter = val.split(by);
builder.append_values_iter(iter)
},
_ => builder.append_null(),
});
builder.finish()
}

fn split_many(&self, by: &Utf8Chunked) -> ListChunked {
let ca = self.as_utf8();

let mut builder = ListUtf8ChunkedBuilder::new(ca.name(), ca.len(), ca.get_values_size());

binary_elementwise_for_each(ca, by, |opt_s, opt_by| match (opt_s, opt_by) {
(Some(s), Some(by)) => {
let iter = s.split(by);
builder.append_values_iter(iter);
},
_ => builder.append_null(),
});

builder.finish()
}

fn split_inclusive(&self, by: &str) -> ListChunked {
let ca = self.as_utf8();
let mut builder = ListUtf8ChunkedBuilder::new(ca.name(), ca.len(), ca.get_values_size());

ca.for_each(|opt_v| match opt_v {
Some(val) => {
let iter = val.split_inclusive(by);
builder.append_values_iter(iter)
},
_ => builder.append_null(),
});
builder.finish()
}

fn split_inclusive_many(&self, by: &Utf8Chunked) -> ListChunked {
let ca = self.as_utf8();

let mut builder = ListUtf8ChunkedBuilder::new(ca.name(), ca.len(), ca.get_values_size());

binary_elementwise_for_each(ca, by, |opt_s, opt_by| match (opt_s, opt_by) {
(Some(s), Some(by)) => {
let iter = s.split_inclusive(by);
builder.append_values_iter(iter);
},
_ => builder.append_null(),
});

builder.finish()
}

/// Extract each successive non-overlapping regex match in an individual string as an array.
fn extract_all_many(&self, pat: &Utf8Chunked) -> PolarsResult<ListChunked> {
let ca = self.as_utf8();
Expand Down
6 changes: 6 additions & 0 deletions crates/polars-plan/src/dsl/function_expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -700,6 +700,12 @@ impl From<StringFunction> for SpecialEq<Arc<dyn SeriesUdf>> {
Strptime(dtype, options) => {
map_as_slice!(strings::strptime, dtype.clone(), &options)
},
Split => {
map_as_slice!(strings::split)
},
SplitInclusive => {
map_as_slice!(strings::split_inclusive)
},
#[cfg(feature = "concat_str")]
ConcatVertical(delimiter) => map!(strings::concat, &delimiter),
#[cfg(feature = "concat_str")]
Expand Down
43 changes: 43 additions & 0 deletions crates/polars-plan/src/dsl/function_expr/strings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ pub enum StringFunction {
StripSuffix(String),
#[cfg(feature = "temporal")]
Strptime(DataType, StrptimeOptions),
Split,
SplitInclusive,
#[cfg(feature = "dtype-decimal")]
ToDecimal(usize),
#[cfg(feature = "nightly")]
Expand Down Expand Up @@ -109,6 +111,7 @@ impl StringFunction {
Replace { .. } => mapper.with_same_dtype(),
#[cfg(feature = "temporal")]
Strptime(dtype, _) => mapper.with_dtype(dtype.clone()),
Split | SplitInclusive => mapper.with_dtype(DataType::List(Box::new(DataType::Utf8))),
#[cfg(feature = "nightly")]
Titlecase => mapper.with_same_dtype(),
#[cfg(feature = "dtype-decimal")]
Expand Down Expand Up @@ -165,6 +168,8 @@ impl Display for StringFunction {
StringFunction::StripSuffix(_) => "strip_suffix",
#[cfg(feature = "temporal")]
StringFunction::Strptime(_, _) => "strptime",
StringFunction::Split => "split",
StringFunction::SplitInclusive => "split_inclusive",
#[cfg(feature = "nightly")]
StringFunction::Titlecase => "titlecase",
#[cfg(feature = "dtype-decimal")]
Expand Down Expand Up @@ -441,6 +446,44 @@ pub(super) fn strptime(
}
}

pub(super) fn split(s: &[Series]) -> PolarsResult<Series> {
let ca = s[0].utf8()?;
let by = s[1].utf8()?;

if by.len() == 1 {
if let Some(by) = by.get(0) {
Ok(ca.split(by).into_series())
} else {
Ok(Series::full_null(
ca.name(),
ca.len(),
&DataType::List(Box::new(DataType::Utf8)),
))
}
} else {
Ok(ca.split_many(by).into_series())
}
}

pub(super) fn split_inclusive(s: &[Series]) -> PolarsResult<Series> {
let ca = s[0].utf8()?;
let by = s[1].utf8()?;

if by.len() == 1 {
if let Some(by) = by.get(0) {
Ok(ca.split_inclusive(by).into_series())
} else {
Ok(Series::full_null(
ca.name(),
ca.len(),
&DataType::List(Box::new(DataType::Utf8)),
))
}
} else {
Ok(ca.split_inclusive_many(by).into_series())
}
}

fn handle_temporal_parsing_error(
ca: &Utf8Chunked,
out: &Series,
Expand Down
46 changes: 4 additions & 42 deletions crates/polars-plan/src/dsl/string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -214,53 +214,15 @@ impl StringNameSpace {
}

/// Split the string by a substring. The resulting dtype is `List<Utf8>`.
pub fn split(self, by: &str) -> Expr {
let by = by.to_string();

let function = move |s: Series| {
let ca = s.utf8()?;

let mut builder = ListUtf8ChunkedBuilder::new(s.name(), s.len(), ca.get_values_size());
ca.into_iter().for_each(|opt_s| match opt_s {
None => builder.append_null(),
Some(s) => {
let iter = s.split(&by);
builder.append_values_iter(iter);
},
});
Ok(Some(builder.finish().into_series()))
};
pub fn split(self, by: Expr) -> Expr {
self.0
.map(
function,
GetOutput::from_type(DataType::List(Box::new(DataType::Utf8))),
)
.with_fmt("str.split")
.map_many_private(StringFunction::Split.into(), &[by], false)
}

/// Split the string by a substring and keep the substring. The resulting dtype is `List<Utf8>`.
pub fn split_inclusive(self, by: &str) -> Expr {
let by = by.to_string();

let function = move |s: Series| {
let ca = s.utf8()?;

let mut builder = ListUtf8ChunkedBuilder::new(s.name(), s.len(), ca.get_values_size());
ca.into_iter().for_each(|opt_s| match opt_s {
None => builder.append_null(),
Some(s) => {
let iter = s.split_inclusive(&by);
builder.append_values_iter(iter);
},
});
Ok(Some(builder.finish().into_series()))
};
pub fn split_inclusive(self, by: Expr) -> Expr {
self.0
.map(
function,
GetOutput::from_type(DataType::List(Box::new(DataType::Utf8))),
)
.with_fmt("str.split_inclusive")
.map_many_private(StringFunction::SplitInclusive.into(), &[by], false)
}

#[cfg(feature = "dtype-struct")]
Expand Down
2 changes: 1 addition & 1 deletion crates/polars/tests/it/lazy/explodes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ fn test_explode_row_numbers() -> PolarsResult<()> {
"text" => ["one two three four", "uno dos tres cuatro"]
]?
.lazy()
.select([col("text").str().split(" ").alias("tokens")])
.select([col("text").str().split(lit(" ")).alias("tokens")])
.with_row_count("row_nr", None)
.explode([col("tokens")])
.select([col("row_nr"), col("tokens")])
Expand Down
50 changes: 37 additions & 13 deletions py-polars/polars/expr/string.py
Original file line number Diff line number Diff line change
Expand Up @@ -1511,7 +1511,7 @@ def count_matches(self, pattern: str | Expr, *, literal: bool = False) -> Expr:
pattern = parse_as_expression(pattern, str_as_lit=True)
return wrap_expr(self._pyexpr.str_count_matches(pattern, literal))

def split(self, by: str, *, inclusive: bool = False) -> Expr:
def split(self, by: str | Expr, *, inclusive: bool = False) -> Expr:
"""
Split the string by a substring.

Expand All @@ -1524,25 +1524,49 @@ def split(self, by: str, *, inclusive: bool = False) -> Expr:

Examples
--------
>>> df = pl.DataFrame({"s": ["foo bar", "foo-bar", "foo bar baz"]})
>>> df.select(pl.col("s").str.split(by=" "))
shape: (3, 1)
┌───────────────────────┐
│ s │
│ --- │
│ list[str] │
╞═══════════════════════╡
│ ["foo", "bar"] │
│ ["foo-bar"] │
│ ["foo", "bar", "baz"] │
└───────────────────────┘
>>> df = pl.DataFrame({"s": ["foo bar", "foo_bar", "foo_bar_baz"]})
>>> df.with_columns(
... pl.col("s").str.split(by="_").alias("split"),
... pl.col("s").str.split(by="_", inclusive=True).alias("split_inclusive"),
... )
shape: (3, 3)
┌─────────────┬───────────────────────┬─────────────────────────┐
│ s ┆ split ┆ split_inclusive │
│ --- ┆ --- ┆ --- │
│ str ┆ list[str] ┆ list[str] │
╞═════════════╪═══════════════════════╪═════════════════════════╡
│ foo bar ┆ ["foo bar"] ┆ ["foo bar"] │
│ foo_bar ┆ ["foo", "bar"] ┆ ["foo_", "bar"] │
│ foo_bar_baz ┆ ["foo", "bar", "baz"] ┆ ["foo_", "bar_", "baz"] │
└─────────────┴───────────────────────┴─────────────────────────┘

>>> df = pl.DataFrame(
... {"s": ["foo^bar", "foo_bar", "foo*bar*baz"], "by": ["_", "_", "*"]}
... )
>>> df.with_columns(
... pl.col("s").str.split(by=pl.col("by")).alias("split"),
... pl.col("s")
... .str.split(by=pl.col("by"), inclusive=True)
... .alias("split_inclusive"),
... )
shape: (3, 4)
┌─────────────┬─────┬───────────────────────┬─────────────────────────┐
│ s ┆ by ┆ split ┆ split_inclusive │
│ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ str ┆ list[str] ┆ list[str] │
╞═════════════╪═════╪═══════════════════════╪═════════════════════════╡
│ foo^bar ┆ _ ┆ ["foo^bar"] ┆ ["foo^bar"] │
│ foo_bar ┆ _ ┆ ["foo", "bar"] ┆ ["foo_", "bar"] │
│ foo*bar*baz ┆ * ┆ ["foo", "bar", "baz"] ┆ ["foo*", "bar*", "baz"] │
└─────────────┴─────┴───────────────────────┴─────────────────────────┘

Returns
-------
Expr
Expression of data type :class:`Utf8`.

"""
by = parse_as_expression(by, str_as_lit=True)
if inclusive:
return wrap_expr(self._pyexpr.str_split_inclusive(by))
return wrap_expr(self._pyexpr.str_split(by))
Expand Down
8 changes: 4 additions & 4 deletions py-polars/src/expr/string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -275,12 +275,12 @@ impl PyExpr {
.into()
}

fn str_split(&self, by: &str) -> Self {
self.inner.clone().str().split(by).into()
fn str_split(&self, by: Self) -> Self {
self.inner.clone().str().split(by.inner).into()
}

fn str_split_inclusive(&self, by: &str) -> Self {
self.inner.clone().str().split_inclusive(by).into()
fn str_split_inclusive(&self, by: Self) -> Self {
self.inner.clone().str().split_inclusive(by.inner).into()
}

fn str_split_exact(&self, by: &str, n: usize) -> Self {
Expand Down
25 changes: 25 additions & 0 deletions py-polars/tests/unit/namespaces/test_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -843,6 +843,31 @@ def test_split() -> None:
assert_frame_equal(df["x"].str.split("_", inclusive=True).to_frame(), expected)


def test_split_expr() -> None:
df = pl.DataFrame({"x": ["a_a", None, "b", "c*c*c"], "by": ["_", "#", "^", "*"]})
out = df.select([pl.col("x").str.split(pl.col("by"))])
expected = pl.DataFrame(
[
{"x": ["a", "a"]},
{"x": None},
{"x": ["b"]},
{"x": ["c", "c", "c"]},
]
)
assert_frame_equal(out, expected)

out = df.select([pl.col("x").str.split(pl.col("by"), inclusive=True)])
expected = pl.DataFrame(
[
{"x": ["a_", "a"]},
{"x": None},
{"x": ["b"]},
{"x": ["c*", "c*", "c"]},
]
)
assert_frame_equal(out, expected)


def test_split_exact() -> None:
df = pl.DataFrame({"x": ["a_a", None, "b", "c_c"]})
out = df.select([pl.col("x").str.split_exact("_", 2, inclusive=False)]).unnest("x")
Expand Down
Loading