Skip to content

Commit

Permalink
fix(python): raise suitable error on invalid predicates passed to `fi…
Browse files Browse the repository at this point in the history
…lter` method
  • Loading branch information
alexander-beedie committed Oct 21, 2023
1 parent 04357ef commit 66b8adb
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 0 deletions.
11 changes: 11 additions & 0 deletions py-polars/polars/lazyframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
_process_null_values,
find_stacklevel,
is_bool_sequence,
is_sequence,
normalize_filepath,
)

Expand Down Expand Up @@ -2655,6 +2656,16 @@ def filter(
for p in predicates:
if is_bool_sequence(p):
boolean_masks.append(pl.Series(p, dtype=Boolean))
elif (
(is_seq := is_sequence(p))
and any(not isinstance(x, pl.Expr) for x in p)
) or (not is_seq and not isinstance(p, pl.Expr)):
err = (
f"Series(…, dtype={p.dtype})"
if isinstance(p, pl.Series)
else f"{p!r}"
)
raise ValueError(f"Invalid predicate for `filter`: {err}")
else:
all_predicates.extend(
wrap_expr(x) for x in parse_as_list_of_expressions(p)
Expand Down
7 changes: 7 additions & 0 deletions py-polars/polars/utils/various.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,13 @@ def is_int_sequence(val: object) -> TypeGuard[Sequence[int]]:
return isinstance(val, Sequence) and _is_iterable_of(val, int)


def is_sequence(val: object) -> TypeGuard[Sequence[int]]:
"""Check whether the given input is a numpy array or python sequence."""
return (_check_for_numpy(val) and isinstance(val, np.ndarray)) or isinstance(
val, Sequence
)


def is_str_sequence(
val: object, *, allow_str: bool = False
) -> TypeGuard[Sequence[str]]:
Expand Down
19 changes: 19 additions & 0 deletions py-polars/tests/unit/test_predicates.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from datetime import date, datetime, timedelta
from typing import Any

import numpy as np
import pytest

import polars as pl

Expand Down Expand Up @@ -203,3 +205,20 @@ def test_no_predicate_push_down_with_cast_and_alias_11883() -> None:
.filter((pl.col("b") >= 1) & (pl.col("b") < 1))
)
assert 'SELECTION: "None"' in out.explain(predicate_pushdown=True)


@pytest.mark.parametrize(
"predicate",
[
0,
"x",
[2, 3],
{"x": 1},
pl.Series([1, 2, 3]),
None,
],
)
def test_invalid_predicates(predicate: Any) -> None:
df = pl.DataFrame({"colx": ["aa", "bb", "cc", "dd"]})
with pytest.raises(ValueError, match="Invalid predicate"):
df.filter(predicate)

0 comments on commit 66b8adb

Please sign in to comment.