Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(rust, python): expressify offset and length parameters for str.slice #12071

Closed
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions crates/polars-ops/src/chunked_array/strings/namespace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -524,12 +524,9 @@ pub trait Utf8NameSpaceImpl: AsUtf8 {
///
/// Determines a substring starting from `start` and with optional length `length` of each of the elements in `array`.
/// `start` can be negative, in which case the start counts from the end of the string.
fn str_slice(&self, start: i64, length: Option<u64>) -> Utf8Chunked {
fn str_slice(&self, start: &Int64Chunked, length: &UInt64Chunked) -> Utf8Chunked {
let ca = self.as_utf8();
let iter = ca
.downcast_iter()
.map(|c| substring::utf8_substring(c, start, &length));
Utf8Chunked::from_chunk_iter_like(ca, iter)
super::substring::utf8_substring(ca, start, length)
}
}

Expand Down
95 changes: 53 additions & 42 deletions crates/polars-ops/src/chunked_array/strings/substring.rs
Original file line number Diff line number Diff line change
@@ -1,51 +1,62 @@
use arrow::array::Utf8Array;
use polars_core::prelude::arity::ternary_elementwise;

use crate::chunked_array::{Int64Chunked, UInt64Chunked, Utf8Chunked};

/// Returns a Utf8Array<O> with a substring starting from `start` and with optional length `length` of each of the elements in `array`.
/// `start` can be negative, in which case the start counts from the end of the string.
pub(super) fn utf8_substring(
array: &Utf8Array<i64>,
start: i64,
length: &Option<u64>,
) -> Utf8Array<i64> {
let length = length.map(|v| v as usize);
/// `offset` can be negative, in which case the offset counts from the end of the string.
fn utf8_substring_ternary(
opt_str_val: Option<&str>,
opt_offset: Option<i64>,
opt_length: Option<u64>,
) -> Option<&str> {
match (opt_str_val, opt_offset) {
(Some(str_val), Some(offset)) => {
// compute where we should offset slicing this entry.
let offset = if offset >= 0 {
offset as usize
} else {
let offset = (0i64 - offset) as usize;
str_val
.char_indices()
.rev()
.nth(offset)
.map(|(idx, _)| idx + 1)
.unwrap_or(0)
};

let iter = array.values_iter().map(|str_val| {
// compute where we should start slicing this entry.
let start = if start >= 0 {
start as usize
} else {
let start = (0i64 - start) as usize;
str_val
.char_indices()
.rev()
.nth(start)
.map(|(idx, _)| idx + 1)
.unwrap_or(0)
};
let mut iter_chars = str_val.char_indices();
if let Some((offset_idx, _)) = iter_chars.nth(offset) {
// length of the str
let len_end = str_val.len() - offset_idx;

let mut iter_chars = str_val.char_indices();
if let Some((start_idx, _)) = iter_chars.nth(start) {
// length of the str
let len_end = str_val.len() - start_idx;
// slice to end of str if no length given
let length = match opt_length {
Some(length) => length as usize,
_ => len_end,
};

// length to slice
let length = length.unwrap_or(len_end);
if length == 0 {
return Some("");
}
// compute
let end_idx = iter_chars
.nth(length.saturating_sub(1))
.map(|(idx, _)| idx)
.unwrap_or(str_val.len());

if length == 0 {
return "";
Some(&str_val[offset_idx..end_idx])
} else {
Some("")
}
// compute
let end_idx = iter_chars
.nth(length.saturating_sub(1))
.map(|(idx, _)| idx)
.unwrap_or(str_val.len());

&str_val[start_idx..end_idx]
} else {
""
}
});
},
_ => None,
}
}

let new = Utf8Array::<i64>::from_trusted_len_values_iter(iter);
new.with_validity(array.validity().cloned())
pub(super) fn utf8_substring(
ca: &Utf8Chunked,
offset: &Int64Chunked,
length: &UInt64Chunked,
) -> Utf8Chunked {
ternary_elementwise(ca, offset, length, utf8_substring_ternary)
}
2 changes: 1 addition & 1 deletion crates/polars-plan/src/dsl/function_expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1078,7 +1078,7 @@ impl From<StringFunction> for SpecialEq<Arc<dyn SeriesUdf>> {
StripSuffix => map_as_slice!(strings::strip_suffix),
#[cfg(feature = "string_from_radix")]
FromRadix(radix, strict) => map!(strings::from_radix, radix, strict),
Slice(start, length) => map!(strings::str_slice, start, length),
Slice => map_as_slice!(strings::str_slice),
Explode => map!(strings::explode),
#[cfg(feature = "dtype-decimal")]
ToDecimal(infer_len) => map!(strings::to_decimal, infer_len),
Expand Down
61 changes: 48 additions & 13 deletions crates/polars-plan/src/dsl/function_expr/strings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ pub enum StringFunction {
length: usize,
fill_char: char,
},
Slice(i64, Option<u64>),
Slice,
StartsWith,
StripChars,
StripCharsStart,
Expand Down Expand Up @@ -125,14 +125,8 @@ impl StringFunction {
Titlecase => mapper.with_same_dtype(),
#[cfg(feature = "dtype-decimal")]
ToDecimal(_) => mapper.with_dtype(DataType::Decimal(None, None)),
Uppercase
| Lowercase
| StripChars
| StripCharsStart
| StripCharsEnd
| StripPrefix
| StripSuffix
| Slice(_, _) => mapper.with_same_dtype(),
Uppercase | Lowercase | StripChars | StripCharsStart | StripCharsEnd | StripPrefix
| StripSuffix | Slice => mapper.with_same_dtype(),
#[cfg(feature = "string_pad")]
PadStart { .. } | PadEnd { .. } | ZFill { .. } => mapper.with_same_dtype(),
#[cfg(feature = "dtype-struct")]
Expand Down Expand Up @@ -180,7 +174,7 @@ impl Display for StringFunction {
StringFunction::PadStart { .. } => "pad_start",
#[cfg(feature = "regex")]
StringFunction::Replace { .. } => "replace",
StringFunction::Slice(_, _) => "slice",
StringFunction::Slice => "slice",
StringFunction::StartsWith { .. } => "starts_with",
StringFunction::StripChars => "strip_chars",
StringFunction::StripCharsStart => "strip_chars_start",
Expand Down Expand Up @@ -732,9 +726,50 @@ pub(super) fn from_radix(s: &Series, radix: u32, strict: bool) -> PolarsResult<S
let ca = s.utf8()?;
ca.parse_int(radix, strict).map(|ok| ok.into_series())
}
pub(super) fn str_slice(s: &Series, start: i64, length: Option<u64>) -> PolarsResult<Series> {
let ca = s.utf8()?;
Ok(ca.str_slice(start, length).into_series())

pub(super) fn str_slice(s: &[Series]) -> PolarsResult<Series> {
let ca = s[0].utf8()?;

let s1 = &s[1];
let s2 = &s[2];

polars_ensure!(
s1.len() <= ca.len(),
ComputeError:
"too many `offset` values ({}) for column length ({})",
s1.len(), ca.len(),
);

polars_ensure!(
s2.len() <= ca.len(),
ComputeError:
"too many `length` values ({}) for column length ({})",
s2.len(), ca.len(),
);

let offset = match s1.len() {
1 => {
let offset = s1.get(0).unwrap();
s1.clear().extend_constant(offset, ca.len()).unwrap()
},
_ => s1.clone(),
};

let offset = offset.cast(&DataType::Int64)?;
let offset = offset.i64()?;

let length = match s2.len() {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't want to add broadcasting by extending memory. See discussion here; #11900

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah! Excellent, thanks @ritchie46 - that explains the need for the added length checks / fallback logic.

I'll switch it around.

1 => {
let length = s2.get(0).unwrap();
s2.clear().extend_constant(length, ca.len()).unwrap()
},
_ => s2.clone(),
};

let length = length.cast(&DataType::UInt64)?;
let length = length.u64()?;

Ok(ca.str_slice(offset, length).into_series())
}

pub(super) fn explode(s: &Series) -> PolarsResult<Series> {
Expand Down
12 changes: 7 additions & 5 deletions crates/polars-plan/src/dsl/string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -402,11 +402,13 @@ impl StringNameSpace {
}

/// Slice the string values.
pub fn slice(self, start: i64, length: Option<u64>) -> Expr {
self.0
.map_private(FunctionExpr::StringExpr(StringFunction::Slice(
start, length,
)))
pub fn slice(self, offset: Expr, length: Expr) -> Expr {
self.0.map_many_private(
FunctionExpr::StringExpr(StringFunction::Slice),
&[offset, length],
false,
false,
)
}

pub fn explode(self) -> Expr {
Expand Down
39 changes: 19 additions & 20 deletions crates/polars-sql/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -702,12 +702,14 @@ impl SqlFunctionVisitor<'_> {
#[cfg(feature = "nightly")]
InitCap => self.visit_unary(|e| e.str().to_titlecase()),
Left => self.try_visit_binary(|e, length| {
Ok(e.str().slice(0, match length {
Expr::Literal(LiteralValue::Int64(n)) => Some(n as u64),
Ok(match length {
Expr::Literal(LiteralValue::Int64(_)) => {
e.str().slice(lit(0), length)
},
_ => {
polars_bail!(InvalidOperation: "Invalid 'length' for Left: {}", function.args[1]);
polars_bail!(InvalidOperation: "Invalid 'length' for Left: {}", function.args[1])
}
}))
})
}),
Length => self.visit_unary(|e| e.str().len_chars()),
Lower => self.visit_unary(|e| e.str().to_lowercase()),
Expand Down Expand Up @@ -756,26 +758,23 @@ impl SqlFunctionVisitor<'_> {
StartsWith => self.visit_binary(|e, s| e.str().starts_with(s)),
Substring => match function.args.len() {
2 => self.try_visit_binary(|e, start| {
Ok(e.str().slice(match start {
Expr::Literal(LiteralValue::Int64(n)) => n,
Ok(match start {
Expr::Literal(LiteralValue::Int64(_)) => {
e.str().slice(start, lit(Null))
},
_ => {
polars_bail!(InvalidOperation: "Invalid 'start' for Substring: {}", function.args[1]);
polars_bail!(InvalidOperation: "Invalid 'start' for Substring: {}", function.args[1])
}
}, None))
})
}),
3 => self.try_visit_ternary(|e, start, length| {
Ok(e.str().slice(
match start {
Expr::Literal(LiteralValue::Int64(n)) => n,
_ => {
polars_bail!(InvalidOperation: "Invalid 'start' for Substring: {}", function.args[1]);
}
}, match length {
Expr::Literal(LiteralValue::Int64(n)) => Some(n as u64),
_ => {
polars_bail!(InvalidOperation: "Invalid 'length' for Substring: {}", function.args[2]);
}
}))
if !matches!(start, Expr::Literal(LiteralValue::Int64(_))) {
polars_bail!(InvalidOperation: "Invalid 'start' for Substring: {}", function.args[1]);
}
if !matches!(length, Expr::Literal(LiteralValue::Int64(_))) {
polars_bail!(InvalidOperation: "Invalid 'length' for Substring: {}", function.args[2]);
}
Ok(e.str().slice(start, length))
}),
_ => polars_bail!(InvalidOperation:
"Invalid number of arguments for Substring: {}",
Expand Down
4 changes: 3 additions & 1 deletion py-polars/polars/expr/string.py
Original file line number Diff line number Diff line change
Expand Up @@ -1852,7 +1852,7 @@ def replace_all(
value = parse_as_expression(value, str_as_lit=True)
return wrap_expr(self._pyexpr.str_replace_all(pattern, value, literal))

def slice(self, offset: int, length: int | None = None) -> Expr:
def slice(self, offset: IntoExpr, length: IntoExpr | None = None) -> Expr:
"""
Create subslices of the string values of a Utf8 Series.

Expand Down Expand Up @@ -1905,6 +1905,8 @@ def slice(self, offset: int, length: int | None = None) -> Expr:
└─────────────┴──────────┘

"""
offset = parse_as_expression(offset)
length = parse_as_expression(length)
return wrap_expr(self._pyexpr.str_slice(offset, length))

def explode(self) -> Expr:
Expand Down
2 changes: 1 addition & 1 deletion py-polars/polars/series/string.py
Original file line number Diff line number Diff line change
Expand Up @@ -1426,7 +1426,7 @@ def to_titlecase(self) -> Series:

"""

def slice(self, offset: int, length: int | None = None) -> Series:
def slice(self, offset: IntoExpr, length: IntoExpr | None = None) -> Series:
"""
Create subslices of the string values of a Utf8 Series.

Expand Down
8 changes: 6 additions & 2 deletions py-polars/src/expr/string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,12 @@ impl PyExpr {
self.inner.clone().str().strip_suffix(suffix.inner).into()
}

fn str_slice(&self, start: i64, length: Option<u64>) -> Self {
self.inner.clone().str().slice(start, length).into()
fn str_slice(&self, offset: Self, length: Self) -> Self {
self.inner
.clone()
.str()
.slice(offset.inner, length.inner)
.into()
}

fn str_explode(&self) -> Self {
Expand Down
14 changes: 14 additions & 0 deletions py-polars/tests/unit/namespaces/string/test_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,20 @@ def test_str_slice() -> None:
assert df.select([pl.col("a").str.slice(2, 4)])["a"].to_list() == ["obar", "rfoo"]


def test_str_slice_expressions() -> None:
df = pl.DataFrame({"a": ["foobar", "barfoo"], "offset": [1, 3], "length": [3, 4]})

out = df.select(pl.col("a").str.slice("offset", "length"))

expected = pl.DataFrame({"a": ["oob", "foo"]})
assert out.frame_equal(expected)

out = df.select(pl.col("a").str.slice(-3, "length"))

expected = pl.DataFrame({"a": ["bar", "foo"]})
assert out.frame_equal(expected)


def test_str_concat() -> None:
s = pl.Series(["1", None, "2"])
result = s.str.concat()
Expand Down