diff --git a/crates/polars-core/src/chunked_array/ops/arity.rs b/crates/polars-core/src/chunked_array/ops/arity.rs index 6bd63a488c10..e275cf202350 100644 --- a/crates/polars-core/src/chunked_array/ops/arity.rs +++ b/crates/polars-core/src/chunked_array/ops/arity.rs @@ -7,8 +7,18 @@ use crate::datatypes::{ArrayCollectIterExt, ArrayFromIter, StaticArray}; use crate::prelude::{ChunkedArray, PolarsDataType}; use crate::utils::{align_chunks_binary, align_chunks_ternary}; +// We need this helper because for<'a> notation can't yet be applied properly +// on the return type. +pub trait BinaryFnMut: FnMut(A1, A2) -> Self::Ret { + type Ret; +} + +impl R> BinaryFnMut for T { + type Ret = R; +} + #[inline] -pub fn binary_elementwise( +pub fn binary_elementwise( lhs: &ChunkedArray, rhs: &ChunkedArray, mut op: F, @@ -17,8 +27,10 @@ where T: PolarsDataType, U: PolarsDataType, V: PolarsDataType, - F: for<'a> FnMut(Option>, Option>) -> Option, - V::Array: ArrayFromIter>, + F: for<'a> BinaryFnMut>, Option>>, + V::Array: for<'a> ArrayFromIter< + >, Option>>>::Ret, + >, { let (lhs, rhs) = align_chunks_binary(lhs, rhs); let iter = lhs diff --git a/crates/polars-core/src/series/arithmetic/borrowed.rs b/crates/polars-core/src/series/arithmetic/borrowed.rs index bb04ddc9c976..6bcda0bfec29 100644 --- a/crates/polars-core/src/series/arithmetic/borrowed.rs +++ b/crates/polars-core/src/series/arithmetic/borrowed.rs @@ -177,10 +177,8 @@ pub mod checked { // see check_div for chunkedarray let rhs = unsafe { lhs.unpack_series_matching_physical_type(rhs) }; - Ok(arity::binary_elementwise::<_, _, Float32Type, _, _>( - lhs, - rhs, - |opt_l, opt_r| match (opt_l, opt_r) { + let ca: Float32Chunked = + arity::binary_elementwise(lhs, rhs, |opt_l, opt_r| match (opt_l, opt_r) { (Some(l), Some(r)) => { if r.is_zero() { None @@ -189,9 +187,8 @@ pub mod checked { } }, _ => None, - }, - ) - .into_series()) + }); + Ok(ca.into_series()) } } @@ -201,10 +198,8 @@ pub mod checked { // see check_div let rhs = unsafe { lhs.unpack_series_matching_physical_type(rhs) }; - Ok(arity::binary_elementwise::<_, _, Float64Type, _, _>( - lhs, - rhs, - |opt_l, opt_r| match (opt_l, opt_r) { + let ca: Float64Chunked = + arity::binary_elementwise(lhs, rhs, |opt_l, opt_r| match (opt_l, opt_r) { (Some(l), Some(r)) => { if r.is_zero() { None @@ -213,9 +208,8 @@ pub mod checked { } }, _ => None, - }, - ) - .into_series()) + }); + Ok(ca.into_series()) } } diff --git a/crates/polars-ops/src/chunked_array/strings/namespace.rs b/crates/polars-ops/src/chunked_array/strings/namespace.rs index fae3a127143c..de73bd294c52 100644 --- a/crates/polars-ops/src/chunked_array/strings/namespace.rs +++ b/crates/polars-ops/src/chunked_array/strings/namespace.rs @@ -14,6 +14,23 @@ use super::*; #[cfg(feature = "binary_encoding")] use crate::chunked_array::binary::BinaryNameSpaceImpl; +// We need this to infer the right lifetimes for the match closure. +#[inline(always)] +fn infer_re_match(f: F) -> F +where + F: for<'a, 'b> FnMut(Option<&'a str>, Option<&'b str>) -> Option, +{ + f +} + +fn opt_strip_prefix<'a>(s: Option<&'a str>, prefix: Option<&str>) -> Option<&'a str> { + Some(s?.strip_prefix(prefix?).unwrap_or(s?)) +} + +fn opt_strip_suffix<'a>(s: Option<&'a str>, suffix: Option<&str>) -> Option<&'a str> { + Some(s?.strip_suffix(suffix?).unwrap_or(s?)) +} + pub trait Utf8NameSpaceImpl: AsUtf8 { #[cfg(not(feature = "binary_encoding"))] fn hex_decode(&self) -> PolarsResult { @@ -122,15 +139,14 @@ pub trait Utf8NameSpaceImpl: AsUtf8 { } else { // A sqrt(n) regex cache is not too small, not too large. let mut reg_cache = FastFixedCache::new((ca.len() as f64).sqrt() as usize); - Ok(binary_elementwise(ca, pat, |opt_src, opt_pat| { - match (opt_src, opt_pat) { - (Some(src), Some(pat)) => { - let reg = reg_cache.try_get_or_insert_with(pat, |p| Regex::new(p)); - reg.ok().map(|re| re.is_match(src)) - }, - _ => None, - } - })) + Ok(binary_elementwise( + ca, + pat, + infer_re_match(|src, pat| { + let reg = reg_cache.try_get_or_insert_with(pat?, |p| Regex::new(p)); + Some(reg.ok()?.is_match(src?)) + }), + )) } }, } @@ -334,6 +350,32 @@ pub trait Utf8NameSpaceImpl: AsUtf8 { Ok(builder.finish()) } + fn strip_prefix(&self, prefix: &Utf8Chunked) -> Utf8Chunked { + let ca = self.as_utf8(); + match prefix.len() { + 1 => match prefix.get(0) { + Some(prefix) => { + ca.apply_generic(|opt_s| opt_s.map(|s| s.strip_prefix(prefix).unwrap_or(s))) + }, + _ => Utf8Chunked::full_null(ca.name(), ca.len()), + }, + _ => binary_elementwise(ca, prefix, opt_strip_prefix), + } + } + + fn strip_suffix(&self, suffix: &Utf8Chunked) -> Utf8Chunked { + let ca = self.as_utf8(); + match suffix.len() { + 1 => match suffix.get(0) { + Some(suffix) => { + ca.apply_generic(|opt_s| opt_s.map(|s| s.strip_suffix(suffix).unwrap_or(s))) + }, + _ => Utf8Chunked::full_null(ca.name(), ca.len()), + }, + _ => binary_elementwise(ca, suffix, opt_strip_suffix), + } + } + fn split(&self, by: &str) -> ListChunked { let ca = self.as_utf8(); let mut builder = ListUtf8ChunkedBuilder::new(ca.name(), ca.len(), ca.get_values_size()); diff --git a/crates/polars-plan/src/dsl/function_expr/mod.rs b/crates/polars-plan/src/dsl/function_expr/mod.rs index 12b7341c35bd..de10aa810c74 100644 --- a/crates/polars-plan/src/dsl/function_expr/mod.rs +++ b/crates/polars-plan/src/dsl/function_expr/mod.rs @@ -738,8 +738,8 @@ impl From for SpecialEq> { StripChars(matches) => map!(strings::strip_chars, matches.as_deref()), StripCharsStart(matches) => map!(strings::strip_chars_start, matches.as_deref()), StripCharsEnd(matches) => map!(strings::strip_chars_end, matches.as_deref()), - StripPrefix(prefix) => map!(strings::strip_prefix, &prefix), - StripSuffix(suffix) => map!(strings::strip_suffix, &suffix), + StripPrefix => map_as_slice!(strings::strip_prefix), + StripSuffix => map_as_slice!(strings::strip_suffix), #[cfg(feature = "string_from_radix")] FromRadix(radix, strict) => map!(strings::from_radix, radix, strict), Slice(start, length) => map!(strings::str_slice, start, length), diff --git a/crates/polars-plan/src/dsl/function_expr/strings.rs b/crates/polars-plan/src/dsl/function_expr/strings.rs index d9f72c0b1ff6..12260dba49dd 100644 --- a/crates/polars-plan/src/dsl/function_expr/strings.rs +++ b/crates/polars-plan/src/dsl/function_expr/strings.rs @@ -71,8 +71,8 @@ pub enum StringFunction { StripChars(Option), StripCharsStart(Option), StripCharsEnd(Option), - StripPrefix(String), - StripSuffix(String), + StripPrefix, + StripSuffix, #[cfg(feature = "temporal")] Strptime(DataType, StrptimeOptions), Split, @@ -121,8 +121,8 @@ impl StringFunction { | StripChars(_) | StripCharsStart(_) | StripCharsEnd(_) - | StripPrefix(_) - | StripSuffix(_) + | StripPrefix + | StripSuffix | Slice(_, _) => mapper.with_same_dtype(), #[cfg(feature = "string_justify")] Zfill { .. } | LJust { .. } | RJust { .. } => mapper.with_same_dtype(), @@ -164,8 +164,8 @@ impl Display for StringFunction { StringFunction::StripChars(_) => "strip_chars", StringFunction::StripCharsStart(_) => "strip_chars_start", StringFunction::StripCharsEnd(_) => "strip_chars_end", - StringFunction::StripPrefix(_) => "strip_prefix", - StringFunction::StripSuffix(_) => "strip_suffix", + StringFunction::StripPrefix => "strip_prefix", + StringFunction::StripSuffix => "strip_suffix", #[cfg(feature = "temporal")] StringFunction::Strptime(_, _) => "strptime", StringFunction::Split => "split", @@ -325,18 +325,16 @@ pub(super) fn strip_chars_end(s: &Series, matches: Option<&str>) -> PolarsResult } } -pub(super) fn strip_prefix(s: &Series, prefix: &str) -> PolarsResult { - let ca = s.utf8()?; - Ok(ca - .apply_values(|s| Cow::Borrowed(s.strip_prefix(prefix).unwrap_or(s))) - .into_series()) +pub(super) fn strip_prefix(s: &[Series]) -> PolarsResult { + let ca = s[0].utf8()?; + let prefix = s[1].utf8()?; + Ok(ca.strip_prefix(prefix).into_series()) } -pub(super) fn strip_suffix(s: &Series, suffix: &str) -> PolarsResult { - let ca = s.utf8()?; - Ok(ca - .apply_values(|s| Cow::Borrowed(s.strip_suffix(suffix).unwrap_or(s))) - .into_series()) +pub(super) fn strip_suffix(s: &[Series]) -> PolarsResult { + let ca = s[0].utf8()?; + let suffix = s[1].utf8()?; + Ok(ca.strip_suffix(suffix).into_series()) } pub(super) fn extract_all(args: &[Series]) -> PolarsResult { diff --git a/crates/polars-plan/src/dsl/string.rs b/crates/polars-plan/src/dsl/string.rs index e7d0a7322061..1022d87bed52 100644 --- a/crates/polars-plan/src/dsl/string.rs +++ b/crates/polars-plan/src/dsl/string.rs @@ -441,19 +441,21 @@ impl StringNameSpace { } /// Remove prefix. - pub fn strip_prefix(self, prefix: String) -> Expr { - self.0 - .map_private(FunctionExpr::StringExpr(StringFunction::StripPrefix( - prefix, - ))) + pub fn strip_prefix(self, prefix: Expr) -> Expr { + self.0.map_many_private( + FunctionExpr::StringExpr(StringFunction::StripPrefix), + &[prefix], + false, + ) } /// Remove suffix. - pub fn strip_suffix(self, suffix: String) -> Expr { - self.0 - .map_private(FunctionExpr::StringExpr(StringFunction::StripSuffix( - suffix, - ))) + pub fn strip_suffix(self, suffix: Expr) -> Expr { + self.0.map_many_private( + FunctionExpr::StringExpr(StringFunction::StripSuffix), + &[suffix], + false, + ) } /// Convert all characters to lowercase. diff --git a/py-polars/polars/expr/string.py b/py-polars/polars/expr/string.py index 1bd7d945eb09..8dd24580e10e 100644 --- a/py-polars/polars/expr/string.py +++ b/py-polars/polars/expr/string.py @@ -680,7 +680,7 @@ def strip_chars_end(self, characters: str | None = None) -> Expr: """ return wrap_expr(self._pyexpr.str_strip_chars_end(characters)) - def strip_prefix(self, prefix: str) -> Expr: + def strip_prefix(self, prefix: IntoExpr) -> Expr: """ Remove prefix. @@ -708,9 +708,10 @@ def strip_prefix(self, prefix: str) -> Expr: └───────────┴──────────┘ """ + prefix = parse_as_expression(prefix, str_as_lit=True) return wrap_expr(self._pyexpr.str_strip_prefix(prefix)) - def strip_suffix(self, suffix: str) -> Expr: + def strip_suffix(self, suffix: IntoExpr) -> Expr: """ Remove suffix. @@ -738,6 +739,7 @@ def strip_suffix(self, suffix: str) -> Expr: └───────────┴──────────┘ """ + suffix = parse_as_expression(suffix, str_as_lit=True) return wrap_expr(self._pyexpr.str_strip_suffix(suffix)) def zfill(self, alignment: int) -> Expr: diff --git a/py-polars/polars/series/string.py b/py-polars/polars/series/string.py index 0a00f53c4e23..f3801fff2190 100644 --- a/py-polars/polars/series/string.py +++ b/py-polars/polars/series/string.py @@ -1208,7 +1208,7 @@ def strip_chars_end(self, characters: str | None = None) -> Series: """ - def strip_prefix(self, prefix: str) -> Series: + def strip_prefix(self, prefix: IntoExpr) -> Series: """ Remove prefix. @@ -1234,7 +1234,7 @@ def strip_prefix(self, prefix: str) -> Series: """ - def strip_suffix(self, suffix: str) -> Series: + def strip_suffix(self, suffix: IntoExpr) -> Series: """ Remove suffix. diff --git a/py-polars/src/expr/string.rs b/py-polars/src/expr/string.rs index e466e5256254..0fad6bfa58d2 100644 --- a/py-polars/src/expr/string.rs +++ b/py-polars/src/expr/string.rs @@ -75,12 +75,12 @@ impl PyExpr { self.inner.clone().str().strip_chars_end(matches).into() } - fn str_strip_prefix(&self, prefix: String) -> Self { - self.inner.clone().str().strip_prefix(prefix).into() + fn str_strip_prefix(&self, prefix: Self) -> Self { + self.inner.clone().str().strip_prefix(prefix.inner).into() } - fn str_strip_suffix(&self, suffix: String) -> Self { - self.inner.clone().str().strip_suffix(suffix).into() + fn str_strip_suffix(&self, suffix: Self) -> Self { + self.inner.clone().str().strip_suffix(suffix.inner).into() } fn str_slice(&self, start: i64, length: Option) -> Self { diff --git a/py-polars/tests/unit/namespaces/test_string.py b/py-polars/tests/unit/namespaces/test_string.py index c8777d562dc5..8309ff68f053 100644 --- a/py-polars/tests/unit/namespaces/test_string.py +++ b/py-polars/tests/unit/namespaces/test_string.py @@ -260,16 +260,40 @@ def test_str_strip_deprecated() -> None: pl.Series(["a", "b", "c"]).str.rstrip() -def test_str_strip_prefix() -> None: - s = pl.Series(["foo:bar", "foofoo:bar", "bar:bar", "foo", ""]) - expected = pl.Series([":bar", "foo:bar", "bar:bar", "", ""]) +def test_str_strip_prefix_literal() -> None: + s = pl.Series(["foo:bar", "foofoo:bar", "bar:bar", "foo", "", None]) + expected = pl.Series([":bar", "foo:bar", "bar:bar", "", "", None]) assert_series_equal(s.str.strip_prefix("foo"), expected) + # test null literal + expected = pl.Series([None, None, None, None, None, None], dtype=pl.Utf8) + assert_series_equal(s.str.strip_prefix(pl.lit(None, dtype=pl.Utf8)), expected) + + +def test_str_strip_prefix_suffix_expr() -> None: + df = pl.DataFrame( + { + "s": ["foo-bar", "foobarbar", "barfoo", "", "anything", None], + "prefix": ["foo", "foobar", "foo", "", None, "bar"], + "suffix": ["bar", "barbar", "bar", "", None, "foo"], + } + ) + out = df.select( + pl.col("s").str.strip_prefix(pl.col("prefix")).alias("strip_prefix"), + pl.col("s").str.strip_suffix(pl.col("suffix")).alias("strip_suffix"), + ) + assert out.to_dict(False) == { + "strip_prefix": ["-bar", "bar", "barfoo", "", None, None], + "strip_suffix": ["foo-", "foo", "barfoo", "", None, None], + } def test_str_strip_suffix() -> None: - s = pl.Series(["foo:bar", "foo:barbar", "foo:foo", "bar", ""]) - expected = pl.Series(["foo:", "foo:bar", "foo:foo", "", ""]) + s = pl.Series(["foo:bar", "foo:barbar", "foo:foo", "bar", "", None]) + expected = pl.Series(["foo:", "foo:bar", "foo:foo", "", "", None]) assert_series_equal(s.str.strip_suffix("bar"), expected) + # test null literal + expected = pl.Series([None, None, None, None, None, None], dtype=pl.Utf8) + assert_series_equal(s.str.strip_suffix(pl.lit(None, dtype=pl.Utf8)), expected) def test_str_split() -> None: