Skip to content

Commit

Permalink
feat: Expressify str.strip_prefix & suffix (#11197)
Browse files Browse the repository at this point in the history
Co-authored-by: Orson Peters <[email protected]>
  • Loading branch information
reswqa and orlp authored Sep 21, 2023
1 parent 6c8b9b7 commit 431f85f
Show file tree
Hide file tree
Showing 10 changed files with 141 additions and 67 deletions.
18 changes: 15 additions & 3 deletions crates/polars-core/src/chunked_array/ops/arity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,18 @@ use crate::datatypes::{ArrayCollectIterExt, ArrayFromIter, StaticArray};
use crate::prelude::{ChunkedArray, PolarsDataType};
use crate::utils::{align_chunks_binary, align_chunks_ternary};

// We need this helper because for<'a> notation can't yet be applied properly
// on the return type.
pub trait BinaryFnMut<A1, A2>: FnMut(A1, A2) -> Self::Ret {
type Ret;
}

impl<A1, A2, R, T: FnMut(A1, A2) -> R> BinaryFnMut<A1, A2> for T {
type Ret = R;
}

#[inline]
pub fn binary_elementwise<T, U, V, F, K>(
pub fn binary_elementwise<T, U, V, F>(
lhs: &ChunkedArray<T>,
rhs: &ChunkedArray<U>,
mut op: F,
Expand All @@ -17,8 +27,10 @@ where
T: PolarsDataType,
U: PolarsDataType,
V: PolarsDataType,
F: for<'a> FnMut(Option<T::Physical<'a>>, Option<U::Physical<'a>>) -> Option<K>,
V::Array: ArrayFromIter<Option<K>>,
F: for<'a> BinaryFnMut<Option<T::Physical<'a>>, Option<U::Physical<'a>>>,
V::Array: for<'a> ArrayFromIter<
<F as BinaryFnMut<Option<T::Physical<'a>>, Option<U::Physical<'a>>>>::Ret,
>,
{
let (lhs, rhs) = align_chunks_binary(lhs, rhs);
let iter = lhs
Expand Down
22 changes: 8 additions & 14 deletions crates/polars-core/src/series/arithmetic/borrowed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,10 +177,8 @@ pub mod checked {
// see check_div for chunkedarray<T>
let rhs = unsafe { lhs.unpack_series_matching_physical_type(rhs) };

Ok(arity::binary_elementwise::<_, _, Float32Type, _, _>(
lhs,
rhs,
|opt_l, opt_r| match (opt_l, opt_r) {
let ca: Float32Chunked =
arity::binary_elementwise(lhs, rhs, |opt_l, opt_r| match (opt_l, opt_r) {
(Some(l), Some(r)) => {
if r.is_zero() {
None
Expand All @@ -189,9 +187,8 @@ pub mod checked {
}
},
_ => None,
},
)
.into_series())
});
Ok(ca.into_series())
}
}

Expand All @@ -201,10 +198,8 @@ pub mod checked {
// see check_div
let rhs = unsafe { lhs.unpack_series_matching_physical_type(rhs) };

Ok(arity::binary_elementwise::<_, _, Float64Type, _, _>(
lhs,
rhs,
|opt_l, opt_r| match (opt_l, opt_r) {
let ca: Float64Chunked =
arity::binary_elementwise(lhs, rhs, |opt_l, opt_r| match (opt_l, opt_r) {
(Some(l), Some(r)) => {
if r.is_zero() {
None
Expand All @@ -213,9 +208,8 @@ pub mod checked {
}
},
_ => None,
},
)
.into_series())
});
Ok(ca.into_series())
}
}

Expand Down
60 changes: 51 additions & 9 deletions crates/polars-ops/src/chunked_array/strings/namespace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,23 @@ use super::*;
#[cfg(feature = "binary_encoding")]
use crate::chunked_array::binary::BinaryNameSpaceImpl;

// We need this to infer the right lifetimes for the match closure.
#[inline(always)]
fn infer_re_match<F>(f: F) -> F
where
F: for<'a, 'b> FnMut(Option<&'a str>, Option<&'b str>) -> Option<bool>,
{
f
}

fn opt_strip_prefix<'a>(s: Option<&'a str>, prefix: Option<&str>) -> Option<&'a str> {
Some(s?.strip_prefix(prefix?).unwrap_or(s?))
}

fn opt_strip_suffix<'a>(s: Option<&'a str>, suffix: Option<&str>) -> Option<&'a str> {
Some(s?.strip_suffix(suffix?).unwrap_or(s?))
}

pub trait Utf8NameSpaceImpl: AsUtf8 {
#[cfg(not(feature = "binary_encoding"))]
fn hex_decode(&self) -> PolarsResult<Utf8Chunked> {
Expand Down Expand Up @@ -122,15 +139,14 @@ pub trait Utf8NameSpaceImpl: AsUtf8 {
} else {
// A sqrt(n) regex cache is not too small, not too large.
let mut reg_cache = FastFixedCache::new((ca.len() as f64).sqrt() as usize);
Ok(binary_elementwise(ca, pat, |opt_src, opt_pat| {
match (opt_src, opt_pat) {
(Some(src), Some(pat)) => {
let reg = reg_cache.try_get_or_insert_with(pat, |p| Regex::new(p));
reg.ok().map(|re| re.is_match(src))
},
_ => None,
}
}))
Ok(binary_elementwise(
ca,
pat,
infer_re_match(|src, pat| {
let reg = reg_cache.try_get_or_insert_with(pat?, |p| Regex::new(p));
Some(reg.ok()?.is_match(src?))
}),
))
}
},
}
Expand Down Expand Up @@ -334,6 +350,32 @@ pub trait Utf8NameSpaceImpl: AsUtf8 {
Ok(builder.finish())
}

fn strip_prefix(&self, prefix: &Utf8Chunked) -> Utf8Chunked {
let ca = self.as_utf8();
match prefix.len() {
1 => match prefix.get(0) {
Some(prefix) => {
ca.apply_generic(|opt_s| opt_s.map(|s| s.strip_prefix(prefix).unwrap_or(s)))
},
_ => Utf8Chunked::full_null(ca.name(), ca.len()),
},
_ => binary_elementwise(ca, prefix, opt_strip_prefix),
}
}

fn strip_suffix(&self, suffix: &Utf8Chunked) -> Utf8Chunked {
let ca = self.as_utf8();
match suffix.len() {
1 => match suffix.get(0) {
Some(suffix) => {
ca.apply_generic(|opt_s| opt_s.map(|s| s.strip_suffix(suffix).unwrap_or(s)))
},
_ => Utf8Chunked::full_null(ca.name(), ca.len()),
},
_ => binary_elementwise(ca, suffix, opt_strip_suffix),
}
}

fn split(&self, by: &str) -> ListChunked {
let ca = self.as_utf8();
let mut builder = ListUtf8ChunkedBuilder::new(ca.name(), ca.len(), ca.get_values_size());
Expand Down
4 changes: 2 additions & 2 deletions crates/polars-plan/src/dsl/function_expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -738,8 +738,8 @@ impl From<StringFunction> for SpecialEq<Arc<dyn SeriesUdf>> {
StripChars(matches) => map!(strings::strip_chars, matches.as_deref()),
StripCharsStart(matches) => map!(strings::strip_chars_start, matches.as_deref()),
StripCharsEnd(matches) => map!(strings::strip_chars_end, matches.as_deref()),
StripPrefix(prefix) => map!(strings::strip_prefix, &prefix),
StripSuffix(suffix) => map!(strings::strip_suffix, &suffix),
StripPrefix => map_as_slice!(strings::strip_prefix),
StripSuffix => map_as_slice!(strings::strip_suffix),
#[cfg(feature = "string_from_radix")]
FromRadix(radix, strict) => map!(strings::from_radix, radix, strict),
Slice(start, length) => map!(strings::str_slice, start, length),
Expand Down
30 changes: 14 additions & 16 deletions crates/polars-plan/src/dsl/function_expr/strings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ pub enum StringFunction {
StripChars(Option<String>),
StripCharsStart(Option<String>),
StripCharsEnd(Option<String>),
StripPrefix(String),
StripSuffix(String),
StripPrefix,
StripSuffix,
#[cfg(feature = "temporal")]
Strptime(DataType, StrptimeOptions),
Split,
Expand Down Expand Up @@ -121,8 +121,8 @@ impl StringFunction {
| StripChars(_)
| StripCharsStart(_)
| StripCharsEnd(_)
| StripPrefix(_)
| StripSuffix(_)
| StripPrefix
| StripSuffix
| Slice(_, _) => mapper.with_same_dtype(),
#[cfg(feature = "string_justify")]
Zfill { .. } | LJust { .. } | RJust { .. } => mapper.with_same_dtype(),
Expand Down Expand Up @@ -164,8 +164,8 @@ impl Display for StringFunction {
StringFunction::StripChars(_) => "strip_chars",
StringFunction::StripCharsStart(_) => "strip_chars_start",
StringFunction::StripCharsEnd(_) => "strip_chars_end",
StringFunction::StripPrefix(_) => "strip_prefix",
StringFunction::StripSuffix(_) => "strip_suffix",
StringFunction::StripPrefix => "strip_prefix",
StringFunction::StripSuffix => "strip_suffix",
#[cfg(feature = "temporal")]
StringFunction::Strptime(_, _) => "strptime",
StringFunction::Split => "split",
Expand Down Expand Up @@ -325,18 +325,16 @@ pub(super) fn strip_chars_end(s: &Series, matches: Option<&str>) -> PolarsResult
}
}

pub(super) fn strip_prefix(s: &Series, prefix: &str) -> PolarsResult<Series> {
let ca = s.utf8()?;
Ok(ca
.apply_values(|s| Cow::Borrowed(s.strip_prefix(prefix).unwrap_or(s)))
.into_series())
pub(super) fn strip_prefix(s: &[Series]) -> PolarsResult<Series> {
let ca = s[0].utf8()?;
let prefix = s[1].utf8()?;
Ok(ca.strip_prefix(prefix).into_series())
}

pub(super) fn strip_suffix(s: &Series, suffix: &str) -> PolarsResult<Series> {
let ca = s.utf8()?;
Ok(ca
.apply_values(|s| Cow::Borrowed(s.strip_suffix(suffix).unwrap_or(s)))
.into_series())
pub(super) fn strip_suffix(s: &[Series]) -> PolarsResult<Series> {
let ca = s[0].utf8()?;
let suffix = s[1].utf8()?;
Ok(ca.strip_suffix(suffix).into_series())
}

pub(super) fn extract_all(args: &[Series]) -> PolarsResult<Series> {
Expand Down
22 changes: 12 additions & 10 deletions crates/polars-plan/src/dsl/string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -441,19 +441,21 @@ impl StringNameSpace {
}

/// Remove prefix.
pub fn strip_prefix(self, prefix: String) -> Expr {
self.0
.map_private(FunctionExpr::StringExpr(StringFunction::StripPrefix(
prefix,
)))
pub fn strip_prefix(self, prefix: Expr) -> Expr {
self.0.map_many_private(
FunctionExpr::StringExpr(StringFunction::StripPrefix),
&[prefix],
false,
)
}

/// Remove suffix.
pub fn strip_suffix(self, suffix: String) -> Expr {
self.0
.map_private(FunctionExpr::StringExpr(StringFunction::StripSuffix(
suffix,
)))
pub fn strip_suffix(self, suffix: Expr) -> Expr {
self.0.map_many_private(
FunctionExpr::StringExpr(StringFunction::StripSuffix),
&[suffix],
false,
)
}

/// Convert all characters to lowercase.
Expand Down
6 changes: 4 additions & 2 deletions py-polars/polars/expr/string.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,7 +680,7 @@ def strip_chars_end(self, characters: str | None = None) -> Expr:
"""
return wrap_expr(self._pyexpr.str_strip_chars_end(characters))

def strip_prefix(self, prefix: str) -> Expr:
def strip_prefix(self, prefix: IntoExpr) -> Expr:
"""
Remove prefix.
Expand Down Expand Up @@ -708,9 +708,10 @@ def strip_prefix(self, prefix: str) -> Expr:
└───────────┴──────────┘
"""
prefix = parse_as_expression(prefix, str_as_lit=True)
return wrap_expr(self._pyexpr.str_strip_prefix(prefix))

def strip_suffix(self, suffix: str) -> Expr:
def strip_suffix(self, suffix: IntoExpr) -> Expr:
"""
Remove suffix.
Expand Down Expand Up @@ -738,6 +739,7 @@ def strip_suffix(self, suffix: str) -> Expr:
└───────────┴──────────┘
"""
suffix = parse_as_expression(suffix, str_as_lit=True)
return wrap_expr(self._pyexpr.str_strip_suffix(suffix))

def zfill(self, alignment: int) -> Expr:
Expand Down
4 changes: 2 additions & 2 deletions py-polars/polars/series/string.py
Original file line number Diff line number Diff line change
Expand Up @@ -1208,7 +1208,7 @@ def strip_chars_end(self, characters: str | None = None) -> Series:
"""

def strip_prefix(self, prefix: str) -> Series:
def strip_prefix(self, prefix: IntoExpr) -> Series:
"""
Remove prefix.
Expand All @@ -1234,7 +1234,7 @@ def strip_prefix(self, prefix: str) -> Series:
"""

def strip_suffix(self, suffix: str) -> Series:
def strip_suffix(self, suffix: IntoExpr) -> Series:
"""
Remove suffix.
Expand Down
8 changes: 4 additions & 4 deletions py-polars/src/expr/string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,12 @@ impl PyExpr {
self.inner.clone().str().strip_chars_end(matches).into()
}

fn str_strip_prefix(&self, prefix: String) -> Self {
self.inner.clone().str().strip_prefix(prefix).into()
fn str_strip_prefix(&self, prefix: Self) -> Self {
self.inner.clone().str().strip_prefix(prefix.inner).into()
}

fn str_strip_suffix(&self, suffix: String) -> Self {
self.inner.clone().str().strip_suffix(suffix).into()
fn str_strip_suffix(&self, suffix: Self) -> Self {
self.inner.clone().str().strip_suffix(suffix.inner).into()
}

fn str_slice(&self, start: i64, length: Option<u64>) -> Self {
Expand Down
34 changes: 29 additions & 5 deletions py-polars/tests/unit/namespaces/test_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,16 +260,40 @@ def test_str_strip_deprecated() -> None:
pl.Series(["a", "b", "c"]).str.rstrip()


def test_str_strip_prefix() -> None:
s = pl.Series(["foo:bar", "foofoo:bar", "bar:bar", "foo", ""])
expected = pl.Series([":bar", "foo:bar", "bar:bar", "", ""])
def test_str_strip_prefix_literal() -> None:
s = pl.Series(["foo:bar", "foofoo:bar", "bar:bar", "foo", "", None])
expected = pl.Series([":bar", "foo:bar", "bar:bar", "", "", None])
assert_series_equal(s.str.strip_prefix("foo"), expected)
# test null literal
expected = pl.Series([None, None, None, None, None, None], dtype=pl.Utf8)
assert_series_equal(s.str.strip_prefix(pl.lit(None, dtype=pl.Utf8)), expected)


def test_str_strip_prefix_suffix_expr() -> None:
df = pl.DataFrame(
{
"s": ["foo-bar", "foobarbar", "barfoo", "", "anything", None],
"prefix": ["foo", "foobar", "foo", "", None, "bar"],
"suffix": ["bar", "barbar", "bar", "", None, "foo"],
}
)
out = df.select(
pl.col("s").str.strip_prefix(pl.col("prefix")).alias("strip_prefix"),
pl.col("s").str.strip_suffix(pl.col("suffix")).alias("strip_suffix"),
)
assert out.to_dict(False) == {
"strip_prefix": ["-bar", "bar", "barfoo", "", None, None],
"strip_suffix": ["foo-", "foo", "barfoo", "", None, None],
}


def test_str_strip_suffix() -> None:
s = pl.Series(["foo:bar", "foo:barbar", "foo:foo", "bar", ""])
expected = pl.Series(["foo:", "foo:bar", "foo:foo", "", ""])
s = pl.Series(["foo:bar", "foo:barbar", "foo:foo", "bar", "", None])
expected = pl.Series(["foo:", "foo:bar", "foo:foo", "", "", None])
assert_series_equal(s.str.strip_suffix("bar"), expected)
# test null literal
expected = pl.Series([None, None, None, None, None, None], dtype=pl.Utf8)
assert_series_equal(s.str.strip_suffix(pl.lit(None, dtype=pl.Utf8)), expected)


def test_str_split() -> None:
Expand Down

0 comments on commit 431f85f

Please sign in to comment.