Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jul 23, 2024
1 parent 74937ea commit 5b030d6
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 13 deletions.
31 changes: 21 additions & 10 deletions narwhals/_pandas_like/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,13 +258,21 @@ def concat(
)
raise NotImplementedError

def when(self, *predicates: IntoPandasLikeExpr, **constraints: Any) -> PandasWhen: # noqa: ARG002
def when(self, *predicates: IntoPandasLikeExpr, **constraints: Any) -> PandasWhen: # noqa: ARG002
plx = self.__class__(self._implementation, self._backend_version)
condition = plx.all_horizontal(*predicates)
return PandasWhen(condition, self._implementation, self._backend_version)


class PandasWhen:
def __init__(self, condition: PandasLikeExpr, implementation: Implementation, backend_version: tuple[int, ...], then_value: Any = None, otherise_value: Any = None) -> None:
def __init__(
self,
condition: PandasLikeExpr,
implementation: Implementation,
backend_version: tuple[int, ...],
then_value: Any = None,
otherise_value: Any = None,
) -> None:
self._implementation = implementation
self._backend_version = backend_version
self._condition = condition
Expand All @@ -274,18 +282,21 @@ def __init__(self, condition: PandasLikeExpr, implementation: Implementation, ba
def __call__(self, df: PandasLikeDataFrame) -> list[PandasLikeSeries]:
from narwhals._pandas_like.namespace import PandasLikeNamespace

plx = PandasLikeNamespace(implementation=self._implementation, backend_version=self._backend_version)
plx = PandasLikeNamespace(
implementation=self._implementation, backend_version=self._backend_version
)

condition = self._condition._call(df)[0]

value_series = plx._create_broadcast_series_from_scalar(self._then_value, condition)
otherwise_series = plx._create_broadcast_series_from_scalar(self._otherwise_value, condition)
return [
value_series.zip_with(condition, otherwise_series)
]
value_series = plx._create_broadcast_series_from_scalar(
self._then_value, condition
)
otherwise_series = plx._create_broadcast_series_from_scalar(
self._otherwise_value, condition
)
return [value_series.zip_with(condition, otherwise_series)]

def then(self, value: Any) -> PandasThen:

self._then_value = value

return PandasThen(
Expand All @@ -298,8 +309,8 @@ def then(self, value: Any) -> PandasThen:
backend_version=self._condition._backend_version,
)

class PandasThen(PandasLikeExpr):

class PandasThen(PandasLikeExpr):
def __init__(
self,
call: PandasWhen,
Expand Down
12 changes: 9 additions & 3 deletions narwhals/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3605,6 +3605,7 @@ def sum_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr:
)
)


class When:
def __init__(self, condition: Expr) -> None:
self._condition = condition
Expand All @@ -3614,14 +3615,16 @@ def __init__(self, condition: Expr) -> None:
def then(self, value: Any) -> Then:
return Then(lambda plx: plx.when(self._condition._call(plx)).then(value))


class Then(Expr):
def __init__(self, call) -> None: # noqa: ANN001
def __init__(self, call) -> None: # noqa: ANN001
self._call = call

def otherwise(self, value: Any) -> Expr:
return Expr(lambda plx: self._call(plx).otherwise(value))

def when(*predicates: IntoExpr | Iterable[IntoExpr], **constraints: Any) -> When: # noqa: ARG001

def when(*predicates: IntoExpr | Iterable[IntoExpr], **constraints: Any) -> When: # noqa: ARG001
"""
Start a `when-then-otherwise` expression.
Expression similar to an `if-else` statement in Python. Always initiated by a `pl.when(<condition>).then(<value if condition>)`., and optionally followed by chaining one or more `.when(<condition>).then(<value>)` statements.
Expand All @@ -3646,7 +3649,10 @@ def when(*predicates: IntoExpr | Iterable[IntoExpr], **constraints: Any) -> When
>>> @nw.narwhalify
... def func(df_any):
... from narwhals.expr import when
... return df_any.with_columns(when(nw.col("a") < 3).then(5).otherwise(6).alias("a_when"))
...
... return df_any.with_columns(
... when(nw.col("a") < 3).then(5).otherwise(6).alias("a_when")
... )
We can then pass either pandas or polars to `func`:
Expand Down
1 change: 1 addition & 0 deletions tests/test_when.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def test_when(request: Any, constructor: Any) -> None:
}
compare_dicts(result, expected)


def test_when_otherwise(request: Any, constructor: Any) -> None:
if "pyarrow_table" in str(constructor):
request.applymarker(pytest.mark.xfail)
Expand Down

0 comments on commit 5b030d6

Please sign in to comment.