Skip to content

Commit

Permalink
Added test for Expr.str contains method
Browse files Browse the repository at this point in the history
  • Loading branch information
ugohuche committed Jun 14, 2024
1 parent 0b34024 commit aeff68a
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 2 deletions.
2 changes: 1 addition & 1 deletion narwhals/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -1462,7 +1462,7 @@ def ends_with(self, suffix: str) -> Expr:
lambda plx: self._expr._call(plx).str.ends_with(suffix)
)

def contains(self, pattern: str, literal: bool = False) -> Expr:
def contains(self, pattern: str | Expr, literal: bool = False) -> Expr:
"""
Check if string contains a substring that matches a pattern.
Expand Down
2 changes: 1 addition & 1 deletion narwhals/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1612,7 +1612,7 @@ def __init__(self, series: Series) -> None:
def ends_with(self, suffix: str) -> Series:
return self._series.__class__(self._series._series.str.ends_with(suffix))

def contains(self, pattern: str, literal: bool = False) -> Series:
def contains(self, pattern: str | Expr, literal: bool = False) -> Series:
return self._series.__class__(
self._series._series.str.contains(pattern, literal=literal)
)
Expand Down
31 changes: 31 additions & 0 deletions tests/expr/str/contains_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from typing import Any

import pandas as pd
import polars as pl
import pytest

import narwhals as nw
from tests.utils import compare_dicts

data = {
"pets": ["cat", "dog", "rabbit and parrot", "dove"]
}

df_pandas = pd.DataFrame(data)
df_polars = pl.DataFrame(data)

@pytest.mark.parametrize("df_any", [df_pandas, df_polars])
def test_contains(df_any: Any) -> None:
df = nw.from_native(df_any, eager_only=True)
result = df.with_columns(
case_insensitive_match=nw.col("pets").str.contains("(?i)parrot|Dove")
)
expected = {
"pets": ["cat", "dog", "rabbit and parrot", "dove"],
"case_insensitive_match": [False, False, True, True]
}
compare_dicts(result, expected)
result = df.with_columns(
case_insensitive_match=df["pets"].str.contains("(?i)parrot|Dove")
)
compare_dicts(result, expected)

0 comments on commit aeff68a

Please sign in to comment.