Skip to content

Commit

Permalink
feat(python): Expose 'strict' argument to 'is_in' (#17776)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Jul 22, 2024
1 parent 7888d3b commit a57c75c
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 4 deletions.
13 changes: 10 additions & 3 deletions py-polars/polars/expr/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -5698,14 +5698,21 @@ def xor(self, other: Any) -> Expr:
"""
return self.__xor__(other)

def is_in(self, other: Expr | Collection[Any] | Series) -> Expr:
def is_in(
self, other: Expr | Collection[Any] | Series, *, strict: bool = True
) -> Expr:
"""
Check if elements of this expression are present in the other Series.
Parameters
----------
other
Series or sequence of primitive type.
Series or sequence to test membership of.
strict
If a python collection is given, `strict`
will be passed to the `Series` constructor
and indicates how different types should be
handled.
Returns
-------
Expand All @@ -5732,7 +5739,7 @@ def is_in(self, other: Expr | Collection[Any] | Series) -> Expr:
if isinstance(other, Collection) and not isinstance(other, str):
if isinstance(other, (Set, FrozenSet)):
other = list(other)
other = F.lit(pl.Series(other))._pyexpr
other = F.lit(pl.Series(other, strict=strict))._pyexpr
else:
other = parse_into_expression(other)
return self._from_pyexpr(self._pyexpr.is_in(other))
Expand Down
12 changes: 11 additions & 1 deletion py-polars/polars/series/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -3703,10 +3703,20 @@ def is_not_nan(self) -> Series:
]
"""

def is_in(self, other: Series | Collection[Any]) -> Series:
def is_in(self, other: Series | Collection[Any], *, strict: bool = True) -> Series:
"""
Check if elements of this Series are in the other Series.
Parameters
----------
other
Series or sequence to test membership of.
strict
If a python collection is given, `strict`
will be passed to the `Series` constructor
and indicates how different types should be
handled.
Returns
-------
Series
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,8 @@ def test_is_first_last_distinct_all_null(dtypes: PolarsDataType) -> None:
s = pl.Series([None, None, None], dtype=dtypes)
assert s.is_first_distinct().to_list() == [True, False, False]
assert s.is_last_distinct().to_list() == [False, False, True]


def test_is_in_non_strict() -> None:
s = pl.Series([1, 2, 3, 4])
assert s.is_in([2, 2.5], strict=False).to_list() == [False, True, False, False]

0 comments on commit a57c75c

Please sign in to comment.