Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add cat.contains and cat.contains_any #20582

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions crates/polars-plan/src/dsl/cat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,34 @@ impl CategoricalNameSpace {
suffix,
)))
}

/// Check if a string value contains a literal substring.
#[cfg(all(feature = "strings", feature = "regex"))]
pub fn contains(self, pat: &str, literal: bool, strict: bool) -> Expr {
self.0
.map_private(FunctionExpr::Categorical(CategoricalFunction::Contains {
pat: pat.into(),
literal,
strict: strict && !literal, // if literal, strict = false
}))
}

/// Uses aho-corasick to find many patterns.
///
/// # Arguments
/// - `patterns`: an expression that evaluates to a String column
/// - `ascii_case_insensitive`: Enable ASCII-aware case insensitive matching.
/// When this option is enabled, searching will be performed without respect to case for
/// ASCII letters (a-z and A-Z) only.
#[cfg(feature = "find_many")]
pub fn contains_any(self, patterns: Expr, ascii_case_insensitive: bool) -> Expr {
self.0.map_many_private(
FunctionExpr::Categorical(CategoricalFunction::ContainsMany {
ascii_case_insensitive,
}),
&[patterns],
false,
None,
)
}
}
68 changes: 67 additions & 1 deletion crates/polars-plan/src/dsl/function_expr/cat.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use polars_ops::chunked_array::strings;

use super::*;
use crate::map;
use crate::{map, map_as_slice};

#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Clone, PartialEq, Debug, Eq, Hash)]
Expand All @@ -13,6 +15,16 @@ pub enum CategoricalFunction {
StartsWith(String),
#[cfg(feature = "strings")]
EndsWith(String),
#[cfg(all(feature = "strings", feature = "regex"))]
Contains {
pat: PlSmallStr,
literal: bool,
strict: bool,
},
#[cfg(all(feature = "strings", feature = "find_many"))]
ContainsMany {
ascii_case_insensitive: bool,
},
}

impl CategoricalFunction {
Expand All @@ -28,6 +40,10 @@ impl CategoricalFunction {
StartsWith(_) => mapper.with_dtype(DataType::Boolean),
#[cfg(feature = "strings")]
EndsWith(_) => mapper.with_dtype(DataType::Boolean),
#[cfg(all(feature = "strings", feature = "regex"))]
Contains { .. } => mapper.with_dtype(DataType::Boolean),
#[cfg(all(feature = "strings", feature = "find_many"))]
ContainsMany { .. } => mapper.with_dtype(DataType::Boolean),
}
}
}
Expand All @@ -45,6 +61,10 @@ impl Display for CategoricalFunction {
StartsWith(_) => "starts_with",
#[cfg(feature = "strings")]
EndsWith(_) => "ends_with",
#[cfg(all(feature = "strings", feature = "regex"))]
Contains { .. } => "contains",
#[cfg(all(feature = "strings", feature = "find_many"))]
ContainsMany { .. } => "contains_many",
};
write!(f, "cat.{s}")
}
Expand All @@ -63,6 +83,18 @@ impl From<CategoricalFunction> for SpecialEq<Arc<dyn ColumnsUdf>> {
StartsWith(prefix) => map!(starts_with, prefix.as_str()),
#[cfg(feature = "strings")]
EndsWith(suffix) => map!(ends_with, suffix.as_str()),
#[cfg(all(feature = "strings", feature = "regex"))]
Contains {
pat,
literal,
strict,
} => map!(contains, pat.as_str(), literal, strict),
#[cfg(all(feature = "strings", feature = "find_many"))]
ContainsMany {
ascii_case_insensitive,
} => {
map_as_slice!(contains_many, ascii_case_insensitive)
},
}
}
}
Expand Down Expand Up @@ -114,6 +146,21 @@ where
Ok(out.into_column())
}

/// Fast path: apply a fallible string function to the categories of a categorical column and
/// broadcast the result back to the array.
fn try_apply_to_cats<F, T>(ca: &CategoricalChunked, mut op: F) -> PolarsResult<Column>
where
F: FnMut(&StringChunked) -> PolarsResult<ChunkedArray<T>>,
ChunkedArray<T>: IntoSeries,
T: PolarsDataType<HasViews = FalseT, IsStruct = FalseT, IsNested = FalseT>,
{
let (categories, phys) = _get_cat_phys_map(ca);
let result = op(&categories)?;
// SAFETY: physical idx array is valid.
let out = unsafe { result.take_unchecked(phys.idx().unwrap()) };
Ok(out.into_column())
}

/// Fast path: apply a binary function to the categories of a categorical column and broadcast the
/// result back to the array.
fn apply_to_cats_binary<F, T>(ca: &CategoricalChunked, mut op: F) -> PolarsResult<Column>
Expand Down Expand Up @@ -152,3 +199,22 @@ fn ends_with(s: &Column, suffix: &str) -> PolarsResult<Column> {
let ca = s.categorical()?;
apply_to_cats_binary(ca, |s| s.as_binary().ends_with(suffix.as_bytes()))
}

#[cfg(all(feature = "strings", feature = "regex"))]
pub(super) fn contains(s: &Column, pat: &str, literal: bool, strict: bool) -> PolarsResult<Column> {
let ca = s.categorical()?;
if literal {
try_apply_to_cats(ca, |s| s.contains_literal(pat))
} else {
try_apply_to_cats(ca, |s| s.contains(pat, strict))
}
}

#[cfg(all(feature = "strings", feature = "find_many"))]
fn contains_many(s: &[Column], ascii_case_insensitive: bool) -> PolarsResult<Column> {
let ca = s[0].categorical()?;
let patterns = s[1].str()?;
try_apply_to_cats(ca, |s| {
strings::contains_any(s, patterns, ascii_case_insensitive)
})
}
16 changes: 16 additions & 0 deletions crates/polars-python/src/expr/categorical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,20 @@ impl PyExpr {
fn cat_ends_with(&self, suffix: String) -> Self {
self.inner.clone().cat().ends_with(suffix).into()
}

#[pyo3(signature = (pat, literal, strict))]
#[cfg(feature = "regex")]
fn cat_contains(&self, pat: &str, literal: Option<bool>, strict: bool) -> Self {
let lit = literal.unwrap_or(false);
self.inner.clone().cat().contains(pat, lit, strict).into()
}

#[cfg(feature = "find_many")]
fn cat_contains_any(&self, patterns: PyExpr, ascii_case_insensitive: bool) -> Self {
self.inner
.clone()
.cat()
.contains_any(patterns.inner, ascii_case_insensitive)
.into()
}
}
2 changes: 2 additions & 0 deletions py-polars/docs/source/reference/expressions/categories.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ The following methods are available under the `expr.cat` attribute.
:toctree: api/
:template: autosummary/accessor_method.rst

Expr.cat.contains
Expr.cat.contains_any
Expr.cat.ends_with
Expr.cat.get_categories
Expr.cat.len_bytes
Expand Down
2 changes: 2 additions & 0 deletions py-polars/docs/source/reference/series/categories.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ The following methods are available under the `Series.cat` attribute.
:toctree: api/
:template: autosummary/accessor_method.rst

Series.cat.contains
Series.cat.contains_any
Series.cat.ends_with
Series.cat.get_categories
Series.cat.is_local
Expand Down
138 changes: 138 additions & 0 deletions py-polars/polars/expr/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@

from typing import TYPE_CHECKING

from polars._utils.parse import parse_into_expression
from polars._utils.wrap import wrap_expr

if TYPE_CHECKING:
from polars import Expr
from polars._typing import IntoExpr


class ExprCatNameSpace:
Expand Down Expand Up @@ -237,3 +239,139 @@ def ends_with(self, suffix: str) -> Expr:
msg = f"'suffix' must be a string; found {type(suffix)!r}"
raise TypeError(msg)
return wrap_expr(self._pyexpr.cat_ends_with(suffix))

def contains(
self, pattern: str, *, literal: bool = False, strict: bool = True
) -> Expr:
"""
Check if the string representation contains a substring that matches a pattern.

Parameters
----------
pattern
A valid regular expression pattern, compatible with the `regex crate
<https://docs.rs/regex/latest/regex/>`_.
literal
Treat `pattern` as a literal string, not as a regular expression.
strict
Raise an error if the underlying pattern is not a valid regex,
otherwise mask out with a null value.

Notes
-----
To modify regular expression behaviour (such as case-sensitivity) with
flags, use the inline `(?iLmsuxU)` syntax. For example:

>>> pl.DataFrame({"s": ["AAA", "aAa", "aaa"]}).with_columns(
... pl.col("s").cast(pl.Categorical)
... ).with_columns(
... default_match=pl.col("s").cat.contains("AA"),
... insensitive_match=pl.col("s").cat.contains("(?i)AA"),
... )
shape: (3, 3)
┌─────┬───────────────┬───────────────────┐
│ s ┆ default_match ┆ insensitive_match │
│ --- ┆ --- ┆ --- │
│ cat ┆ bool ┆ bool │
╞═════╪═══════════════╪═══════════════════╡
│ AAA ┆ true ┆ true │
│ aAa ┆ false ┆ true │
│ aaa ┆ false ┆ true │
└─────┴───────────────┴───────────────────┘

See the regex crate's section on `grouping and flags
<https://docs.rs/regex/latest/regex/#grouping-and-flags>`_ for
additional information about the use of inline expression modifiers.

See Also
--------
starts_with : Check if string values start with a substring.
ends_with : Check if string values end with a substring.
find: Return the index of the first substring matching a pattern.

Examples
--------
>>> df = pl.DataFrame(
... {
... "txt": pl.Series(
... ["Crab", "cat and dog", "rab$bit", None],
... dtype=pl.Categorical,
... )
... }
... )
>>> df.select(
... pl.col("txt"),
... pl.col("txt").cat.contains("cat|bit").alias("regex"),
... pl.col("txt").cat.contains("rab$", literal=True).alias("literal"),
... )
shape: (4, 3)
┌─────────────┬───────┬─────────┐
│ txt ┆ regex ┆ literal │
│ --- ┆ --- ┆ --- │
│ cat ┆ bool ┆ bool │
╞═════════════╪═══════╪═════════╡
│ Crab ┆ false ┆ false │
│ cat and dog ┆ true ┆ false │
│ rab$bit ┆ true ┆ true │
│ null ┆ null ┆ null │
└─────────────┴───────┴─────────┘
"""
return wrap_expr(self._pyexpr.cat_contains(pattern, literal, strict))

def contains_any(
self, patterns: IntoExpr, *, ascii_case_insensitive: bool = False
) -> Expr:
"""
Use the Aho-Corasick algorithm to find matches.

Determines if any of the patterns are contained in the string representation.

Parameters
----------
patterns
String patterns to search.
ascii_case_insensitive
Enable ASCII-aware case-insensitive matching.
When this option is enabled, searching will be performed without respect
to case for ASCII letters (a-z and A-Z) only.

Notes
-----
This method supports matching on string literals only, and does not support
regular expression matching.

Examples
--------
>>> _ = pl.Config.set_fmt_str_lengths(100)
>>> df = pl.DataFrame(
... {
... "lyrics": pl.Series(
... [
... "Everybody wants to rule the world",
... "Tell me what you want, what you really really want",
... "Can you feel the love tonight",
... ],
... dtype=pl.Categorical,
... )
... }
... )
>>> df.with_columns(
... pl.col("lyrics").cat.contains_any(["you", "me"]).alias("contains_any")
... )
shape: (3, 2)
┌────────────────────────────────────────────────────┬──────────────┐
│ lyrics ┆ contains_any │
│ --- ┆ --- │
│ cat ┆ bool │
╞════════════════════════════════════════════════════╪══════════════╡
│ Everybody wants to rule the world ┆ false │
│ Tell me what you want, what you really really want ┆ true │
│ Can you feel the love tonight ┆ true │
└────────────────────────────────────────────────────┴──────────────┘
"""
patterns = parse_into_expression(
patterns, str_as_lit=False, list_as_series=True
)
return wrap_expr(
self._pyexpr.cat_contains_any(patterns, ascii_case_insensitive)
)
Loading
Loading