Skip to content

Commit

Permalink
feat: str.strip_chars supports take an expr argument
Browse files Browse the repository at this point in the history
  • Loading branch information
reswqa committed Sep 25, 2023
1 parent 34043ce commit ca09ce8
Show file tree
Hide file tree
Showing 13 changed files with 309 additions and 141 deletions.
6 changes: 6 additions & 0 deletions crates/polars-ops/src/chunked_array/strings/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,22 @@ mod justify;
mod namespace;
#[cfg(feature = "strings")]
mod replace;
#[cfg(feature = "strings")]
mod split;
#[cfg(feature = "strings")]
mod strip;
#[cfg(feature = "strings")]
mod substring;

#[cfg(feature = "extract_jsonpath")]
pub use json_path::*;
#[cfg(feature = "strings")]
pub use namespace::*;
use polars_core::prelude::*;
#[cfg(feature = "strings")]
pub use split::*;
#[cfg(feature = "strings")]
pub use strip::*;

pub trait AsUtf8 {
fn as_utf8(&self) -> &Utf8Chunked;
Expand Down
55 changes: 29 additions & 26 deletions crates/polars-ops/src/chunked_array/strings/namespace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,6 @@ where
f
}

fn opt_strip_prefix<'a>(s: Option<&'a str>, prefix: Option<&str>) -> Option<&'a str> {
Some(s?.strip_prefix(prefix?).unwrap_or(s?))
}

fn opt_strip_suffix<'a>(s: Option<&'a str>, suffix: Option<&str>) -> Option<&'a str> {
Some(s?.strip_suffix(suffix?).unwrap_or(s?))
}

pub trait Utf8NameSpaceImpl: AsUtf8 {
#[cfg(not(feature = "binary_encoding"))]
fn hex_decode(&self) -> PolarsResult<Utf8Chunked> {
Expand Down Expand Up @@ -350,32 +342,43 @@ pub trait Utf8NameSpaceImpl: AsUtf8 {
Ok(builder.finish())
}

fn strip_prefix(&self, prefix: &Utf8Chunked) -> Utf8Chunked {
fn strip_chars(&self, pat: &Series) -> PolarsResult<Utf8Chunked> {
let ca = self.as_utf8();
match prefix.len() {
1 => match prefix.get(0) {
Some(prefix) => {
ca.apply_generic(|opt_s| opt_s.map(|s| s.strip_prefix(prefix).unwrap_or(s)))
},
_ => Utf8Chunked::full_null(ca.name(), ca.len()),
},
_ => binary_elementwise(ca, prefix, opt_strip_prefix),
if pat.dtype() == &DataType::Null {
Ok(ca.apply_generic(|opt_s| opt_s.map(|s| s.trim())))
} else {
Ok(strip_chars(ca, pat.utf8()?))
}
}

fn strip_suffix(&self, suffix: &Utf8Chunked) -> Utf8Chunked {
fn strip_chars_start(&self, pat: &Series) -> PolarsResult<Utf8Chunked> {
let ca = self.as_utf8();
match suffix.len() {
1 => match suffix.get(0) {
Some(suffix) => {
ca.apply_generic(|opt_s| opt_s.map(|s| s.strip_suffix(suffix).unwrap_or(s)))
},
_ => Utf8Chunked::full_null(ca.name(), ca.len()),
},
_ => binary_elementwise(ca, suffix, opt_strip_suffix),
if pat.dtype() == &DataType::Null {
return Ok(ca.apply_generic(|opt_s| opt_s.map(|s| s.trim_start())));
} else {
Ok(strip_chars_start(ca, pat.utf8()?))
}
}

fn strip_chars_end(&self, pat: &Series) -> PolarsResult<Utf8Chunked> {
let ca = self.as_utf8();
if pat.dtype() == &DataType::Null {
return Ok(ca.apply_generic(|opt_s| opt_s.map(|s| s.trim_end())));
} else {
Ok(strip_chars_end(ca, pat.utf8()?))
}
}

fn strip_prefix(&self, prefix: &Utf8Chunked) -> Utf8Chunked {
let ca = self.as_utf8();
strip_prefix(ca, prefix)
}

fn strip_suffix(&self, suffix: &Utf8Chunked) -> Utf8Chunked {
let ca = self.as_utf8();
strip_suffix(ca, suffix)
}

#[cfg(feature = "dtype-struct")]
fn split_exact(&self, by: &Utf8Chunked, n: usize) -> PolarsResult<StructChunked> {
let ca = self.as_utf8();
Expand Down
139 changes: 139 additions & 0 deletions crates/polars-ops/src/chunked_array/strings/strip.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
use polars_core::prelude::arity::binary_elementwise;

use super::*;

fn strip_chars_binary<'a>(opt_s: Option<&'a str>, opt_pat: Option<&str>) -> Option<&'a str> {
match (opt_s, opt_pat) {
(Some(s), Some(pat)) => {
if pat.chars().count() == 1 {
Some(s.trim_matches(pat.chars().next().unwrap()))
} else {
Some(s.trim_matches(|c| pat.contains(c)))
}
},
(Some(s), _) => Some(s.trim()),
_ => None,
}
}

fn strip_chars_start_binary<'a>(opt_s: Option<&'a str>, opt_pat: Option<&str>) -> Option<&'a str> {
match (opt_s, opt_pat) {
(Some(s), Some(pat)) => {
if pat.chars().count() == 1 {
Some(s.trim_start_matches(pat.chars().next().unwrap()))
} else {
Some(s.trim_start_matches(|c| pat.contains(c)))
}
},
(Some(s), _) => Some(s.trim_start()),
_ => None,
}
}

fn strip_chars_end_binary<'a>(opt_s: Option<&'a str>, opt_pat: Option<&str>) -> Option<&'a str> {
match (opt_s, opt_pat) {
(Some(s), Some(pat)) => {
if pat.chars().count() == 1 {
Some(s.trim_end_matches(pat.chars().next().unwrap()))
} else {
Some(s.trim_end_matches(|c| pat.contains(c)))
}
},
(Some(s), _) => Some(s.trim_end()),
_ => None,
}
}

fn strip_prefix_binary<'a>(s: Option<&'a str>, prefix: Option<&str>) -> Option<&'a str> {
Some(s?.strip_prefix(prefix?).unwrap_or(s?))
}

fn strip_suffix_binary<'a>(s: Option<&'a str>, suffix: Option<&str>) -> Option<&'a str> {
Some(s?.strip_suffix(suffix?).unwrap_or(s?))
}

pub fn strip_chars(ca: &Utf8Chunked, pat: &Utf8Chunked) -> Utf8Chunked {
match pat.len() {
1 => {
if let Some(pat) = pat.get(0) {
if pat.chars().count() == 1 {
// Fast path for when a single character is passed
ca.apply_generic(|opt_s| {
opt_s.map(|s| s.trim_matches(pat.chars().next().unwrap()))
})
} else {
ca.apply_generic(|opt_s| opt_s.map(|s| s.trim_matches(|c| pat.contains(c))))
}
} else {
ca.apply_generic(|opt_s| opt_s.map(|s| s.trim()))
}
},
_ => binary_elementwise(ca, pat, strip_chars_binary),
}
}

pub fn strip_chars_start(ca: &Utf8Chunked, pat: &Utf8Chunked) -> Utf8Chunked {
match pat.len() {
1 => {
if let Some(pat) = pat.get(0) {
if pat.chars().count() == 1 {
// Fast path for when a single character is passed
ca.apply_generic(|opt_s| {
opt_s.map(|s| s.trim_start_matches(pat.chars().next().unwrap()))
})
} else {
ca.apply_generic(|opt_s| {
opt_s.map(|s| s.trim_start_matches(|c| pat.contains(c)))
})
}
} else {
ca.apply_generic(|opt_s| opt_s.map(|s| s.trim_start()))
}
},
_ => binary_elementwise(ca, pat, strip_chars_start_binary),
}
}

pub fn strip_chars_end(ca: &Utf8Chunked, pat: &Utf8Chunked) -> Utf8Chunked {
match pat.len() {
1 => {
if let Some(pat) = pat.get(0) {
if pat.chars().count() == 1 {
// Fast path for when a single character is passed
ca.apply_generic(|opt_s| {
opt_s.map(|s| s.trim_end_matches(pat.chars().next().unwrap()))
})
} else {
ca.apply_generic(|opt_s| opt_s.map(|s| s.trim_end_matches(|c| pat.contains(c))))
}
} else {
ca.apply_generic(|opt_s| opt_s.map(|s| s.trim_end()))
}
},
_ => binary_elementwise(ca, pat, strip_chars_end_binary),
}
}

pub fn strip_prefix(ca: &Utf8Chunked, prefix: &Utf8Chunked) -> Utf8Chunked {
match prefix.len() {
1 => match prefix.get(0) {
Some(prefix) => {
ca.apply_generic(|opt_s| opt_s.map(|s| s.strip_prefix(prefix).unwrap_or(s)))
},
_ => Utf8Chunked::full_null(ca.name(), ca.len()),
},
_ => binary_elementwise(ca, prefix, strip_prefix_binary),
}
}

pub fn strip_suffix(ca: &Utf8Chunked, suffix: &Utf8Chunked) -> Utf8Chunked {
match suffix.len() {
1 => match suffix.get(0) {
Some(suffix) => {
ca.apply_generic(|opt_s| opt_s.map(|s| s.strip_suffix(suffix).unwrap_or(s)))
},
_ => Utf8Chunked::full_null(ca.name(), ca.len()),
},
_ => binary_elementwise(ca, suffix, strip_suffix_binary),
}
}
6 changes: 3 additions & 3 deletions crates/polars-plan/src/dsl/function_expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -746,9 +746,9 @@ impl From<StringFunction> for SpecialEq<Arc<dyn SeriesUdf>> {
Lowercase => map!(strings::lowercase),
#[cfg(feature = "nightly")]
Titlecase => map!(strings::titlecase),
StripChars(matches) => map!(strings::strip_chars, matches.as_deref()),
StripCharsStart(matches) => map!(strings::strip_chars_start, matches.as_deref()),
StripCharsEnd(matches) => map!(strings::strip_chars_end, matches.as_deref()),
StripChars => map_as_slice!(strings::strip_chars),
StripCharsStart => map_as_slice!(strings::strip_chars_start),
StripCharsEnd => map_as_slice!(strings::strip_chars_end),
StripPrefix => map_as_slice!(strings::strip_prefix),
StripSuffix => map_as_slice!(strings::strip_suffix),
#[cfg(feature = "string_from_radix")]
Expand Down
87 changes: 21 additions & 66 deletions crates/polars-plan/src/dsl/function_expr/strings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,9 @@ pub enum StringFunction {
},
Slice(i64, Option<u64>),
StartsWith,
StripChars(Option<String>),
StripCharsStart(Option<String>),
StripCharsEnd(Option<String>),
StripChars,
StripCharsStart,
StripCharsEnd,
StripPrefix,
StripSuffix,
#[cfg(feature = "dtype-struct")]
Expand Down Expand Up @@ -127,9 +127,9 @@ impl StringFunction {
ToDecimal(_) => mapper.with_dtype(DataType::Decimal(None, None)),
Uppercase
| Lowercase
| StripChars(_)
| StripCharsStart(_)
| StripCharsEnd(_)
| StripChars
| StripCharsStart
| StripCharsEnd
| StripPrefix
| StripSuffix
| Slice(_, _) => mapper.with_same_dtype(),
Expand Down Expand Up @@ -182,9 +182,9 @@ impl Display for StringFunction {
StringFunction::Replace { .. } => "replace",
StringFunction::Slice(_, _) => "str_slice",
StringFunction::StartsWith { .. } => "starts_with",
StringFunction::StripChars(_) => "strip_chars",
StringFunction::StripCharsStart(_) => "strip_chars_start",
StringFunction::StripCharsEnd(_) => "strip_chars_end",
StringFunction::StripChars => "strip_chars",
StringFunction::StripCharsStart => "strip_chars_start",
StringFunction::StripCharsEnd => "strip_chars_end",
StringFunction::StripPrefix => "strip_prefix",
StringFunction::StripSuffix => "strip_suffix",
#[cfg(feature = "dtype-struct")]
Expand Down Expand Up @@ -298,67 +298,22 @@ pub(super) fn rjust(s: &Series, width: usize, fillchar: char) -> PolarsResult<Se
Ok(ca.rjust(width, fillchar).into_series())
}

pub(super) fn strip_chars(s: &Series, matches: Option<&str>) -> PolarsResult<Series> {
let ca = s.utf8()?;
if let Some(matches) = matches {
if matches.chars().count() == 1 {
// Fast path for when a single character is passed
Ok(ca
.apply_values(|s| Cow::Borrowed(s.trim_matches(matches.chars().next().unwrap())))
.into_series())
} else {
Ok(ca
.apply_values(|s| Cow::Borrowed(s.trim_matches(|c| matches.contains(c))))
.into_series())
}
} else {
Ok(ca.apply_values(|s| Cow::Borrowed(s.trim())).into_series())
}
pub(super) fn strip_chars(s: &[Series]) -> PolarsResult<Series> {
let ca = s[0].utf8()?;
let pat_s = &s[1];
ca.strip_chars(pat_s).map(|ok| ok.into_series())
}

pub(super) fn strip_chars_start(s: &Series, matches: Option<&str>) -> PolarsResult<Series> {
let ca = s.utf8()?;

if let Some(matches) = matches {
if matches.chars().count() == 1 {
// Fast path for when a single character is passed
Ok(ca
.apply_values(|s| {
Cow::Borrowed(s.trim_start_matches(matches.chars().next().unwrap()))
})
.into_series())
} else {
Ok(ca
.apply_values(|s| Cow::Borrowed(s.trim_start_matches(|c| matches.contains(c))))
.into_series())
}
} else {
Ok(ca
.apply_values(|s| Cow::Borrowed(s.trim_start()))
.into_series())
}
pub(super) fn strip_chars_start(s: &[Series]) -> PolarsResult<Series> {
let ca = s[0].utf8()?;
let pat_s = &s[1];
ca.strip_chars_start(pat_s).map(|ok| ok.into_series())
}

pub(super) fn strip_chars_end(s: &Series, matches: Option<&str>) -> PolarsResult<Series> {
let ca = s.utf8()?;
if let Some(matches) = matches {
if matches.chars().count() == 1 {
// Fast path for when a single character is passed
Ok(ca
.apply_values(|s| {
Cow::Borrowed(s.trim_end_matches(matches.chars().next().unwrap()))
})
.into_series())
} else {
Ok(ca
.apply_values(|s| Cow::Borrowed(s.trim_end_matches(|c| matches.contains(c))))
.into_series())
}
} else {
Ok(ca
.apply_values(|s| Cow::Borrowed(s.trim_end()))
.into_series())
}
pub(super) fn strip_chars_end(s: &[Series]) -> PolarsResult<Series> {
let ca = s[0].utf8()?;
let pat_s = &s[1];
ca.strip_chars_end(pat_s).map(|ok| ok.into_series())
}

pub(super) fn strip_prefix(s: &[Series]) -> PolarsResult<Series> {
Expand Down
Loading

0 comments on commit ca09ce8

Please sign in to comment.