From 7744eed6c3017255303f1485850df454bd6f207c Mon Sep 17 00:00:00 2001 From: Marshall Crumiller Date: Mon, 6 Jan 2025 16:36:58 -0500 Subject: [PATCH] Add contains --- crates/polars-plan/src/dsl/cat.rs | 30 ++++ .../polars-plan/src/dsl/function_expr/cat.rs | 68 ++++++++- crates/polars-python/src/expr/categorical.rs | 16 ++ .../reference/expressions/categories.rst | 2 + .../source/reference/series/categories.rst | 2 + py-polars/polars/expr/categorical.py | 138 ++++++++++++++++++ py-polars/polars/series/categorical.py | 113 ++++++++++++++ .../operations/namespaces/test_categorical.py | 117 +++++++++++++++ 8 files changed, 485 insertions(+), 1 deletion(-) diff --git a/crates/polars-plan/src/dsl/cat.rs b/crates/polars-plan/src/dsl/cat.rs index 66a147ebb9c7..ab02898f0246 100644 --- a/crates/polars-plan/src/dsl/cat.rs +++ b/crates/polars-plan/src/dsl/cat.rs @@ -36,4 +36,34 @@ impl CategoricalNameSpace { suffix, ))) } + + /// Check if a string value contains a literal substring. + #[cfg(all(feature = "strings", feature = "regex"))] + pub fn contains(self, pat: &str, literal: bool, strict: bool) -> Expr { + self.0 + .map_private(FunctionExpr::Categorical(CategoricalFunction::Contains { + pat: pat.into(), + literal, + strict: strict && !literal, // if literal, strict = false + })) + } + + /// Uses aho-corasick to find many patterns. + /// + /// # Arguments + /// - `patterns`: an expression that evaluates to a String column + /// - `ascii_case_insensitive`: Enable ASCII-aware case insensitive matching. + /// When this option is enabled, searching will be performed without respect to case for + /// ASCII letters (a-z and A-Z) only. + #[cfg(feature = "find_many")] + pub fn contains_any(self, patterns: Expr, ascii_case_insensitive: bool) -> Expr { + self.0.map_many_private( + FunctionExpr::Categorical(CategoricalFunction::ContainsMany { + ascii_case_insensitive, + }), + &[patterns], + false, + None, + ) + } } diff --git a/crates/polars-plan/src/dsl/function_expr/cat.rs b/crates/polars-plan/src/dsl/function_expr/cat.rs index 9a82fc2ea895..c940df88abd9 100644 --- a/crates/polars-plan/src/dsl/function_expr/cat.rs +++ b/crates/polars-plan/src/dsl/function_expr/cat.rs @@ -1,5 +1,7 @@ +use polars_ops::chunked_array::strings; + use super::*; -use crate::map; +use crate::{map, map_as_slice}; #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Clone, PartialEq, Debug, Eq, Hash)] @@ -13,6 +15,16 @@ pub enum CategoricalFunction { StartsWith(String), #[cfg(feature = "strings")] EndsWith(String), + #[cfg(all(feature = "strings", feature = "regex"))] + Contains { + pat: PlSmallStr, + literal: bool, + strict: bool, + }, + #[cfg(all(feature = "strings", feature = "find_many"))] + ContainsMany { + ascii_case_insensitive: bool, + }, } impl CategoricalFunction { @@ -28,6 +40,10 @@ impl CategoricalFunction { StartsWith(_) => mapper.with_dtype(DataType::Boolean), #[cfg(feature = "strings")] EndsWith(_) => mapper.with_dtype(DataType::Boolean), + #[cfg(all(feature = "strings", feature = "regex"))] + Contains { .. } => mapper.with_dtype(DataType::Boolean), + #[cfg(all(feature = "strings", feature = "find_many"))] + ContainsMany { .. } => mapper.with_dtype(DataType::Boolean), } } } @@ -45,6 +61,10 @@ impl Display for CategoricalFunction { StartsWith(_) => "starts_with", #[cfg(feature = "strings")] EndsWith(_) => "ends_with", + #[cfg(all(feature = "strings", feature = "regex"))] + Contains { .. } => "contains", + #[cfg(all(feature = "strings", feature = "find_many"))] + ContainsMany { .. } => "contains_many", }; write!(f, "cat.{s}") } @@ -63,6 +83,18 @@ impl From for SpecialEq> { StartsWith(prefix) => map!(starts_with, prefix.as_str()), #[cfg(feature = "strings")] EndsWith(suffix) => map!(ends_with, suffix.as_str()), + #[cfg(all(feature = "strings", feature = "regex"))] + Contains { + pat, + literal, + strict, + } => map!(contains, pat.as_str(), literal, strict), + #[cfg(all(feature = "strings", feature = "find_many"))] + ContainsMany { + ascii_case_insensitive, + } => { + map_as_slice!(contains_many, ascii_case_insensitive) + }, } } } @@ -114,6 +146,21 @@ where Ok(out.into_column()) } +/// Fast path: apply a fallible string function to the categories of a categorical column and +/// broadcast the result back to the array. +fn try_apply_to_cats(ca: &CategoricalChunked, mut op: F) -> PolarsResult +where + F: FnMut(&StringChunked) -> PolarsResult>, + ChunkedArray: IntoSeries, + T: PolarsDataType, +{ + let (categories, phys) = _get_cat_phys_map(ca); + let result = op(&categories)?; + // SAFETY: physical idx array is valid. + let out = unsafe { result.take_unchecked(phys.idx().unwrap()) }; + Ok(out.into_column()) +} + /// Fast path: apply a binary function to the categories of a categorical column and broadcast the /// result back to the array. fn apply_to_cats_binary(ca: &CategoricalChunked, mut op: F) -> PolarsResult @@ -152,3 +199,22 @@ fn ends_with(s: &Column, suffix: &str) -> PolarsResult { let ca = s.categorical()?; apply_to_cats_binary(ca, |s| s.as_binary().ends_with(suffix.as_bytes())) } + +#[cfg(all(feature = "strings", feature = "regex"))] +pub(super) fn contains(s: &Column, pat: &str, literal: bool, strict: bool) -> PolarsResult { + let ca = s.categorical()?; + if literal { + try_apply_to_cats(ca, |s| s.contains_literal(pat)) + } else { + try_apply_to_cats(ca, |s| s.contains(pat, strict)) + } +} + +#[cfg(all(feature = "strings", feature = "find_many"))] +fn contains_many(s: &[Column], ascii_case_insensitive: bool) -> PolarsResult { + let ca = s[0].categorical()?; + let patterns = s[1].str()?; + try_apply_to_cats(ca, |s| { + strings::contains_any(s, patterns, ascii_case_insensitive) + }) +} diff --git a/crates/polars-python/src/expr/categorical.rs b/crates/polars-python/src/expr/categorical.rs index b1dcb816b80e..64e6c58adf53 100644 --- a/crates/polars-python/src/expr/categorical.rs +++ b/crates/polars-python/src/expr/categorical.rs @@ -23,4 +23,20 @@ impl PyExpr { fn cat_ends_with(&self, suffix: String) -> Self { self.inner.clone().cat().ends_with(suffix).into() } + + #[pyo3(signature = (pat, literal, strict))] + #[cfg(feature = "regex")] + fn cat_contains(&self, pat: &str, literal: Option, strict: bool) -> Self { + let lit = literal.unwrap_or(false); + self.inner.clone().cat().contains(pat, lit, strict).into() + } + + #[cfg(feature = "find_many")] + fn cat_contains_any(&self, patterns: PyExpr, ascii_case_insensitive: bool) -> Self { + self.inner + .clone() + .cat() + .contains_any(patterns.inner, ascii_case_insensitive) + .into() + } } diff --git a/py-polars/docs/source/reference/expressions/categories.rst b/py-polars/docs/source/reference/expressions/categories.rst index a437704f6094..5771fbcc3ea2 100644 --- a/py-polars/docs/source/reference/expressions/categories.rst +++ b/py-polars/docs/source/reference/expressions/categories.rst @@ -9,6 +9,8 @@ The following methods are available under the `expr.cat` attribute. :toctree: api/ :template: autosummary/accessor_method.rst + Expr.cat.contains + Expr.cat.contains_any Expr.cat.ends_with Expr.cat.get_categories Expr.cat.len_bytes diff --git a/py-polars/docs/source/reference/series/categories.rst b/py-polars/docs/source/reference/series/categories.rst index 46db6491f100..6804951bd4c1 100644 --- a/py-polars/docs/source/reference/series/categories.rst +++ b/py-polars/docs/source/reference/series/categories.rst @@ -9,6 +9,8 @@ The following methods are available under the `Series.cat` attribute. :toctree: api/ :template: autosummary/accessor_method.rst + Series.cat.contains + Series.cat.contains_any Series.cat.ends_with Series.cat.get_categories Series.cat.is_local diff --git a/py-polars/polars/expr/categorical.py b/py-polars/polars/expr/categorical.py index 140c06f4ad53..72bb00cc1c74 100644 --- a/py-polars/polars/expr/categorical.py +++ b/py-polars/polars/expr/categorical.py @@ -2,10 +2,12 @@ from typing import TYPE_CHECKING +from polars._utils.parse import parse_into_expression from polars._utils.wrap import wrap_expr if TYPE_CHECKING: from polars import Expr + from polars._typing import IntoExpr class ExprCatNameSpace: @@ -237,3 +239,139 @@ def ends_with(self, suffix: str) -> Expr: msg = f"'suffix' must be a string; found {type(suffix)!r}" raise TypeError(msg) return wrap_expr(self._pyexpr.cat_ends_with(suffix)) + + def contains( + self, pattern: str, *, literal: bool = False, strict: bool = True + ) -> Expr: + """ + Check if the string representation contains a substring that matches a pattern. + + Parameters + ---------- + pattern + A valid regular expression pattern, compatible with the `regex crate + `_. + literal + Treat `pattern` as a literal string, not as a regular expression. + strict + Raise an error if the underlying pattern is not a valid regex, + otherwise mask out with a null value. + + Notes + ----- + To modify regular expression behaviour (such as case-sensitivity) with + flags, use the inline `(?iLmsuxU)` syntax. For example: + + >>> pl.DataFrame({"s": ["AAA", "aAa", "aaa"]}).with_columns( + ... pl.col("s").cast(pl.Categorical) + ... ).with_columns( + ... default_match=pl.col("s").cat.contains("AA"), + ... insensitive_match=pl.col("s").cat.contains("(?i)AA"), + ... ) + shape: (3, 3) + ┌─────┬───────────────┬───────────────────┐ + │ s ┆ default_match ┆ insensitive_match │ + │ --- ┆ --- ┆ --- │ + │ cat ┆ bool ┆ bool │ + ╞═════╪═══════════════╪═══════════════════╡ + │ AAA ┆ true ┆ true │ + │ aAa ┆ false ┆ true │ + │ aaa ┆ false ┆ true │ + └─────┴───────────────┴───────────────────┘ + + See the regex crate's section on `grouping and flags + `_ for + additional information about the use of inline expression modifiers. + + See Also + -------- + starts_with : Check if string values start with a substring. + ends_with : Check if string values end with a substring. + find: Return the index of the first substring matching a pattern. + + Examples + -------- + >>> df = pl.DataFrame( + ... { + ... "txt": pl.Series( + ... ["Crab", "cat and dog", "rab$bit", None], + ... dtype=pl.Categorical, + ... ) + ... } + ... ) + >>> df.select( + ... pl.col("txt"), + ... pl.col("txt").cat.contains("cat|bit").alias("regex"), + ... pl.col("txt").cat.contains("rab$", literal=True).alias("literal"), + ... ) + shape: (4, 3) + ┌─────────────┬───────┬─────────┐ + │ txt ┆ regex ┆ literal │ + │ --- ┆ --- ┆ --- │ + │ cat ┆ bool ┆ bool │ + ╞═════════════╪═══════╪═════════╡ + │ Crab ┆ false ┆ false │ + │ cat and dog ┆ true ┆ false │ + │ rab$bit ┆ true ┆ true │ + │ null ┆ null ┆ null │ + └─────────────┴───────┴─────────┘ + """ + return wrap_expr(self._pyexpr.cat_contains(pattern, literal, strict)) + + def contains_any( + self, patterns: IntoExpr, *, ascii_case_insensitive: bool = False + ) -> Expr: + """ + Use the Aho-Corasick algorithm to find matches. + + Determines if any of the patterns are contained in the string representation. + + Parameters + ---------- + patterns + String patterns to search. + ascii_case_insensitive + Enable ASCII-aware case-insensitive matching. + When this option is enabled, searching will be performed without respect + to case for ASCII letters (a-z and A-Z) only. + + Notes + ----- + This method supports matching on string literals only, and does not support + regular expression matching. + + Examples + -------- + >>> _ = pl.Config.set_fmt_str_lengths(100) + >>> df = pl.DataFrame( + ... { + ... "lyrics": pl.Series( + ... [ + ... "Everybody wants to rule the world", + ... "Tell me what you want, what you really really want", + ... "Can you feel the love tonight", + ... ], + ... dtype=pl.Categorical, + ... ) + ... } + ... ) + >>> df.with_columns( + ... pl.col("lyrics").cat.contains_any(["you", "me"]).alias("contains_any") + ... ) + shape: (3, 2) + ┌────────────────────────────────────────────────────┬──────────────┐ + │ lyrics ┆ contains_any │ + │ --- ┆ --- │ + │ cat ┆ bool │ + ╞════════════════════════════════════════════════════╪══════════════╡ + │ Everybody wants to rule the world ┆ false │ + │ Tell me what you want, what you really really want ┆ true │ + │ Can you feel the love tonight ┆ true │ + └────────────────────────────────────────────────────┴──────────────┘ + """ + patterns = parse_into_expression( + patterns, str_as_lit=False, list_as_series=True + ) + return wrap_expr( + self._pyexpr.cat_contains_any(patterns, ascii_case_insensitive) + ) diff --git a/py-polars/polars/series/categorical.py b/py-polars/polars/series/categorical.py index f7f04be28305..4d23f993bfd5 100644 --- a/py-polars/polars/series/categorical.py +++ b/py-polars/polars/series/categorical.py @@ -239,3 +239,116 @@ def ends_with(self, suffix: str) -> Series: null ] """ + + def contains( + self, pattern: str, *, literal: bool = False, strict: bool = True + ) -> Series: + """ + Check if the string representation contains a substring that matches a pattern. + + Parameters + ---------- + pattern + A valid regular expression pattern, compatible with the `regex crate + `_. + literal + Treat `pattern` as a literal string, not as a regular expression. + strict + Raise an error if the underlying pattern is not a valid regex, + otherwise mask out with a null value. + + Notes + ----- + To modify regular expression behaviour (such as case-sensitivity) with + flags, use the inline `(?iLmsuxU)` syntax. For example: + + Default (case-sensitive) match: + + >>> s = pl.Series("s", ["AAA", "aAa", "aaa"], dtype=pl.Categorical) + >>> s.cat.contains("AA").to_list() + [True, False, False] + + Case-insensitive match, using an inline flag: + + >>> s = pl.Series("s", ["AAA", "aAa", "aaa"], dtype=pl.Categorical) + >>> s.cat.contains("(?i)AA").to_list() + [True, True, True] + + See the regex crate's section on `grouping and flags + `_ for + additional information about the use of inline expression modifiers. + + Returns + ------- + Series + Series of data type :class:`Boolean`. + + Examples + -------- + >>> s = pl.Series( + ... ["Crab", "cat and dog", "rab$bit", None], + ... dtype=pl.Categorical, + ... ) + >>> s.cat.contains("cat|bit") + shape: (4,) + Series: '' [bool] + [ + false + true + true + null + ] + >>> s.cat.contains("rab$", literal=True) + shape: (4,) + Series: '' [bool] + [ + false + false + true + null + ] + """ + + def contains_any( + self, patterns: Series | list[str], *, ascii_case_insensitive: bool = False + ) -> Series: + """ + Use the Aho-Corasick algorithm to find matches. + + Determines if any of the patterns are contained in the string representation. + + Parameters + ---------- + patterns + String patterns to search. + ascii_case_insensitive + Enable ASCII-aware case-insensitive matching. + When this option is enabled, searching will be performed without respect + to case for ASCII letters (a-z and A-Z) only. + + Notes + ----- + This method supports matching on string literals only, and does not support + regular expression matching. + + Examples + -------- + >>> _ = pl.Config.set_fmt_str_lengths(100) + >>> s = pl.Series( + ... "lyrics", + ... [ + ... "Everybody wants to rule the world", + ... "Tell me what you want, what you really really want", + ... "Can you feel the love tonight", + ... ], + ... dtype=pl.Categorical, + ... ) + >>> s.cat.contains_any(["you", "me"]) + shape: (3,) + Series: 'lyrics' [bool] + [ + false + true + true + ] + """ diff --git a/py-polars/tests/unit/operations/namespaces/test_categorical.py b/py-polars/tests/unit/operations/namespaces/test_categorical.py index 3de783baf8d9..f218b64b5443 100644 --- a/py-polars/tests/unit/operations/namespaces/test_categorical.py +++ b/py-polars/tests/unit/operations/namespaces/test_categorical.py @@ -3,6 +3,7 @@ import pytest import polars as pl +from polars.exceptions import ComputeError from polars.testing import assert_frame_equal, assert_series_equal @@ -278,3 +279,119 @@ def test_starts_ends_with() -> None: with pytest.raises(TypeError, match="'suffix' must be a string; found"): df.select(pl.col("a").cat.ends_with(None)) # type: ignore[arg-type] + + +def test_cat_contains() -> None: + s = pl.Series(["messi", "ronaldo", "ibrahimovic", "messi"], dtype=pl.Categorical) + expected = pl.Series([True, False, False, True]) + assert_series_equal(s.cat.contains("mes"), expected) + + +def test_contains() -> None: + s_txt = pl.Series(["123", "456", "789", "123"], dtype=pl.Categorical) + assert ( + pl.Series([None, None, None, None]).cast(pl.Boolean).to_list() + == s_txt.cat.contains("(not_valid_regex", literal=False, strict=False).to_list() + ) + with pytest.raises(ComputeError): + s_txt.cat.contains("(not_valid_regex", literal=False, strict=True) + assert ( + pl.Series([True, False, False, True]).cast(pl.Boolean).to_list() + == s_txt.cat.contains("1", literal=False, strict=False).to_list() + ) + + df = pl.DataFrame( + data=[ + (1, "some * * text"), + (2, "(with) special\n * chars"), + (3, "**etc...?$"), + (4, "some * * text"), + ], + schema={"idx": pl.get_index_type(), "text": pl.Categorical}, + orient="row", + ) + for pattern, as_literal, expected in ( + (r"\* \*", False, [True, False, False, True]), + (r"* *", True, [True, False, False, True]), + (r"^\(", False, [False, True, False, False]), + (r"^\(", True, [False, False, False, False]), + (r"(", True, [False, True, False, False]), + (r"e", False, [True, True, True, True]), + (r"e", True, [True, True, True, True]), + (r"^\S+$", False, [False, False, True, False]), + (r"\?\$", False, [False, False, True, False]), + (r"?$", True, [False, False, True, False]), + ): + # series + assert ( + expected == df["text"].cat.contains(pattern, literal=as_literal).to_list() + ) + # frame select + assert ( + expected + == df.select(pl.col("text").cat.contains(pattern, literal=as_literal))[ + "text" + ].to_list() + ) + # frame filter + assert sum(expected) == len( + df.filter(pl.col("text").cat.contains(pattern, literal=as_literal)) + ) + + +@pytest.mark.parametrize( + ("pattern", "case_insensitive", "expected"), + [ + (["me"], False, [True, False, False, True]), + (["Me"], False, [False, False, True, False]), + (["Me"], True, [True, False, True, True]), + (pl.Series(["me", "they"]), False, [True, False, True, True]), + (pl.Series(["Me", "they"]), False, [False, False, True, False]), + (pl.Series(["Me", "they"]), True, [True, False, True, True]), + (["me", "they"], False, [True, False, True, True]), + (["Me", "they"], False, [False, False, True, False]), + (["Me", "they"], True, [True, False, True, True]), + ], +) +def test_contains_any( + pattern: pl.Series | list[str], + case_insensitive: bool, + expected: list[bool], +) -> None: + df = pl.DataFrame( + { + "text": pl.Series( + [ + "Tell me what you want", + "Tell you what I want", + "Tell Me what they want", + "Tell me what you want", + ], + dtype=pl.Categorical, + ) + } + ) + # series + assert ( + expected + == df["text"] + .cat.contains_any(pattern, ascii_case_insensitive=case_insensitive) + .to_list() + ) + # expr + assert ( + expected + == df.select( + pl.col("text").cat.contains_any( + pattern, ascii_case_insensitive=case_insensitive + ) + )["text"].to_list() + ) + # frame filter + assert sum(expected) == len( + df.filter( + pl.col("text").cat.contains_any( + pattern, ascii_case_insensitive=case_insensitive + ) + ) + )