Skip to content

Commit

Permalink
test: add test for multiple predicates
Browse files Browse the repository at this point in the history
  • Loading branch information
aivanoved committed Jul 25, 2024
1 parent 0454ac4 commit 4ad28b7
Showing 1 changed file with 23 additions and 5 deletions.
28 changes: 23 additions & 5 deletions tests/expr_and_series/when_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from tests.utils import compare_dicts

data = {
"a": [1, 1, 2],
"a": [1, 2, 3],
"b": ["a", "b", "c"],
"c": [4.1, 5.0, 6.0],
"d": [True, False, True],
Expand All @@ -23,11 +23,11 @@ def test_when(request: Any, constructor: Any) -> None:
df = nw.from_native(constructor(data))
result = df.with_columns(when(nw.col("a") == 1).then(value=3).alias("a_when"))
expected = {
"a": [1, 1, 2],
"a": [1, 2, 3],
"b": ["a", "b", "c"],
"c": [4.1, 5.0, 6.0],
"d": [True, False, True],
"a_when": [3, 3, None],
"a_when": [3, None, None],
}
compare_dicts(result, expected)

Expand All @@ -39,10 +39,28 @@ def test_when_otherwise(request: Any, constructor: Any) -> None:
df = nw.from_native(constructor(data))
result = df.with_columns(when(nw.col("a") == 1).then(3).otherwise(6).alias("a_when"))
expected = {
"a": [1, 1, 2],
"a": [1, 2, 3],
"b": ["a", "b", "c"],
"c": [4.1, 5.0, 6.0],
"d": [True, False, True],
"a_when": [3, 3, 6],
"a_when": [3, 6, 6],
}
compare_dicts(result, expected)


def test_multiple_conditions(request: Any, constructor: Any) -> None:
if "pyarrow_table" in str(constructor):
request.applymarker(pytest.mark.xfail)

df = nw.from_native(constructor(data))
result = df.with_columns(
when(nw.col("a") < 3, nw.col("c") < 5.0).then(3).alias("a_when")
)
expected = {
"a": [1, 2, 3],
"b": ["a", "b", "c"],
"c": [4.1, 5.0, 6.0],
"d": [True, False, True],
"a_when": [3, None, None],
}
compare_dicts(result, expected)

0 comments on commit 4ad28b7

Please sign in to comment.