From 17c0a01bf77a4e9db9c84253645bbdf39587fa36 Mon Sep 17 00:00:00 2001 From: Orson Peters Date: Mon, 7 Aug 2023 14:32:06 +0200 Subject: [PATCH 1/6] feat(rust): re-use regex capture allocation (#10302) --- .../src/chunked_array/strings/extract.rs | 78 +++++++++++++------ .../src/chunked_array/strings/namespace.rs | 10 +-- 2 files changed, 57 insertions(+), 31 deletions(-) diff --git a/crates/polars-ops/src/chunked_array/strings/extract.rs b/crates/polars-ops/src/chunked_array/strings/extract.rs index 8bff5f5fae16..d12aa8dfaa21 100644 --- a/crates/polars-ops/src/chunked_array/strings/extract.rs +++ b/crates/polars-ops/src/chunked_array/strings/extract.rs @@ -1,5 +1,4 @@ use arrow::array::{Array, MutableArray, MutableUtf8Array, StructArray, Utf8Array}; -use polars_arrow::utils::combine_validities_and; use polars_core::export::regex::Regex; use super::*; @@ -14,32 +13,30 @@ fn extract_groups_array( .map(|_| MutableUtf8Array::::with_capacity(arr.len())) .collect::>(); - arr.into_iter().for_each(|opt_v| { - // we combine the null validity later - if let Some(value) = opt_v { - let caps = reg.captures(value); - match caps { - None => builders.iter_mut().for_each(|arr| arr.push_null()), - Some(caps) => { - caps.iter() - .skip(1) // skip 0th group - .zip(builders.iter_mut()) - .for_each(|(m, builder)| builder.push(m.map(|m| m.as_str()))) + let mut locs = reg.capture_locations(); + for opt_v in arr { + if let Some(s) = opt_v { + if let Some(_) = reg.captures_read(&mut locs, s) { + for (i, builder) in builders.iter_mut().enumerate() { + builder.push(locs.get(i + 1).map(|(start, stop)| &s[start..stop])); } + continue; } } - }); + + // Push nulls if either the string is null or there was no match. We + // distinguish later between the two by copying arr's validity mask. + builders.iter_mut().for_each(|arr| arr.push_null()); + } let values = builders .into_iter() - .map(|group_array| { - let group_array: Utf8Array = group_array.into(); - let final_validity = combine_validities_and(group_array.validity(), arr.validity()); - group_array.with_validity(final_validity).to_boxed() + .map(|a| { + let immutable_a: Utf8Array = a.into(); + immutable_a.to_boxed() }) .collect(); - - Ok(StructArray::new(data_type.clone(), values, None).boxed()) + Ok(StructArray::new(data_type.clone(), values, arr.validity().cloned()).boxed()) } pub(super) fn extract_groups( @@ -49,16 +46,14 @@ pub(super) fn extract_groups( ) -> PolarsResult { let reg = Regex::new(pat)?; let n_fields = reg.captures_len(); - if n_fields == 1 { return StructChunked::new(ca.name(), &[Series::new_null(ca.name(), ca.len())]) .map(|ca| ca.into_series()); } let data_type = dtype.to_arrow(); - // impl error if it isn't a struct let DataType::Struct(fields) = dtype else { - unreachable!() + unreachable!() // Implementation error if it isn't a struct. }; let names = fields .iter() @@ -72,3 +67,42 @@ pub(super) fn extract_groups( Series::try_from((ca.name(), chunks)) } + +fn extract_group_array( + arr: &Utf8Array, + reg: &Regex, + group_index: usize, +) -> PolarsResult> { + let mut builder = MutableUtf8Array::::with_capacity(arr.len()); + + let mut locs = reg.capture_locations(); + for opt_v in arr { + if let Some(s) = opt_v { + if let Some(_) = reg.captures_read(&mut locs, s) { + builder.push(locs.get(group_index).map(|(start, stop)| &s[start..stop])); + continue; + } + } + + // Push null if either the string is null or there was no match. + builder.push_null(); + } + + Ok(builder.into()) +} + +pub(super) fn extract_group( + ca: &Utf8Chunked, + pat: &str, + group_index: usize, +) -> PolarsResult { + let reg = Regex::new(pat)?; + + let chunks = ca + .downcast_iter() + .map(|array| Ok(extract_group_array(array, ®, group_index)?.to_boxed())) + .collect::>>>()?; + + // SAFETY: all chunks have type Utf8Array + unsafe { Ok(ChunkedArray::from_chunks(ca.name(), chunks)) } +} diff --git a/crates/polars-ops/src/chunked_array/strings/namespace.rs b/crates/polars-ops/src/chunked_array/strings/namespace.rs index 18fa6f709338..29ffc8392f6c 100644 --- a/crates/polars-ops/src/chunked_array/strings/namespace.rs +++ b/crates/polars-ops/src/chunked_array/strings/namespace.rs @@ -1,5 +1,3 @@ -use std::borrow::Cow; - #[cfg(feature = "string_encoding")] use base64::engine::general_purpose; #[cfg(feature = "string_encoding")] @@ -15,11 +13,6 @@ use super::*; #[cfg(feature = "binary_encoding")] use crate::chunked_array::binary::BinaryNameSpaceImpl; -fn f_regex_extract<'a>(reg: &Regex, input: &'a str, group_index: usize) -> Option> { - reg.captures(input) - .and_then(|cap| cap.get(group_index).map(|m| Cow::Borrowed(m.as_str()))) -} - pub trait Utf8NameSpaceImpl: AsUtf8 { #[cfg(not(feature = "binary_encoding"))] fn hex_decode(&self) -> PolarsResult { @@ -298,8 +291,7 @@ pub trait Utf8NameSpaceImpl: AsUtf8 { /// Extract the nth capture group from pattern fn extract(&self, pat: &str, group_index: usize) -> PolarsResult { let ca = self.as_utf8(); - let reg = Regex::new(pat)?; - Ok(ca.apply_on_opt(|e| e.and_then(|input| f_regex_extract(®, input, group_index)))) + super::extract::extract_group(ca, pat, group_index) } /// Extract each successive non-overlapping regex match in an individual string as an array From 49e480cbba2262d6147bc7bb7d2d932dbde2f37f Mon Sep 17 00:00:00 2001 From: Orson Peters Date: Mon, 7 Aug 2023 15:04:48 +0200 Subject: [PATCH 2/6] fix(rust): set cfg("extract_groups") appropriately --- crates/polars-ops/src/chunked_array/strings/extract.rs | 2 ++ crates/polars-ops/src/chunked_array/strings/mod.rs | 1 - 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/crates/polars-ops/src/chunked_array/strings/extract.rs b/crates/polars-ops/src/chunked_array/strings/extract.rs index d12aa8dfaa21..1c5e1cb33651 100644 --- a/crates/polars-ops/src/chunked_array/strings/extract.rs +++ b/crates/polars-ops/src/chunked_array/strings/extract.rs @@ -3,6 +3,7 @@ use polars_core::export::regex::Regex; use super::*; +#[cfg(feature = "extract_groups")] fn extract_groups_array( arr: &Utf8Array, reg: &Regex, @@ -39,6 +40,7 @@ fn extract_groups_array( Ok(StructArray::new(data_type.clone(), values, arr.validity().cloned()).boxed()) } +#[cfg(feature = "extract_groups")] pub(super) fn extract_groups( ca: &Utf8Chunked, pat: &str, diff --git a/crates/polars-ops/src/chunked_array/strings/mod.rs b/crates/polars-ops/src/chunked_array/strings/mod.rs index 8b8c636da5f2..2116578c2260 100644 --- a/crates/polars-ops/src/chunked_array/strings/mod.rs +++ b/crates/polars-ops/src/chunked_array/strings/mod.rs @@ -1,6 +1,5 @@ #[cfg(feature = "strings")] mod case; -#[cfg(feature = "extract_groups")] mod extract; #[cfg(feature = "extract_jsonpath")] mod json_path; From 8ab2e98612a1275467f4c8d2fc3af6add3b45a49 Mon Sep 17 00:00:00 2001 From: Orson Peters Date: Mon, 7 Aug 2023 15:05:24 +0200 Subject: [PATCH 3/6] chore(rust): silence clippy --- crates/polars-ops/src/chunked_array/strings/extract.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/polars-ops/src/chunked_array/strings/extract.rs b/crates/polars-ops/src/chunked_array/strings/extract.rs index 1c5e1cb33651..6b08d0fb18c3 100644 --- a/crates/polars-ops/src/chunked_array/strings/extract.rs +++ b/crates/polars-ops/src/chunked_array/strings/extract.rs @@ -17,7 +17,7 @@ fn extract_groups_array( let mut locs = reg.capture_locations(); for opt_v in arr { if let Some(s) = opt_v { - if let Some(_) = reg.captures_read(&mut locs, s) { + if reg.captures_read(&mut locs, s).is_some() { for (i, builder) in builders.iter_mut().enumerate() { builder.push(locs.get(i + 1).map(|(start, stop)| &s[start..stop])); } @@ -80,7 +80,7 @@ fn extract_group_array( let mut locs = reg.capture_locations(); for opt_v in arr { if let Some(s) = opt_v { - if let Some(_) = reg.captures_read(&mut locs, s) { + if reg.captures_read(&mut locs, s).is_some() { builder.push(locs.get(group_index).map(|(start, stop)| &s[start..stop])); continue; } From 745aea66c53a4eb88f8fb1a4c41fa6b80d047b1b Mon Sep 17 00:00:00 2001 From: Orson Peters Date: Mon, 7 Aug 2023 15:17:25 +0200 Subject: [PATCH 4/6] fix(rust): set cfg appropriately, v2 --- crates/polars-ops/src/chunked_array/strings/mod.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/crates/polars-ops/src/chunked_array/strings/mod.rs b/crates/polars-ops/src/chunked_array/strings/mod.rs index 2116578c2260..3caaec8a9dba 100644 --- a/crates/polars-ops/src/chunked_array/strings/mod.rs +++ b/crates/polars-ops/src/chunked_array/strings/mod.rs @@ -1,5 +1,6 @@ #[cfg(feature = "strings")] mod case; +#[cfg(feature = "strings")] mod extract; #[cfg(feature = "extract_jsonpath")] mod json_path; From 61c50e4989c43f2e6380e608e840ab122af26055 Mon Sep 17 00:00:00 2001 From: Orson Peters Date: Mon, 7 Aug 2023 15:44:10 +0200 Subject: [PATCH 5/6] chore(rust): silence clippy, v2 --- crates/polars-ops/src/chunked_array/strings/extract.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/crates/polars-ops/src/chunked_array/strings/extract.rs b/crates/polars-ops/src/chunked_array/strings/extract.rs index 6b08d0fb18c3..91ea20ac331b 100644 --- a/crates/polars-ops/src/chunked_array/strings/extract.rs +++ b/crates/polars-ops/src/chunked_array/strings/extract.rs @@ -1,4 +1,7 @@ -use arrow::array::{Array, MutableArray, MutableUtf8Array, StructArray, Utf8Array}; +#[cfg(feature = "extract_groups")] +use arrow::array::StructArray; + +use arrow::array::{Array, MutableArray, MutableUtf8Array, Utf8Array}; use polars_core::export::regex::Regex; use super::*; From 77eaeb879da58f7e786de6d0736a6093fc9191ad Mon Sep 17 00:00:00 2001 From: Orson Peters Date: Mon, 7 Aug 2023 16:00:55 +0200 Subject: [PATCH 6/6] chore(rust): cargo fmt --- crates/polars-ops/src/chunked_array/strings/extract.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/crates/polars-ops/src/chunked_array/strings/extract.rs b/crates/polars-ops/src/chunked_array/strings/extract.rs index 91ea20ac331b..741ea18ec6f7 100644 --- a/crates/polars-ops/src/chunked_array/strings/extract.rs +++ b/crates/polars-ops/src/chunked_array/strings/extract.rs @@ -1,6 +1,5 @@ #[cfg(feature = "extract_groups")] use arrow::array::StructArray; - use arrow::array::{Array, MutableArray, MutableUtf8Array, Utf8Array}; use polars_core::export::regex::Regex;