Skip to content

Commit

Permalink
feat: Expressify str.split argument.
Browse files Browse the repository at this point in the history
  • Loading branch information
reswqa committed Sep 14, 2023
1 parent 826a1e3 commit da3e144
Show file tree
Hide file tree
Showing 8 changed files with 203 additions and 60 deletions.
18 changes: 18 additions & 0 deletions crates/polars-core/src/chunked_array/ops/arity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,24 @@ where
ChunkedArray::from_chunk_iter(lhs.name(), iter)
}

#[inline]
pub fn binary_elementwise_for_each<T, U, F>(lhs: &ChunkedArray<T>, rhs: &ChunkedArray<U>, mut op: F)
where
T: PolarsDataType,
U: PolarsDataType,
F: for<'a> FnMut(Option<T::Physical<'a>>, Option<U::Physical<'a>>),
{
let (lhs, rhs) = align_chunks_binary(lhs, rhs);
lhs.downcast_iter()
.zip(rhs.downcast_iter())
.for_each(|(lhs_arr, rhs_arr)| {
lhs_arr
.iter()
.zip(rhs_arr.iter())
.for_each(|(lhs_opt_val, rhs_opt_val)| op(lhs_opt_val, rhs_opt_val));
})
}

#[inline]
pub fn try_binary_elementwise<T, U, V, F, K, E>(
lhs: &ChunkedArray<T>,
Expand Down
67 changes: 66 additions & 1 deletion crates/polars-ops/src/chunked_array/strings/namespace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ use polars_arrow::kernels::string::*;
#[cfg(feature = "string_from_radix")]
use polars_core::export::num::Num;
use polars_core::export::regex::Regex;
use polars_core::prelude::arity::try_binary_elementwise;
use polars_core::prelude::arity::{binary_elementwise_for_each, try_binary_elementwise};
use polars_core::utils::rayon::iter::split;
use polars_utils::cache::FastFixedCache;
use regex::escape;

Expand Down Expand Up @@ -311,6 +312,70 @@ pub trait Utf8NameSpaceImpl: AsUtf8 {
Ok(builder.finish())
}

fn split(&self, by: &str) -> ListChunked {
let ca = self.as_utf8();
let mut builder = ListUtf8ChunkedBuilder::new(ca.name(), ca.len(), ca.get_values_size());

ca.downcast_iter().for_each(|arr| {
arr.iter().for_each(|val| match val {
Some(val) => {
let iter = val.split(by);
builder.append_values_iter(iter)
},
_ => builder.append_null(),
})
});
builder.finish()
}

fn split_many(&self, by: &Utf8Chunked) -> ListChunked {
let ca = self.as_utf8();

let mut builder = ListUtf8ChunkedBuilder::new(ca.name(), ca.len(), ca.get_values_size());

binary_elementwise_for_each(ca, by, |opt_s, opt_by| match (opt_s, opt_by) {
(Some(s), Some(by)) => {
let iter = s.split(by);
builder.append_values_iter(iter);
},
_ => builder.append_null(),
});

builder.finish()
}

fn split_inclusive(&self, by: &str) -> ListChunked {
let ca = self.as_utf8();
let mut builder = ListUtf8ChunkedBuilder::new(ca.name(), ca.len(), ca.get_values_size());

ca.downcast_iter().for_each(|arr| {
arr.iter().for_each(|val| match val {
Some(val) => {
let iter = val.split_inclusive(by);
builder.append_values_iter(iter)
},
_ => builder.append_null(),
})
});
builder.finish()
}

fn split_inclusive_many(&self, by: &Utf8Chunked) -> ListChunked {
let ca = self.as_utf8();

let mut builder = ListUtf8ChunkedBuilder::new(ca.name(), ca.len(), ca.get_values_size());

binary_elementwise_for_each(ca, by, |opt_s, opt_by| match (opt_s, opt_by) {
(Some(s), Some(by)) => {
let iter = s.split_inclusive(by);
builder.append_values_iter(iter);
},
_ => builder.append_null(),
});

builder.finish()
}

/// Extract each successive non-overlapping regex match in an individual string as an array.
fn extract_all_many(&self, pat: &Utf8Chunked) -> PolarsResult<ListChunked> {
let ca = self.as_utf8();
Expand Down
6 changes: 6 additions & 0 deletions crates/polars-plan/src/dsl/function_expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -680,6 +680,12 @@ impl From<StringFunction> for SpecialEq<Arc<dyn SeriesUdf>> {
Strptime(dtype, options) => {
map_as_slice!(strings::strptime, dtype.clone(), &options)
},
Split => {
map_as_slice!(strings::split)
},
SplitInclusive => {
map_as_slice!(strings::split_inclusive)
},
#[cfg(feature = "concat_str")]
ConcatVertical(delimiter) => map!(strings::concat, &delimiter),
#[cfg(feature = "concat_str")]
Expand Down
43 changes: 43 additions & 0 deletions crates/polars-plan/src/dsl/function_expr/strings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ pub enum StringFunction {
StripSuffix(String),
#[cfg(feature = "temporal")]
Strptime(DataType, StrptimeOptions),
Split,
SplitInclusive,
#[cfg(feature = "dtype-decimal")]
ToDecimal(usize),
#[cfg(feature = "nightly")]
Expand Down Expand Up @@ -109,6 +111,7 @@ impl StringFunction {
Replace { .. } => mapper.with_same_dtype(),
#[cfg(feature = "temporal")]
Strptime(dtype, _) => mapper.with_dtype(dtype.clone()),
Split | SplitInclusive => mapper.with_dtype(DataType::List(Box::new(DataType::Utf8))),
#[cfg(feature = "nightly")]
Titlecase => mapper.with_same_dtype(),
#[cfg(feature = "dtype-decimal")]
Expand Down Expand Up @@ -165,6 +168,8 @@ impl Display for StringFunction {
StringFunction::StripSuffix(_) => "strip_suffix",
#[cfg(feature = "temporal")]
StringFunction::Strptime(_, _) => "strptime",
StringFunction::Split => "split",
StringFunction::SplitInclusive => "split_inclusive",
#[cfg(feature = "nightly")]
StringFunction::Titlecase => "titlecase",
#[cfg(feature = "dtype-decimal")]
Expand Down Expand Up @@ -441,6 +446,44 @@ pub(super) fn strptime(
}
}

pub(super) fn split(s: &[Series]) -> PolarsResult<Series> {
let ca = s[0].utf8()?;
let by = s[1].utf8()?;

if by.len() == 1 {
if let Some(by) = by.get(0) {
Ok(ca.split(by).into_series())
} else {
Ok(Series::full_null(
ca.name(),
ca.len(),
&DataType::List(Box::new(DataType::Utf8)),
))
}
} else {
Ok(ca.split_many(by).into_series())
}
}

pub(super) fn split_inclusive(s: &[Series]) -> PolarsResult<Series> {
let ca = s[0].utf8()?;
let by = s[1].utf8()?;

if by.len() == 1 {
if let Some(by) = by.get(0) {
Ok(ca.split_inclusive(by).into_series())
} else {
Ok(Series::full_null(
ca.name(),
ca.len(),
&DataType::List(Box::new(DataType::Utf8)),
))
}
} else {
Ok(ca.split_inclusive_many(by).into_series())
}
}

fn handle_temporal_parsing_error(
ca: &Utf8Chunked,
out: &Series,
Expand Down
46 changes: 4 additions & 42 deletions crates/polars-plan/src/dsl/string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -214,53 +214,15 @@ impl StringNameSpace {
}

/// Split the string by a substring. The resulting dtype is `List<Utf8>`.
pub fn split(self, by: &str) -> Expr {
let by = by.to_string();

let function = move |s: Series| {
let ca = s.utf8()?;

let mut builder = ListUtf8ChunkedBuilder::new(s.name(), s.len(), ca.get_values_size());
ca.into_iter().for_each(|opt_s| match opt_s {
None => builder.append_null(),
Some(s) => {
let iter = s.split(&by);
builder.append_values_iter(iter);
},
});
Ok(Some(builder.finish().into_series()))
};
pub fn split(self, by: Expr) -> Expr {
self.0
.map(
function,
GetOutput::from_type(DataType::List(Box::new(DataType::Utf8))),
)
.with_fmt("str.split")
.map_many_private(StringFunction::Split.into(), &[by], false)
}

/// Split the string by a substring and keep the substring. The resulting dtype is `List<Utf8>`.
pub fn split_inclusive(self, by: &str) -> Expr {
let by = by.to_string();

let function = move |s: Series| {
let ca = s.utf8()?;

let mut builder = ListUtf8ChunkedBuilder::new(s.name(), s.len(), ca.get_values_size());
ca.into_iter().for_each(|opt_s| match opt_s {
None => builder.append_null(),
Some(s) => {
let iter = s.split_inclusive(&by);
builder.append_values_iter(iter);
},
});
Ok(Some(builder.finish().into_series()))
};
pub fn split_inclusive(self, by: Expr) -> Expr {
self.0
.map(
function,
GetOutput::from_type(DataType::List(Box::new(DataType::Utf8))),
)
.with_fmt("str.split_inclusive")
.map_many_private(StringFunction::SplitInclusive.into(), &[by], false)
}

#[cfg(feature = "dtype-struct")]
Expand Down
50 changes: 37 additions & 13 deletions py-polars/polars/expr/string.py
Original file line number Diff line number Diff line change
Expand Up @@ -1511,7 +1511,7 @@ def count_matches(self, pattern: str | Expr, *, literal: bool = False) -> Expr:
pattern = parse_as_expression(pattern, str_as_lit=True)
return wrap_expr(self._pyexpr.str_count_matches(pattern, literal))

def split(self, by: str, *, inclusive: bool = False) -> Expr:
def split(self, by: str | Expr, *, inclusive: bool = False) -> Expr:
"""
Split the string by a substring.
Expand All @@ -1524,25 +1524,49 @@ def split(self, by: str, *, inclusive: bool = False) -> Expr:
Examples
--------
>>> df = pl.DataFrame({"s": ["foo bar", "foo-bar", "foo bar baz"]})
>>> df.select(pl.col("s").str.split(by=" "))
shape: (3, 1)
┌───────────────────────┐
│ s │
│ --- │
│ list[str] │
╞═══════════════════════╡
│ ["foo", "bar"] │
│ ["foo-bar"] │
│ ["foo", "bar", "baz"] │
└───────────────────────┘
>>> df = pl.DataFrame({"s": ["foo bar", "foo_bar", "foo_bar_baz"]})
>>> df.with_columns(
... pl.col("s").str.split(by="_").alias("split"),
... pl.col("s").str.split(by="_", inclusive=True).alias("split_inclusive"),
... )
shape: (3, 3)
┌─────────────┬───────────────────────┬─────────────────────────┐
│ s ┆ split ┆ split_inclusive │
│ --- ┆ --- ┆ --- │
│ str ┆ list[str] ┆ list[str] │
╞═════════════╪═══════════════════════╪═════════════════════════╡
│ foo bar ┆ ["foo bar"] ┆ ["foo bar"] │
│ foo_bar ┆ ["foo", "bar"] ┆ ["foo_", "bar"] │
│ foo_bar_baz ┆ ["foo", "bar", "baz"] ┆ ["foo_", "bar_", "baz"] │
└─────────────┴───────────────────────┴─────────────────────────┘
>>> df = pl.DataFrame(
... {"s": ["foo^bar", "foo_bar", "foo*bar*baz"], "by": ["_", "_", "*"]}
... )
>>> df.with_columns(
... pl.col("s").str.split(by=pl.col("by")).alias("split"),
... pl.col("s")
... .str.split(by=pl.col("by"), inclusive=True)
... .alias("split_inclusive"),
... )
shape: (3, 4)
┌─────────────┬─────┬───────────────────────┬─────────────────────────┐
│ s ┆ by ┆ split ┆ split_inclusive │
│ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ str ┆ list[str] ┆ list[str] │
╞═════════════╪═════╪═══════════════════════╪═════════════════════════╡
│ foo^bar ┆ _ ┆ ["foo^bar"] ┆ ["foo^bar"] │
│ foo_bar ┆ _ ┆ ["foo", "bar"] ┆ ["foo_", "bar"] │
│ foo*bar*baz ┆ * ┆ ["foo", "bar", "baz"] ┆ ["foo*", "bar*", "baz"] │
└─────────────┴─────┴───────────────────────┴─────────────────────────┘
Returns
-------
Expr
Expression of data type :class:`Utf8`.
"""
by = parse_as_expression(by, str_as_lit=True)
if inclusive:
return wrap_expr(self._pyexpr.str_split_inclusive(by))
return wrap_expr(self._pyexpr.str_split(by))
Expand Down
8 changes: 4 additions & 4 deletions py-polars/src/expr/string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -275,12 +275,12 @@ impl PyExpr {
.into()
}

fn str_split(&self, by: &str) -> Self {
self.inner.clone().str().split(by).into()
fn str_split(&self, by: Self) -> Self {
self.inner.clone().str().split(by.inner).into()
}

fn str_split_inclusive(&self, by: &str) -> Self {
self.inner.clone().str().split_inclusive(by).into()
fn str_split_inclusive(&self, by: Self) -> Self {
self.inner.clone().str().split_inclusive(by.inner).into()
}

fn str_split_exact(&self, by: &str, n: usize) -> Self {
Expand Down
25 changes: 25 additions & 0 deletions py-polars/tests/unit/namespaces/test_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -843,6 +843,31 @@ def test_split() -> None:
assert_frame_equal(df["x"].str.split("_", inclusive=True).to_frame(), expected)


def test_split_expr() -> None:
df = pl.DataFrame({"x": ["a_a", None, "b", "c*c*c"], "by": ["_", "#", "^", "*"]})
out = df.select([pl.col("x").str.split(pl.col("by"))])
expected = pl.DataFrame(
[
{"x": ["a", "a"]},
{"x": None},
{"x": ["b"]},
{"x": ["c", "c", "c"]},
]
)
assert_frame_equal(out, expected)

out = df.select([pl.col("x").str.split(pl.col("by"), inclusive=True)])
expected = pl.DataFrame(
[
{"x": ["a_", "a"]},
{"x": None},
{"x": ["b"]},
{"x": ["c*", "c*", "c"]},
]
)
assert_frame_equal(out, expected)


def test_split_exact() -> None:
df = pl.DataFrame({"x": ["a_a", None, "b", "c_c"]})
out = df.select([pl.col("x").str.split_exact("_", 2, inclusive=False)]).unnest("x")
Expand Down

0 comments on commit da3e144

Please sign in to comment.