diff --git a/crates/polars-lazy/Cargo.toml b/crates/polars-lazy/Cargo.toml index 96394e90b4da..fcd7f21aa845 100644 --- a/crates/polars-lazy/Cargo.toml +++ b/crates/polars-lazy/Cargo.toml @@ -138,6 +138,7 @@ list_sets = ["polars-plan/list_sets", "polars-ops/list_sets"] list_any_all = ["polars-ops/list_any_all", "polars-plan/list_any_all"] cutqcut = ["polars-plan/cutqcut", "polars-ops/cutqcut"] rle = ["polars-plan/rle", "polars-ops/rle"] +extract_groups = ["polars-plan/extract_groups"] binary_encoding = ["polars-plan/binary_encoding"] diff --git a/crates/polars-ops/Cargo.toml b/crates/polars-ops/Cargo.toml index aa7b426d4b2d..0a8a66c92c88 100644 --- a/crates/polars-ops/Cargo.toml +++ b/crates/polars-ops/Cargo.toml @@ -88,3 +88,4 @@ semi_anti_join = ["polars-core/semi_anti_join"] list_take = [] list_sets = [] list_any_all = [] +extract_groups = ["dtype-struct", "polars-core/regex"] diff --git a/crates/polars-ops/src/chunked_array/strings/extract.rs b/crates/polars-ops/src/chunked_array/strings/extract.rs new file mode 100644 index 000000000000..175a71cdf8c3 --- /dev/null +++ b/crates/polars-ops/src/chunked_array/strings/extract.rs @@ -0,0 +1,77 @@ +use arrow::array::{Array, MutableArray, MutableUtf8Array, StructArray, Utf8Array}; +use polars_arrow::utils::combine_validities_and; +use polars_core::export::regex::Regex; + +use super::*; + +fn extract_groups_array( + arr: &Utf8Array, + reg: &Regex, + names: &[String], + data_type: ArrowDataType, +) -> PolarsResult { + let mut builders = (0..names.len()) + .map(|_| MutableUtf8Array::::with_capacity(arr.len())) + .collect::>(); + + arr.into_iter().for_each(|opt_v| { + // we combine the null validity later + if let Some(value) = opt_v { + let caps = reg.captures(value); + match caps { + None => builders.iter_mut().for_each(|arr| arr.push_null()), + Some(caps) => { + caps.iter() + .skip(1) // skip 0th group + .zip(builders.iter_mut()) + .for_each(|(m, builder)| builder.push(m.map(|m| m.as_str()))) + } + } + } + }); + + let values = builders + .into_iter() + .map(|group_array| { + let group_array: Utf8Array = group_array.into(); + let final_validity = combine_validities_and(group_array.validity(), arr.validity()); + group_array.with_validity(final_validity).to_boxed() + }) + .collect(); + + Ok(StructArray::new(data_type.clone(), values, None).boxed()) +} + +pub(super) fn extract_groups(ca: &Utf8Chunked, pat: &str) -> PolarsResult { + let reg = Regex::new(pat)?; + let n_fields = reg.captures_len(); + + if n_fields == 1 { + return StructChunked::new(ca.name(), &[Series::new_null(ca.name(), ca.len())]) + .map(|ca| ca.into_series()); + } + + let names = reg + .capture_names() + .enumerate() + .skip(1) + .map(|(idx, opt_name)| { + opt_name + .map(|name| name.to_string()) + .unwrap_or_else(|| format!("{idx}")) + }) + .collect::>(); + let data_type = ArrowDataType::Struct( + names + .iter() + .map(|name| ArrowField::new(name.as_str(), ArrowDataType::LargeUtf8, true)) + .collect(), + ); + + let chunks = ca + .downcast_iter() + .map(|array| extract_groups_array(array, ®, &names, data_type.clone())) + .collect::>>()?; + + Series::try_from((ca.name(), chunks)) +} diff --git a/crates/polars-ops/src/chunked_array/strings/mod.rs b/crates/polars-ops/src/chunked_array/strings/mod.rs index 25c1e4c74233..8b8c636da5f2 100644 --- a/crates/polars-ops/src/chunked_array/strings/mod.rs +++ b/crates/polars-ops/src/chunked_array/strings/mod.rs @@ -1,5 +1,7 @@ #[cfg(feature = "strings")] mod case; +#[cfg(feature = "extract_groups")] +mod extract; #[cfg(feature = "extract_jsonpath")] mod json_path; #[cfg(feature = "string_justify")] diff --git a/crates/polars-ops/src/chunked_array/strings/namespace.rs b/crates/polars-ops/src/chunked_array/strings/namespace.rs index 1e056fbcc86b..badd4c0970eb 100644 --- a/crates/polars-ops/src/chunked_array/strings/namespace.rs +++ b/crates/polars-ops/src/chunked_array/strings/namespace.rs @@ -353,6 +353,13 @@ pub trait Utf8NameSpaceImpl: AsUtf8 { Ok(builder.finish()) } + #[cfg(feature = "extract_groups")] + /// Extract all capture groups from pattern and return as a struct + fn extract_groups(&self, pat: &str) -> PolarsResult { + let ca = self.as_utf8(); + super::extract::extract_groups(ca, pat) + } + /// Count all successive non-overlapping regex matches. fn count_match(&self, pat: &str) -> PolarsResult { let ca = self.as_utf8(); diff --git a/crates/polars-plan/Cargo.toml b/crates/polars-plan/Cargo.toml index 2f4b4e94f5c0..c09fb792217a 100644 --- a/crates/polars-plan/Cargo.toml +++ b/crates/polars-plan/Cargo.toml @@ -130,6 +130,7 @@ list_sets = ["polars-ops/list_sets"] list_any_all = ["polars-ops/list_any_all"] cutqcut = ["polars-ops/cutqcut"] rle = ["polars-ops/rle"] +extract_groups = ["regex", "dtype-struct", "polars-ops/extract_groups"] bigidx = ["polars-arrow/bigidx", "polars-core/bigidx", "polars-utils/bigidx"] diff --git a/crates/polars-plan/src/dsl/function_expr/mod.rs b/crates/polars-plan/src/dsl/function_expr/mod.rs index 5feaebb7c990..56b0f41d6122 100644 --- a/crates/polars-plan/src/dsl/function_expr/mod.rs +++ b/crates/polars-plan/src/dsl/function_expr/mod.rs @@ -645,6 +645,10 @@ impl From for SpecialEq> { ExtractAll => { map_as_slice!(strings::extract_all) } + #[cfg(feature = "extract_groups")] + ExtractGroups { pat } => { + map!(strings::extract_groups, &pat) + } NChars => map!(strings::n_chars), Length => map!(strings::lengths), #[cfg(feature = "string_justify")] diff --git a/crates/polars-plan/src/dsl/function_expr/strings.rs b/crates/polars-plan/src/dsl/function_expr/strings.rs index 57236d05ab96..8208365745f8 100644 --- a/crates/polars-plan/src/dsl/function_expr/strings.rs +++ b/crates/polars-plan/src/dsl/function_expr/strings.rs @@ -34,6 +34,10 @@ pub enum StringFunction { group_index: usize, }, ExtractAll, + #[cfg(feature = "extract_groups")] + ExtractGroups { + pat: String, + }, #[cfg(feature = "string_from_radix")] FromRadix(u32, bool), NChars, @@ -90,6 +94,8 @@ impl StringFunction { Explode => mapper.with_same_dtype(), Extract { .. } => mapper.with_same_dtype(), ExtractAll => mapper.with_dtype(DataType::List(Box::new(DataType::Utf8))), + #[cfg(feature = "extract_groups")] + ExtractGroups { .. } => mapper.with_same_dtype(), #[cfg(feature = "string_from_radix")] FromRadix { .. } => mapper.with_dtype(DataType::Int32), #[cfg(feature = "extract_jsonpath")] @@ -127,6 +133,8 @@ impl Display for StringFunction { StringFunction::ConcatVertical(_) => "concat_vertical", StringFunction::Explode => "explode", StringFunction::ExtractAll => "extract_all", + #[cfg(feature = "extract_groups")] + StringFunction::ExtractGroups { .. } => "extract_groups", #[cfg(feature = "string_from_radix")] StringFunction::FromRadix { .. } => "from_radix", #[cfg(feature = "extract_jsonpath")] @@ -292,6 +300,15 @@ pub(super) fn extract(s: &Series, pat: &str, group_index: usize) -> PolarsResult ca.extract(&pat, group_index).map(|ca| ca.into_series()) } +#[cfg(feature = "extract_groups")] +/// Extract all capture groups from a regex pattern as a struct +pub(super) fn extract_groups(s: &Series, pat: &str) -> PolarsResult { + let pat = pat.to_string(); + + let ca = s.utf8()?; + ca.extract_groups(&pat) +} + #[cfg(feature = "string_justify")] pub(super) fn zfill(s: &Series, alignment: usize) -> PolarsResult { let ca = s.utf8()?; diff --git a/crates/polars-plan/src/dsl/string.rs b/crates/polars-plan/src/dsl/string.rs index fe5409a3ef86..de725741cc6a 100644 --- a/crates/polars-plan/src/dsl/string.rs +++ b/crates/polars-plan/src/dsl/string.rs @@ -61,6 +61,14 @@ impl StringNameSpace { .map_private(StringFunction::Extract { pat, group_index }.into()) } + #[cfg(feature = "extract_groups")] + // Extract all captures groups from a regex pattern as a struct + pub fn extract_groups(self, pat: &str) -> Expr { + let pat = pat.to_string(); + self.0 + .map_private(StringFunction::ExtractGroups { pat }.into()) + } + /// Return a copy of the string left filled with ASCII '0' digits to make a string of length width. /// A leading sign prefix ('+'/'-') is handled by inserting the padding after the sign character /// rather than before. diff --git a/crates/polars/Cargo.toml b/crates/polars/Cargo.toml index 46bf924c923d..cafd1c176331 100644 --- a/crates/polars/Cargo.toml +++ b/crates/polars/Cargo.toml @@ -184,6 +184,7 @@ list_sets = ["polars-lazy/list_sets"] list_any_all = ["polars-lazy/list_any_all"] cutqcut = ["polars-lazy/cutqcut"] rle = ["polars-lazy/rle"] +extract_groups = ["polars-lazy/extract_groups"] test = [ "lazy", @@ -327,6 +328,7 @@ docs-selection = [ "propagate_nans", "coalesce", "dynamic_groupby", + "extract_groups", ] bench = [ diff --git a/crates/polars/src/lib.rs b/crates/polars/src/lib.rs index cafafd5650e4..ff8167241c9a 100644 --- a/crates/polars/src/lib.rs +++ b/crates/polars/src/lib.rs @@ -251,10 +251,11 @@ //! - `cumulative_eval` - Apply expressions over cumulatively increasing windows. //! - `arg_where` - Get indices where condition holds. //! - `search_sorted` - Find indices where elements should be inserted to maintain order. -//! - `date_offset` Add an offset to dates that take months and leap years into account. -//! - `trigonometry` Trigonometric functions. -//! - `sign` Compute the element-wise sign of a Series. -//! - `propagate_nans` NaN propagating min/max aggregations. +//! - `date_offset` - Add an offset to dates that take months and leap years into account. +//! - `trigonometry` - Trigonometric functions. +//! - `sign` - Compute the element-wise sign of a Series. +//! - `propagate_nans` - NaN propagating min/max aggregations. +//! - `extract_groups` - Extract multiple regex groups from strings. //! * `DataFrame` pretty printing //! - `fmt` - Activate DataFrame formatting //! diff --git a/py-polars/Cargo.toml b/py-polars/Cargo.toml index 9d11dda96ddf..89b262947117 100644 --- a/py-polars/Cargo.toml +++ b/py-polars/Cargo.toml @@ -70,6 +70,7 @@ list_sets = ["polars-lazy/list_sets"] list_any_all = ["polars/list_any_all"] cutqcut = ["polars/cutqcut"] rle = ["polars/rle"] +extract_groups = ["polars/extract_groups"] all = [ "json", @@ -109,6 +110,7 @@ all = [ "list_any_all", "cutqcut", "rle", + "extract_groups", ] # we cannot conditionally activate simd diff --git a/py-polars/docs/source/reference/expressions/string.rst b/py-polars/docs/source/reference/expressions/string.rst index ba64cbb00acd..f2e5ad2e9945 100644 --- a/py-polars/docs/source/reference/expressions/string.rst +++ b/py-polars/docs/source/reference/expressions/string.rst @@ -18,6 +18,7 @@ The following methods are available under the `expr.str` attribute. Expr.str.explode Expr.str.extract Expr.str.extract_all + Expr.str.extract_groups Expr.str.json_extract Expr.str.json_path_match Expr.str.lengths diff --git a/py-polars/docs/source/reference/series/string.rst b/py-polars/docs/source/reference/series/string.rst index 6c3fcd521b82..59489085588c 100644 --- a/py-polars/docs/source/reference/series/string.rst +++ b/py-polars/docs/source/reference/series/string.rst @@ -18,6 +18,7 @@ The following methods are available under the `Series.str` attribute. Series.str.explode Series.str.extract Series.str.extract_all + Series.str.extract_groups Series.str.json_extract Series.str.json_path_match Series.str.lengths diff --git a/py-polars/polars/expr/string.py b/py-polars/polars/expr/string.py index 7caddfe1f307..ccb5fbcb3270 100644 --- a/py-polars/polars/expr/string.py +++ b/py-polars/polars/expr/string.py @@ -1261,6 +1261,94 @@ def extract_all(self, pattern: str | Expr) -> Expr: pattern = parse_as_expression(pattern, str_as_lit=True) return wrap_expr(self._pyexpr.str_extract_all(pattern)) + def extract_groups(self, pattern: str) -> Expr: + r""" + Extract all capture groups for the given regex pattern. + + Parameters + ---------- + pattern + A valid regular expression pattern, compatible with the `regex crate + `_. + + Notes + ----- + All group names are **strings**. + + If your pattern contains unnamed groups, their numerical position is converted + to a string. + + For example, here we access groups 2 and 3 via the names `"2"` and `"3"`:: + + >>> df = pl.DataFrame({"col": ["foo bar baz"]}) + >>> ( + ... df.with_columns( + ... pl.col("col").str.extract_groups(r"(\S+) (\S+) (.+)") + ... ).select(pl.col("col").struct["2"], pl.col("col").struct["3"]) + ... ) + shape: (1, 2) + ┌─────┬─────┐ + │ 2 ┆ 3 │ + │ --- ┆ --- │ + │ str ┆ str │ + ╞═════╪═════╡ + │ bar ┆ baz │ + └─────┴─────┘ + + Returns + ------- + Expr + Expression of data type :class:`Struct` with fields of data type + :class:`Utf8`. + + Examples + -------- + >>> df = pl.DataFrame( + ... data={ + ... "url": [ + ... "http://vote.com/ballon_dor?candidate=messi&ref=python", + ... "http://vote.com/ballon_dor?candidate=weghorst&ref=polars", + ... "http://vote.com/ballon_dor?error=404&ref=rust", + ... ] + ... } + ... ) + >>> pattern = r"candidate=(?\w+)&ref=(?\w+)" + >>> df.select(captures=pl.col("url").str.extract_groups(pattern)).unnest( + ... "captures" + ... ) + shape: (3, 2) + ┌───────────┬────────┐ + │ candidate ┆ ref │ + │ --- ┆ --- │ + │ str ┆ str │ + ╞═══════════╪════════╡ + │ messi ┆ python │ + │ weghorst ┆ polars │ + │ null ┆ null │ + └───────────┴────────┘ + + Unnamed groups have their numerical position converted to a string: + + >>> pattern = r"candidate=(\w+)&ref=(\w+)" + >>> ( + ... df.with_columns( + ... captures=pl.col("url").str.extract_groups(pattern) + ... ).with_columns(name=pl.col("captures").struct["1"].str.to_uppercase()) + ... ) + shape: (3, 3) + ┌───────────────────────────────────┬───────────────────────┬──────────┐ + │ url ┆ captures ┆ name │ + │ --- ┆ --- ┆ --- │ + │ str ┆ struct[2] ┆ str │ + ╞═══════════════════════════════════╪═══════════════════════╪══════════╡ + │ http://vote.com/ballon_dor?candi… ┆ {"messi","python"} ┆ MESSI │ + │ http://vote.com/ballon_dor?candi… ┆ {"weghorst","polars"} ┆ WEGHORST │ + │ http://vote.com/ballon_dor?error… ┆ {null,null} ┆ null │ + └───────────────────────────────────┴───────────────────────┴──────────┘ + + """ + return wrap_expr(self._pyexpr.str_extract_groups(pattern)) + def count_match(self, pattern: str) -> Expr: r""" Count all successive non-overlapping regex matches. diff --git a/py-polars/polars/series/string.py b/py-polars/polars/series/string.py index 9def2d9c6e4d..74c30c52a9ff 100644 --- a/py-polars/polars/series/string.py +++ b/py-polars/polars/series/string.py @@ -774,6 +774,62 @@ def extract_all(self, pattern: str | Series) -> Series: ''' + def extract_groups(self, pattern: str) -> Series: + r""" + Extract all capture groups for the given regex pattern. + + Parameters + ---------- + pattern + A valid regular expression pattern, compatible with the `regex crate + `_. + + Notes + ----- + All group names are **strings**. + + If your pattern contains unnamed groups, their numerical position is converted + to a string. + + For example, we can access the first group via the string `"1"`:: + + >>> ( + ... pl.Series(["foo bar baz"]) + ... .str.extract_groups(r"(\w+) (.+) (\w+)") + ... .struct["1"] + ... ) + shape: (1,) + Series: '1' [str] + [ + "foo" + ] + + Returns + ------- + Series + Series of data type :class:`Struct` with fields of data type :class:`Utf8`. + + Examples + -------- + >>> s = pl.Series( + ... name="url", + ... values=[ + ... "http://vote.com/ballon_dor?candidate=messi&ref=python", + ... "http://vote.com/ballon_dor?candidate=weghorst&ref=polars", + ... "http://vote.com/ballon_dor?error=404&ref=rust", + ... ], + ... ) + >>> s.str.extract_groups(r"candidate=(?\w+)&ref=(?\w+)") + shape: (3,) + Series: 'url' [struct[2]] + [ + {"messi","python"} + {"weghorst","polars"} + {null,null} + ] + + """ + def count_match(self, pattern: str) -> Series: r""" Count all successive non-overlapping regex matches. diff --git a/py-polars/src/expr/string.rs b/py-polars/src/expr/string.rs index aa8d7330c24b..fbc71fe7c57c 100644 --- a/py-polars/src/expr/string.rs +++ b/py-polars/src/expr/string.rs @@ -241,6 +241,11 @@ impl PyExpr { self.inner.clone().str().extract_all(pat.inner).into() } + #[cfg(feature = "extract_groups")] + fn str_extract_groups(&self, pat: &str) -> Self { + self.inner.clone().str().extract_groups(pat).into() + } + fn str_count_match(&self, pat: &str) -> Self { self.inner.clone().str().count_match(pat).into() } diff --git a/py-polars/tests/unit/namespaces/test_string.py b/py-polars/tests/unit/namespaces/test_string.py index 1a05327d4ab8..2d3b3473eb90 100644 --- a/py-polars/tests/unit/namespaces/test_string.py +++ b/py-polars/tests/unit/namespaces/test_string.py @@ -537,6 +537,62 @@ def test_extract_all_many() -> None: assert df["foo"].str.extract_all(df["re"]).to_list() == [["a"], ["bc"], ["abc"]] +def test_extract_groups() -> None: + def _named_groups_builder(pattern: str, groups: dict[str, str]) -> str: + return pattern.format( + **{name: f"(?<{name}>{value})" for name, value in groups.items()} + ) + + expected = { + "authority": ["ISO", "ISO/IEC/IEEE"], + "spec_num": ["80000", "29148"], + "part_num": ["1", None], + "revision_year": ["2009", "2018"], + } + + pattern = _named_groups_builder( + r"{authority}\s{spec_num}(?:-{part_num})?(?::{revision_year})", + { + "authority": r"^ISO(?:/[A-Z]+)*", + "spec_num": r"\d+", + "part_num": r"\d+", + "revision_year": r"\d{4}", + }, + ) + + df = pl.DataFrame({"iso_code": ["ISO 80000-1:2009", "ISO/IEC/IEEE 29148:2018"]}) + + assert ( + df.select(pl.col("iso_code").str.extract_groups(pattern)) + .unnest("iso_code") + .to_dict(False) + == expected + ) + + assert df.select(pl.col("iso_code").str.extract_groups("")).to_dict(False) == { + "iso_code": [{"iso_code": None}, {"iso_code": None}] + } + + assert df.select( + pl.col("iso_code").str.extract_groups(r"\A(ISO\S*).*?(\d+)") + ).to_dict(False) == { + "iso_code": [{"1": "ISO", "2": "80000"}, {"1": "ISO/IEC/IEEE", "2": "29148"}] + } + + assert df.select( + pl.col("iso_code").str.extract_groups(r"\A(ISO\S*).*?(?\d+)\z") + ).to_dict(False) == { + "iso_code": [ + {"1": "ISO", "year": "2009"}, + {"1": "ISO/IEC/IEEE", "year": "2018"}, + ] + } + + assert pl.select( + pl.lit(r"foobar").str.extract_groups(r"(?.{3})|(?...)") + ).to_dict(False) == {"literal": [{"foo": "foo", "bar": None}]} + + def test_zfill() -> None: df = pl.DataFrame( {