From 36c4fcf220bc5b6eeced26dbe356c6412543084e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dea=20Mar=C3=ADa=20L=C3=A9on?= Date: Sat, 25 May 2024 19:00:07 +0200 Subject: [PATCH 1/2] fix pre-commit --- docs/api-reference/series.md | 1 + narwhals/_pandas_like/series.py | 5 ++++ narwhals/series.py | 49 +++++++++++++++++++++++++++++++++ utils/check_api_reference.py | 1 + 4 files changed, 56 insertions(+) diff --git a/docs/api-reference/series.md b/docs/api-reference/series.md index d65262a5b..e5ab96e7e 100644 --- a/docs/api-reference/series.md +++ b/docs/api-reference/series.md @@ -40,5 +40,6 @@ - to_pandas - unique - value_counts + - zip_with show_source: false show_bases: false diff --git a/narwhals/_pandas_like/series.py b/narwhals/_pandas_like/series.py index 74c641a70..fe7428cb7 100644 --- a/narwhals/_pandas_like/series.py +++ b/narwhals/_pandas_like/series.py @@ -473,6 +473,11 @@ def value_counts(self: Self, *, sort: bool = False, parallel: bool = False) -> A implementation=self._implementation, ) + def zip_with(self: Self, mask: Any, other: Any) -> PandasSeries: + ser = self._series + res = ser.where(mask._series, other._series) + return self._from_series(res) + @property def str(self) -> PandasSeriesStringNamespace: return PandasSeriesStringNamespace(self) diff --git a/narwhals/series.py b/narwhals/series.py index 03f2e5ba9..cdc189f8e 100644 --- a/narwhals/series.py +++ b/narwhals/series.py @@ -1482,6 +1482,55 @@ def value_counts( return DataFrame(self._series.value_counts(sort=sort, parallel=parallel)) + def zip_with(self, mask: Any, other: Any) -> Self: + """ + Take values from self or other based on the given mask. Where mask evaluates true, take values from self. Where mask evaluates false, take values from other. + + Examples: + >>> import narwhals as nw + >>> import pandas as pd + >>> import polars as pl + >>> s1_pl = pl.Series([1, 2, 3, 4, 5]) + >>> s2_pl = pl.Series([5, 4, 3, 2, 1]) + >>> mask_pl = pl.Series([True, False, True, False, True]) + >>> s1_pd = pd.Series([1, 2, 3, 4, 5]) + >>> s2_pd = pd.Series([5, 4, 3, 2, 1]) + >>> mask_pd = pd.Series([True, False, True, False, True]) + + Let's define a dataframe-agnostic function: + + >>> def func(s1_any, mask_any, s2_any): + ... s1 = nw.from_native(s1_any, allow_series=True) + ... mask = nw.from_native(mask_any, series_only=True) + ... s2 = nw.from_native(s2_any, series_only=True) + ... s = s1.zip_with(mask, s2) + ... return nw.to_native(s) + + We can then pass either pandas or Polars to `func`: + + >>> func(s1_pl, mask_pl, s2_pl) # doctest: +NORMALIZE_WHITESPACE + shape: (5,) + Series: '' [i64] + [ + 1 + 4 + 3 + 2 + 5 + ] + >>> func(s1_pd, mask_pd, s2_pd) + 0 1 + 1 4 + 2 3 + 3 2 + 4 5 + dtype: int64 + """ + + return self._from_series( + self._series.zip_with(self._extract_native(mask), self._extract_native(other)) + ) + @property def str(self) -> SeriesStringNamespace: return SeriesStringNamespace(self) diff --git a/utils/check_api_reference.py b/utils/check_api_reference.py index e1f74ac2b..f1b986fe2 100644 --- a/utils/check_api_reference.py +++ b/utils/check_api_reference.py @@ -145,6 +145,7 @@ "is_empty", "is_sorted", "value_counts", + "zip_with", } ) ): From d2891f80f69e9e60b63df0b70552d8bfc5c5d170 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dea=20Mar=C3=ADa=20L=C3=A9on?= Date: Sun, 26 May 2024 11:19:24 +0200 Subject: [PATCH 2/2] correct test --- tests/test_series.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/test_series.py b/tests/test_series.py index fa9d629ab..d54aabbda 100644 --- a/tests/test_series.py +++ b/tests/test_series.py @@ -416,3 +416,19 @@ def test_is_sorted_invalid(df_raw: Any) -> None: with pytest.raises(TypeError): series.is_sorted(descending="invalid_type") # type: ignore[arg-type] + + +@pytest.mark.parametrize( + ("df_raw", "mask", "expected"), + [ + (df_pandas, pd.Series([True, False, True]), pd.Series([1, 4, 2])), + (df_polars, pl.Series([True, False, True]), pl.Series([1, 4, 2])), + ], +) +def test_zip_with(df_raw: Any, mask: Any, expected: Any) -> None: + series1 = nw.Series(df_raw["a"]) + series2 = nw.Series(df_raw["b"]) + mask = nw.Series(mask) + result = series1.zip_with(mask, series2) + expected = nw.Series(expected) + assert result == expected