From ab83c2ad27b82148617af1f00bb26ae809c4aa04 Mon Sep 17 00:00:00 2001 From: Alexander Beedie Date: Wed, 22 Nov 2023 21:02:44 +0400 Subject: [PATCH] feat(python): extend recent `filter` syntax upgrades to `when/then` construct (#12603) --- py-polars/polars/expr/whenthen.py | 45 ++++++++--- py-polars/polars/functions/whenthen.py | 74 ++++++++++++++----- py-polars/polars/utils/_parse_expr_input.py | 33 ++++++++- .../tests/unit/functions/test_whenthen.py | 47 +++++++++++- 4 files changed, 166 insertions(+), 33 deletions(-) diff --git a/py-polars/polars/expr/whenthen.py b/py-polars/polars/expr/whenthen.py index 2a4fe40d8f7d..eb2561db288d 100644 --- a/py-polars/polars/expr/whenthen.py +++ b/py-polars/polars/expr/whenthen.py @@ -1,10 +1,13 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Iterable import polars.functions as F from polars.expr.expr import Expr -from polars.utils._parse_expr_input import parse_as_expression +from polars.utils._parse_expr_input import ( + parse_as_expression, + parse_when_constraint_expressions, +) from polars.utils._wrap import wrap_expr from polars.utils.deprecation import ( deprecate_renamed_parameter, @@ -67,18 +70,27 @@ def _pyexpr(self) -> PyExpr: return self._then.otherwise(F.lit(None)._pyexpr) @deprecate_renamed_parameter("predicate", "condition", version="0.18.9") - def when(self, condition: IntoExpr) -> ChainedWhen: + def when( + self, + *predicates: IntoExpr | Iterable[IntoExpr], + **constraints: Any, + ) -> ChainedWhen: """ Add a condition to the `when-then-otherwise` expression. Parameters ---------- - condition - The condition for applying the subsequent statement. - Accepts a boolean expression. String input is parsed as a column name. + predicates + Condition(s) that must be met in order to apply the subsequent statement. + Accepts one or more boolean expressions, which are implicitly combined with + `&`. String input is parsed as a column name. + constraints + Apply conditions as `colname = value` keyword arguments that are treated as + equality matches, such as `x = 123`. As with the predicates parameter, + multiple conditions are implicitly combined using `&`. """ - condition_pyexpr = parse_as_expression(condition) + condition_pyexpr = parse_when_constraint_expressions(*predicates, **constraints) return ChainedWhen(self._then.when(condition_pyexpr)) @deprecate_renamed_parameter("expr", "statement", version="0.18.9") @@ -150,18 +162,27 @@ def _pyexpr(self) -> PyExpr: return self._chained_then.otherwise(F.lit(None)._pyexpr) @deprecate_renamed_parameter("predicate", "condition", version="0.18.9") - def when(self, condition: IntoExpr) -> ChainedWhen: + def when( + self, + *predicates: IntoExpr | Iterable[IntoExpr], + **constraints: Any, + ) -> ChainedWhen: """ Add another condition to the `when-then-otherwise` expression. Parameters ---------- - condition - The condition for applying the subsequent statement. - Accepts a boolean expression. String input is parsed as a column name. + predicates + Condition(s) that must be met in order to apply the subsequent statement. + Accepts one or more boolean expressions, which are implicitly combined with + `&`. String input is parsed as a column name. + constraints + Apply conditions as `colname = value` keyword arguments that are treated as + equality matches, such as `x = 123`. As with the predicates parameter, + multiple conditions are implicitly combined using `&`. """ - condition_pyexpr = parse_as_expression(condition) + condition_pyexpr = parse_when_constraint_expressions(*predicates, **constraints) return ChainedWhen(self._chained_then.when(condition_pyexpr)) @deprecate_renamed_parameter("expr", "statement", version="0.18.9") diff --git a/py-polars/polars/functions/whenthen.py b/py-polars/polars/functions/whenthen.py index 17833097e482..3f5e197dd5de 100644 --- a/py-polars/polars/functions/whenthen.py +++ b/py-polars/polars/functions/whenthen.py @@ -1,27 +1,30 @@ from __future__ import annotations import contextlib -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Iterable import polars._reexport as pl -from polars.utils._parse_expr_input import parse_as_expression +from polars.utils._parse_expr_input import parse_when_constraint_expressions from polars.utils.deprecation import deprecate_renamed_parameter with contextlib.suppress(ImportError): # Module not available when building docs import polars.polars as plr if TYPE_CHECKING: - from polars.type_aliases import IntoExpr + from polars.type_aliases import IntoExprColumn @deprecate_renamed_parameter("expr", "condition", version="0.18.9") -def when(condition: IntoExpr) -> pl.When: +def when( + *predicates: IntoExprColumn | Iterable[IntoExprColumn] | bool, + **constraints: Any, +) -> pl.When: """ Start a `when-then-otherwise` expression. Expression similar to an `if-else` statement in Python. Always initiated by a - `pl.when().then()`. Optionally followed by chaining - one or more `.when().then()` statements. + `pl.when().then()`., and optionally followed by + chaining one or more `.when().then()` statements. Chained `when, thens` should be read as Python `if, elif, ... elif` blocks, not as `if, if, ... if`, i.e. the first condition that evaluates to True will be picked. @@ -32,9 +35,14 @@ def when(condition: IntoExpr) -> pl.When: Parameters ---------- - condition - The condition for applying the subsequent statement. - Accepts a boolean expression. String input is parsed as a column name. + predicates + Condition(s) that must be met in order to apply the subsequent statement. + Accepts one or more boolean expressions, which are implicitly combined with + `&`. String input is parsed as a column name. + constraints + Apply conditions as `colname = value` keyword arguments that are treated as + equality matches, such as `x = 123`. As with the predicates parameter, multiple + conditions are implicitly combined using `&`. Warnings -------- @@ -48,12 +56,7 @@ def when(condition: IntoExpr) -> pl.When: where it isn't. >>> df = pl.DataFrame({"foo": [1, 3, 4], "bar": [3, 4, 0]}) - >>> df.with_columns( - ... pl.when(pl.col("foo") > 2) - ... .then(pl.lit(1)) - ... .otherwise(pl.lit(-1)) - ... .alias("val") - ... ) + >>> df.with_columns(pl.when(pl.col("foo") > 2).then(1).otherwise(-1).alias("val")) shape: (3, 3) ┌─────┬─────┬─────┐ │ foo ┆ bar ┆ val │ @@ -93,7 +96,7 @@ def when(condition: IntoExpr) -> pl.When: The `otherwise` at the end is optional. If left out, any rows where none of the `when` expressions evaluate to True, are set to `null`: - >>> df.with_columns(pl.when(pl.col("foo") > 2).then(pl.lit(1)).alias("val")) + >>> df.with_columns(pl.when(pl.col("foo") > 2).then(1).alias("val")) shape: (3, 3) ┌─────┬─────┬──────┐ │ foo ┆ bar ┆ val │ @@ -105,6 +108,41 @@ def when(condition: IntoExpr) -> pl.When: │ 4 ┆ 0 ┆ 1 │ └─────┴─────┴──────┘ + Pass multiple predicates, each of which must be met: + + >>> df.with_columns( + ... val=pl.when( + ... pl.col("bar") > 0, + ... pl.col("foo") % 2 != 0, + ... ) + ... .then(99) + ... .otherwise(-1) + ... ) + shape: (3, 3) + ┌─────┬─────┬─────┐ + │ foo ┆ bar ┆ val │ + │ --- ┆ --- ┆ --- │ + │ i64 ┆ i64 ┆ i32 │ + ╞═════╪═════╪═════╡ + │ 1 ┆ 3 ┆ 99 │ + │ 3 ┆ 4 ┆ 99 │ + │ 4 ┆ 0 ┆ -1 │ + └─────┴─────┴─────┘ + + Pass conditions as keyword arguments: + + >>> df.with_columns(val=pl.when(foo=4, bar=0).then(99).otherwise(-1)) + shape: (3, 3) + ┌─────┬─────┬─────┐ + │ foo ┆ bar ┆ val │ + │ --- ┆ --- ┆ --- │ + │ i64 ┆ i64 ┆ i32 │ + ╞═════╪═════╪═════╡ + │ 1 ┆ 3 ┆ -1 │ + │ 3 ┆ 4 ┆ -1 │ + │ 4 ┆ 0 ┆ 99 │ + └─────┴─────┴─────┘ + """ - condition_pyexpr = parse_as_expression(condition) - return pl.When(plr.when(condition_pyexpr)) + condition = parse_when_constraint_expressions(*predicates, **constraints) + return pl.When(plr.when(condition)) diff --git a/py-polars/polars/utils/_parse_expr_input.py b/py-polars/polars/utils/_parse_expr_input.py index 3c067610e807..43fa56480f99 100644 --- a/py-polars/polars/utils/_parse_expr_input.py +++ b/py-polars/polars/utils/_parse_expr_input.py @@ -1,11 +1,13 @@ from __future__ import annotations from datetime import date, datetime, time, timedelta -from typing import TYPE_CHECKING, Iterable +from typing import TYPE_CHECKING, Any, Iterable import polars._reexport as pl from polars import functions as F from polars.exceptions import ComputeError +from polars.utils._wrap import wrap_expr +from polars.utils.deprecation import issue_deprecation_warning if TYPE_CHECKING: from polars import Expr @@ -126,6 +128,35 @@ def parse_as_expression( return expr._pyexpr +def parse_when_constraint_expressions( + *predicates: IntoExpr | Iterable[IntoExpr], + **constraints: Any, +) -> PyExpr: + all_predicates: list[pl.Expr] = [] + for p in predicates: + all_predicates.extend(wrap_expr(x) for x in parse_as_list_of_expressions(p)) + + if "condition" in constraints: + if isinstance(constraints["condition"], pl.Expr): + all_predicates.append(constraints.pop("condition")) + issue_deprecation_warning( + "`when` no longer takes a 'condition' parameter.\n" + "To silence this warning you should omit the keyword and pass " + "as a positional argument instead.", + version="0.19.16", + ) + + all_predicates.extend(F.col(name).eq(value) for name, value in constraints.items()) + if not all_predicates: + raise ValueError("No predicates or constraints provided to `when`.") + + return ( + F.all_horizontal(*all_predicates) + if len(all_predicates) > 1 + else all_predicates[0] + )._pyexpr + + def _structify_expression(expr: Expr) -> Expr: unaliased_expr = expr.meta.undo_aliases() if unaliased_expr.meta.has_multiple_outputs(): diff --git a/py-polars/tests/unit/functions/test_whenthen.py b/py-polars/tests/unit/functions/test_whenthen.py index 5e25f6639ab0..ad4ecb7cb5bd 100644 --- a/py-polars/tests/unit/functions/test_whenthen.py +++ b/py-polars/tests/unit/functions/test_whenthen.py @@ -484,9 +484,15 @@ def test_when_then_binary_op_predicate_agg_12526() -> None: actual = df.group_by("a").agg( col=( - pl.when(pl.col("a").shift(1) > 2) + pl.when( + pl.col("a").shift(1) > 2, + pl.col("b").is_not_null(), + ) .then(pl.lit("abc")) - .when(pl.col("a").shift(1) > 1) + .when( + pl.col("a").shift(1) > 1, + pl.col("b").is_not_null(), + ) .then(pl.lit("def")) .otherwise(pl.lit(None)) .first() @@ -494,3 +500,40 @@ def test_when_then_binary_op_predicate_agg_12526() -> None: ) assert_frame_equal(expect, actual) + + +def test_when_then_deprecation() -> None: + df = pl.DataFrame({"foo": [5, 4, 3], "bar": [2, 1, 0]}) + for param_name in ("expr", "condition"): + with pytest.warns(DeprecationWarning, match="pass as a positional argument"): + df.select(pl.when(**{param_name: pl.col("bar") >= 0}).then(99)) + + +def test_when_predicates_kwargs() -> None: + df = pl.DataFrame( + { + "x": [10, 20, 30, 40], + "y": [15, -20, None, 1], + "z": ["a", "b", "c", "d"], + } + ) + assert_frame_equal( # kwargs only + df.select(matched=pl.when(x=30, z="c").then(True).otherwise(False)), + pl.DataFrame({"matched": [False, False, True, False]}), + ) + assert_frame_equal( # mixed predicates & kwargs + df.select(matched=pl.when(pl.col("x") < 30, z="b").then(True).otherwise(False)), + pl.DataFrame({"matched": [False, True, False, False]}), + ) + assert_frame_equal( # chained when/then with mixed predicates/kwargs + df.select( + misc=pl.when(pl.col("x") > 50) + .then(pl.lit("x>50")) + .when(y=1) + .then(pl.lit("y=1")) + .when(pl.col("z").is_in(["a", "b"]), pl.col("y") < 0) + .then(pl.lit("z in (a|b), y<0")) + .otherwise(pl.lit("?")) + ), + pl.DataFrame({"misc": ["?", "z in (a|b), y<0", "?", "y=1"]}), + )