Skip to content

Commit

Permalink
feat(rust, python): Add str.extract_groups (#10179)
Browse files Browse the repository at this point in the history
  • Loading branch information
cmdlineluser authored Aug 4, 2023
1 parent 6874e2f commit 1e1fd1d
Show file tree
Hide file tree
Showing 18 changed files with 334 additions and 4 deletions.
1 change: 1 addition & 0 deletions crates/polars-lazy/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ list_sets = ["polars-plan/list_sets", "polars-ops/list_sets"]
list_any_all = ["polars-ops/list_any_all", "polars-plan/list_any_all"]
cutqcut = ["polars-plan/cutqcut", "polars-ops/cutqcut"]
rle = ["polars-plan/rle", "polars-ops/rle"]
extract_groups = ["polars-plan/extract_groups"]

binary_encoding = ["polars-plan/binary_encoding"]

Expand Down
1 change: 1 addition & 0 deletions crates/polars-ops/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,4 @@ semi_anti_join = ["polars-core/semi_anti_join"]
list_take = []
list_sets = []
list_any_all = []
extract_groups = ["dtype-struct", "polars-core/regex"]
77 changes: 77 additions & 0 deletions crates/polars-ops/src/chunked_array/strings/extract.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
use arrow::array::{Array, MutableArray, MutableUtf8Array, StructArray, Utf8Array};
use polars_arrow::utils::combine_validities_and;
use polars_core::export::regex::Regex;

use super::*;

fn extract_groups_array(
arr: &Utf8Array<i64>,
reg: &Regex,
names: &[String],
data_type: ArrowDataType,
) -> PolarsResult<ArrayRef> {
let mut builders = (0..names.len())
.map(|_| MutableUtf8Array::<i64>::with_capacity(arr.len()))
.collect::<Vec<_>>();

arr.into_iter().for_each(|opt_v| {
// we combine the null validity later
if let Some(value) = opt_v {
let caps = reg.captures(value);
match caps {
None => builders.iter_mut().for_each(|arr| arr.push_null()),
Some(caps) => {
caps.iter()
.skip(1) // skip 0th group
.zip(builders.iter_mut())
.for_each(|(m, builder)| builder.push(m.map(|m| m.as_str())))
}
}
}
});

let values = builders
.into_iter()
.map(|group_array| {
let group_array: Utf8Array<i64> = group_array.into();
let final_validity = combine_validities_and(group_array.validity(), arr.validity());
group_array.with_validity(final_validity).to_boxed()
})
.collect();

Ok(StructArray::new(data_type.clone(), values, None).boxed())
}

pub(super) fn extract_groups(ca: &Utf8Chunked, pat: &str) -> PolarsResult<Series> {
let reg = Regex::new(pat)?;
let n_fields = reg.captures_len();

if n_fields == 1 {
return StructChunked::new(ca.name(), &[Series::new_null(ca.name(), ca.len())])
.map(|ca| ca.into_series());
}

let names = reg
.capture_names()
.enumerate()
.skip(1)
.map(|(idx, opt_name)| {
opt_name
.map(|name| name.to_string())
.unwrap_or_else(|| format!("{idx}"))
})
.collect::<Vec<_>>();
let data_type = ArrowDataType::Struct(
names
.iter()
.map(|name| ArrowField::new(name.as_str(), ArrowDataType::LargeUtf8, true))
.collect(),
);

let chunks = ca
.downcast_iter()
.map(|array| extract_groups_array(array, &reg, &names, data_type.clone()))
.collect::<PolarsResult<Vec<_>>>()?;

Series::try_from((ca.name(), chunks))
}
2 changes: 2 additions & 0 deletions crates/polars-ops/src/chunked_array/strings/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#[cfg(feature = "strings")]
mod case;
#[cfg(feature = "extract_groups")]
mod extract;
#[cfg(feature = "extract_jsonpath")]
mod json_path;
#[cfg(feature = "string_justify")]
Expand Down
7 changes: 7 additions & 0 deletions crates/polars-ops/src/chunked_array/strings/namespace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,13 @@ pub trait Utf8NameSpaceImpl: AsUtf8 {
Ok(builder.finish())
}

#[cfg(feature = "extract_groups")]
/// Extract all capture groups from pattern and return as a struct
fn extract_groups(&self, pat: &str) -> PolarsResult<Series> {
let ca = self.as_utf8();
super::extract::extract_groups(ca, pat)
}

/// Count all successive non-overlapping regex matches.
fn count_match(&self, pat: &str) -> PolarsResult<UInt32Chunked> {
let ca = self.as_utf8();
Expand Down
1 change: 1 addition & 0 deletions crates/polars-plan/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ list_sets = ["polars-ops/list_sets"]
list_any_all = ["polars-ops/list_any_all"]
cutqcut = ["polars-ops/cutqcut"]
rle = ["polars-ops/rle"]
extract_groups = ["regex", "dtype-struct", "polars-ops/extract_groups"]

bigidx = ["polars-arrow/bigidx", "polars-core/bigidx", "polars-utils/bigidx"]

Expand Down
4 changes: 4 additions & 0 deletions crates/polars-plan/src/dsl/function_expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -645,6 +645,10 @@ impl From<StringFunction> for SpecialEq<Arc<dyn SeriesUdf>> {
ExtractAll => {
map_as_slice!(strings::extract_all)
}
#[cfg(feature = "extract_groups")]
ExtractGroups { pat } => {
map!(strings::extract_groups, &pat)
}
NChars => map!(strings::n_chars),
Length => map!(strings::lengths),
#[cfg(feature = "string_justify")]
Expand Down
17 changes: 17 additions & 0 deletions crates/polars-plan/src/dsl/function_expr/strings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ pub enum StringFunction {
group_index: usize,
},
ExtractAll,
#[cfg(feature = "extract_groups")]
ExtractGroups {
pat: String,
},
#[cfg(feature = "string_from_radix")]
FromRadix(u32, bool),
NChars,
Expand Down Expand Up @@ -90,6 +94,8 @@ impl StringFunction {
Explode => mapper.with_same_dtype(),
Extract { .. } => mapper.with_same_dtype(),
ExtractAll => mapper.with_dtype(DataType::List(Box::new(DataType::Utf8))),
#[cfg(feature = "extract_groups")]
ExtractGroups { .. } => mapper.with_same_dtype(),
#[cfg(feature = "string_from_radix")]
FromRadix { .. } => mapper.with_dtype(DataType::Int32),
#[cfg(feature = "extract_jsonpath")]
Expand Down Expand Up @@ -127,6 +133,8 @@ impl Display for StringFunction {
StringFunction::ConcatVertical(_) => "concat_vertical",
StringFunction::Explode => "explode",
StringFunction::ExtractAll => "extract_all",
#[cfg(feature = "extract_groups")]
StringFunction::ExtractGroups { .. } => "extract_groups",
#[cfg(feature = "string_from_radix")]
StringFunction::FromRadix { .. } => "from_radix",
#[cfg(feature = "extract_jsonpath")]
Expand Down Expand Up @@ -292,6 +300,15 @@ pub(super) fn extract(s: &Series, pat: &str, group_index: usize) -> PolarsResult
ca.extract(&pat, group_index).map(|ca| ca.into_series())
}

#[cfg(feature = "extract_groups")]
/// Extract all capture groups from a regex pattern as a struct
pub(super) fn extract_groups(s: &Series, pat: &str) -> PolarsResult<Series> {
let pat = pat.to_string();

let ca = s.utf8()?;
ca.extract_groups(&pat)
}

#[cfg(feature = "string_justify")]
pub(super) fn zfill(s: &Series, alignment: usize) -> PolarsResult<Series> {
let ca = s.utf8()?;
Expand Down
8 changes: 8 additions & 0 deletions crates/polars-plan/src/dsl/string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,14 @@ impl StringNameSpace {
.map_private(StringFunction::Extract { pat, group_index }.into())
}

#[cfg(feature = "extract_groups")]
// Extract all captures groups from a regex pattern as a struct
pub fn extract_groups(self, pat: &str) -> Expr {
let pat = pat.to_string();
self.0
.map_private(StringFunction::ExtractGroups { pat }.into())
}

/// Return a copy of the string left filled with ASCII '0' digits to make a string of length width.
/// A leading sign prefix ('+'/'-') is handled by inserting the padding after the sign character
/// rather than before.
Expand Down
2 changes: 2 additions & 0 deletions crates/polars/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ list_sets = ["polars-lazy/list_sets"]
list_any_all = ["polars-lazy/list_any_all"]
cutqcut = ["polars-lazy/cutqcut"]
rle = ["polars-lazy/rle"]
extract_groups = ["polars-lazy/extract_groups"]

test = [
"lazy",
Expand Down Expand Up @@ -327,6 +328,7 @@ docs-selection = [
"propagate_nans",
"coalesce",
"dynamic_groupby",
"extract_groups",
]

bench = [
Expand Down
9 changes: 5 additions & 4 deletions crates/polars/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -251,10 +251,11 @@
//! - `cumulative_eval` - Apply expressions over cumulatively increasing windows.
//! - `arg_where` - Get indices where condition holds.
//! - `search_sorted` - Find indices where elements should be inserted to maintain order.
//! - `date_offset` Add an offset to dates that take months and leap years into account.
//! - `trigonometry` Trigonometric functions.
//! - `sign` Compute the element-wise sign of a Series.
//! - `propagate_nans` NaN propagating min/max aggregations.
//! - `date_offset` - Add an offset to dates that take months and leap years into account.
//! - `trigonometry` - Trigonometric functions.
//! - `sign` - Compute the element-wise sign of a Series.
//! - `propagate_nans` - NaN propagating min/max aggregations.
//! - `extract_groups` - Extract multiple regex groups from strings.
//! * `DataFrame` pretty printing
//! - `fmt` - Activate DataFrame formatting
//!
Expand Down
2 changes: 2 additions & 0 deletions py-polars/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ list_sets = ["polars-lazy/list_sets"]
list_any_all = ["polars/list_any_all"]
cutqcut = ["polars/cutqcut"]
rle = ["polars/rle"]
extract_groups = ["polars/extract_groups"]

all = [
"json",
Expand Down Expand Up @@ -109,6 +110,7 @@ all = [
"list_any_all",
"cutqcut",
"rle",
"extract_groups",
]

# we cannot conditionally activate simd
Expand Down
1 change: 1 addition & 0 deletions py-polars/docs/source/reference/expressions/string.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ The following methods are available under the `expr.str` attribute.
Expr.str.explode
Expr.str.extract
Expr.str.extract_all
Expr.str.extract_groups
Expr.str.json_extract
Expr.str.json_path_match
Expr.str.lengths
Expand Down
1 change: 1 addition & 0 deletions py-polars/docs/source/reference/series/string.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ The following methods are available under the `Series.str` attribute.
Series.str.explode
Series.str.extract
Series.str.extract_all
Series.str.extract_groups
Series.str.json_extract
Series.str.json_path_match
Series.str.lengths
Expand Down
88 changes: 88 additions & 0 deletions py-polars/polars/expr/string.py
Original file line number Diff line number Diff line change
Expand Up @@ -1261,6 +1261,94 @@ def extract_all(self, pattern: str | Expr) -> Expr:
pattern = parse_as_expression(pattern, str_as_lit=True)
return wrap_expr(self._pyexpr.str_extract_all(pattern))

def extract_groups(self, pattern: str) -> Expr:
r"""
Extract all capture groups for the given regex pattern.
Parameters
----------
pattern
A valid regular expression pattern, compatible with the `regex crate
<https://docs.rs/regex/latest/regex/>`_.
Notes
-----
All group names are **strings**.
If your pattern contains unnamed groups, their numerical position is converted
to a string.
For example, here we access groups 2 and 3 via the names `"2"` and `"3"`::
>>> df = pl.DataFrame({"col": ["foo bar baz"]})
>>> (
... df.with_columns(
... pl.col("col").str.extract_groups(r"(\S+) (\S+) (.+)")
... ).select(pl.col("col").struct["2"], pl.col("col").struct["3"])
... )
shape: (1, 2)
┌─────┬─────┐
│ 2 ┆ 3 │
│ --- ┆ --- │
│ str ┆ str │
╞═════╪═════╡
│ bar ┆ baz │
└─────┴─────┘
Returns
-------
Expr
Expression of data type :class:`Struct` with fields of data type
:class:`Utf8`.
Examples
--------
>>> df = pl.DataFrame(
... data={
... "url": [
... "http://vote.com/ballon_dor?candidate=messi&ref=python",
... "http://vote.com/ballon_dor?candidate=weghorst&ref=polars",
... "http://vote.com/ballon_dor?error=404&ref=rust",
... ]
... }
... )
>>> pattern = r"candidate=(?<candidate>\w+)&ref=(?<ref>\w+)"
>>> df.select(captures=pl.col("url").str.extract_groups(pattern)).unnest(
... "captures"
... )
shape: (3, 2)
┌───────────┬────────┐
│ candidate ┆ ref │
│ --- ┆ --- │
│ str ┆ str │
╞═══════════╪════════╡
│ messi ┆ python │
│ weghorst ┆ polars │
│ null ┆ null │
└───────────┴────────┘
Unnamed groups have their numerical position converted to a string:
>>> pattern = r"candidate=(\w+)&ref=(\w+)"
>>> (
... df.with_columns(
... captures=pl.col("url").str.extract_groups(pattern)
... ).with_columns(name=pl.col("captures").struct["1"].str.to_uppercase())
... )
shape: (3, 3)
┌───────────────────────────────────┬───────────────────────┬──────────┐
│ url ┆ captures ┆ name │
│ --- ┆ --- ┆ --- │
│ str ┆ struct[2] ┆ str │
╞═══════════════════════════════════╪═══════════════════════╪══════════╡
│ http://vote.com/ballon_dor?candi… ┆ {"messi","python"} ┆ MESSI │
│ http://vote.com/ballon_dor?candi… ┆ {"weghorst","polars"} ┆ WEGHORST │
│ http://vote.com/ballon_dor?error… ┆ {null,null} ┆ null │
└───────────────────────────────────┴───────────────────────┴──────────┘
"""
return wrap_expr(self._pyexpr.str_extract_groups(pattern))

def count_match(self, pattern: str) -> Expr:
r"""
Count all successive non-overlapping regex matches.
Expand Down
Loading

0 comments on commit 1e1fd1d

Please sign in to comment.