diff --git a/narwhals/_pandas_like/namespace.py b/narwhals/_pandas_like/namespace.py index 06af399d3..51c82648f 100644 --- a/narwhals/_pandas_like/namespace.py +++ b/narwhals/_pandas_like/namespace.py @@ -258,13 +258,21 @@ def concat( ) raise NotImplementedError - def when(self, *predicates: IntoPandasLikeExpr, **constraints: Any) -> PandasWhen: # noqa: ARG002 + def when(self, *predicates: IntoPandasLikeExpr, **constraints: Any) -> PandasWhen: # noqa: ARG002 plx = self.__class__(self._implementation, self._backend_version) condition = plx.all_horizontal(*predicates) return PandasWhen(condition, self._implementation, self._backend_version) + class PandasWhen: - def __init__(self, condition: PandasLikeExpr, implementation: Implementation, backend_version: tuple[int, ...], then_value: Any = None, otherise_value: Any = None) -> None: + def __init__( + self, + condition: PandasLikeExpr, + implementation: Implementation, + backend_version: tuple[int, ...], + then_value: Any = None, + otherise_value: Any = None, + ) -> None: self._implementation = implementation self._backend_version = backend_version self._condition = condition @@ -274,18 +282,21 @@ def __init__(self, condition: PandasLikeExpr, implementation: Implementation, ba def __call__(self, df: PandasLikeDataFrame) -> list[PandasLikeSeries]: from narwhals._pandas_like.namespace import PandasLikeNamespace - plx = PandasLikeNamespace(implementation=self._implementation, backend_version=self._backend_version) + plx = PandasLikeNamespace( + implementation=self._implementation, backend_version=self._backend_version + ) condition = self._condition._call(df)[0] - value_series = plx._create_broadcast_series_from_scalar(self._then_value, condition) - otherwise_series = plx._create_broadcast_series_from_scalar(self._otherwise_value, condition) - return [ - value_series.zip_with(condition, otherwise_series) - ] + value_series = plx._create_broadcast_series_from_scalar( + self._then_value, condition + ) + otherwise_series = plx._create_broadcast_series_from_scalar( + self._otherwise_value, condition + ) + return [value_series.zip_with(condition, otherwise_series)] def then(self, value: Any) -> PandasThen: - self._then_value = value return PandasThen( @@ -298,8 +309,8 @@ def then(self, value: Any) -> PandasThen: backend_version=self._condition._backend_version, ) -class PandasThen(PandasLikeExpr): +class PandasThen(PandasLikeExpr): def __init__( self, call: PandasWhen, diff --git a/narwhals/expr.py b/narwhals/expr.py index 5aba8cc61..6eb0dd135 100644 --- a/narwhals/expr.py +++ b/narwhals/expr.py @@ -3605,6 +3605,7 @@ def sum_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: ) ) + class When: def __init__(self, condition: Expr) -> None: self._condition = condition @@ -3614,14 +3615,16 @@ def __init__(self, condition: Expr) -> None: def then(self, value: Any) -> Then: return Then(lambda plx: plx.when(self._condition._call(plx)).then(value)) + class Then(Expr): - def __init__(self, call) -> None: # noqa: ANN001 + def __init__(self, call) -> None: # noqa: ANN001 self._call = call def otherwise(self, value: Any) -> Expr: return Expr(lambda plx: self._call(plx).otherwise(value)) -def when(*predicates: IntoExpr | Iterable[IntoExpr], **constraints: Any) -> When: # noqa: ARG001 + +def when(*predicates: IntoExpr | Iterable[IntoExpr], **constraints: Any) -> When: # noqa: ARG001 """ Start a `when-then-otherwise` expression. Expression similar to an `if-else` statement in Python. Always initiated by a `pl.when().then()`., and optionally followed by chaining one or more `.when().then()` statements. @@ -3646,7 +3649,10 @@ def when(*predicates: IntoExpr | Iterable[IntoExpr], **constraints: Any) -> When >>> @nw.narwhalify ... def func(df_any): ... from narwhals.expr import when - ... return df_any.with_columns(when(nw.col("a") < 3).then(5).otherwise(6).alias("a_when")) + ... + ... return df_any.with_columns( + ... when(nw.col("a") < 3).then(5).otherwise(6).alias("a_when") + ... ) We can then pass either pandas or polars to `func`: diff --git a/tests/test_when.py b/tests/test_when.py index cc95cc347..90df13180 100644 --- a/tests/test_when.py +++ b/tests/test_when.py @@ -31,6 +31,7 @@ def test_when(request: Any, constructor: Any) -> None: } compare_dicts(result, expected) + def test_when_otherwise(request: Any, constructor: Any) -> None: if "pyarrow_table" in str(constructor): request.applymarker(pytest.mark.xfail)