diff --git a/narwhals/_arrow/expr.py b/narwhals/_arrow/expr.py index 592fe615d..fd0863d17 100644 --- a/narwhals/_arrow/expr.py +++ b/narwhals/_arrow/expr.py @@ -322,7 +322,7 @@ def unique(self: Self, *, maintain_order: bool = False) -> Self: return reuse_series_implementation(self, "unique", maintain_order=maintain_order) def replace_strict( - self: Self, old: Sequence[Any], new: Sequence[Any], *, return_dtype: DType + self: Self, old: Sequence[Any], new: Sequence[Any], *, return_dtype: DType | None ) -> Self: return reuse_series_implementation( self, "replace_strict", old, new, return_dtype=return_dtype diff --git a/narwhals/_arrow/series.py b/narwhals/_arrow/series.py index 15f64110a..a43e5cbcc 100644 --- a/narwhals/_arrow/series.py +++ b/narwhals/_arrow/series.py @@ -656,16 +656,16 @@ def unique(self: Self, *, maintain_order: bool = False) -> ArrowSeries: return self._from_native_series(pc.unique(self._native_series)) def replace_strict( - self, old: Sequence[Any], new: Sequence[Any], *, return_dtype: DType + self, old: Sequence[Any], new: Sequence[Any], *, return_dtype: DType | None ) -> ArrowSeries: import pyarrow as pa # ignore-banned-import import pyarrow.compute as pc # ignore-banned-import # https://stackoverflow.com/a/79111029/4451315 idxs = pc.index_in(self._native_series, pa.array(old)) - result_native = pc.take(pa.array(new), idxs).cast( - narwhals_to_native_dtype(return_dtype, self._dtypes) - ) + result_native = pc.take(pa.array(new), idxs) + if return_dtype is not None: + result_native.cast(narwhals_to_native_dtype(return_dtype, self._dtypes)) result = self._from_native_series(result_native) if result.is_null().sum() != self.is_null().sum(): msg = ( diff --git a/narwhals/_dask/expr.py b/narwhals/_dask/expr.py index 3a5e69af3..1dfa47b27 100644 --- a/narwhals/_dask/expr.py +++ b/narwhals/_dask/expr.py @@ -479,7 +479,7 @@ def head(self) -> NoReturn: raise NotImplementedError(msg) def replace_strict( - self, old: Sequence[Any], new: Sequence[Any], *, return_dtype: DType + self, old: Sequence[Any], new: Sequence[Any], *, return_dtype: DType | None ) -> Self: msg = "`replace_strict` is not yet supported for Dask expressions" raise NotImplementedError(msg) diff --git a/narwhals/_pandas_like/expr.py b/narwhals/_pandas_like/expr.py index a3c3aae07..edbe9a488 100644 --- a/narwhals/_pandas_like/expr.py +++ b/narwhals/_pandas_like/expr.py @@ -274,7 +274,7 @@ def drop_nulls(self) -> Self: return reuse_series_implementation(self, "drop_nulls") def replace_strict( - self, old: Sequence[Any], new: Sequence[Any], *, return_dtype: DType + self, old: Sequence[Any], new: Sequence[Any], *, return_dtype: DType | None ) -> Self: return reuse_series_implementation( self, "replace_strict", old, new, return_dtype=return_dtype diff --git a/narwhals/_pandas_like/series.py b/narwhals/_pandas_like/series.py index c7e7d67b9..558e8961e 100644 --- a/narwhals/_pandas_like/series.py +++ b/narwhals/_pandas_like/series.py @@ -492,15 +492,19 @@ def shift(self, n: int) -> PandasLikeSeries: return self._from_native_series(self._native_series.shift(n)) def replace_strict( - self, old: Sequence[Any], new: Sequence[Any], *, return_dtype: DType + self, old: Sequence[Any], new: Sequence[Any], *, return_dtype: DType | None ) -> PandasLikeSeries: tmp_name = f"{self.name}_tmp" - dtype = narwhals_to_native_dtype( - return_dtype, - self._native_series.dtype, - self._implementation, - self._backend_version, - self._dtypes, + dtype = ( + narwhals_to_native_dtype( + return_dtype, + self._native_series.dtype, + self._implementation, + self._backend_version, + self._dtypes, + ) + if return_dtype + else None ) other = self.__native_namespace__().DataFrame( { diff --git a/narwhals/_polars/expr.py b/narwhals/_polars/expr.py index fb5319430..46b415ddc 100644 --- a/narwhals/_polars/expr.py +++ b/narwhals/_polars/expr.py @@ -48,10 +48,12 @@ def cast(self, dtype: DType) -> Self: return self._from_native_expr(expr.cast(dtype)) def replace_strict( - self, old: Sequence[Any], new: Sequence[Any], *, return_dtype: DType + self, old: Sequence[Any], new: Sequence[Any], *, return_dtype: DType | None ) -> Self: expr = self._native_expr - return_dtype = narwhals_to_native_dtype(return_dtype, self._dtypes) + return_dtype = ( + narwhals_to_native_dtype(return_dtype, self._dtypes) if return_dtype else None + ) if self._backend_version < (1,): msg = f"`replace_strict` is only available in Polars>=1.0, found version {self._backend_version}" raise NotImplementedError(msg) diff --git a/narwhals/_polars/series.py b/narwhals/_polars/series.py index 2abe02577..db33bcf82 100644 --- a/narwhals/_polars/series.py +++ b/narwhals/_polars/series.py @@ -106,10 +106,12 @@ def cast(self, dtype: DType) -> Self: return self._from_native_series(ser.cast(dtype)) def replace_strict( - self, old: Sequence[Any], new: Sequence[Any], *, return_dtype: DType + self, old: Sequence[Any], new: Sequence[Any], *, return_dtype: DType | None ) -> Self: ser = self._native_series - dtype = narwhals_to_native_dtype(return_dtype, self._dtypes) + dtype = ( + narwhals_to_native_dtype(return_dtype, self._dtypes) if return_dtype else None + ) if self._backend_version < (1,): msg = f"`replace_strict` is only available in Polars>=1.0, found version {self._backend_version}" raise NotImplementedError(msg) diff --git a/narwhals/expr.py b/narwhals/expr.py index 2d43df716..de0546c7c 100644 --- a/narwhals/expr.py +++ b/narwhals/expr.py @@ -6,6 +6,7 @@ from typing import Generic from typing import Iterable from typing import Literal +from typing import Mapping from typing import Sequence from typing import TypeVar @@ -975,18 +976,25 @@ def shift(self, n: int) -> Self: return self.__class__(lambda plx: self._call(plx).shift(n)) def replace_strict( - self, old: Sequence[Any], new: Sequence[Any], *, return_dtype: DType | type[DType] + self, + old: Sequence[Any] | Mapping[Any, Any], + new: Sequence[Any] | None = None, + *, + return_dtype: DType | type[DType] | None = None, ) -> Self: """ - Replace old values with new values. + Replace all values by different values. - This function must replace all non-null input values (else it raises an error), - and the return dtype must be specified. + This function must replace all non-null input values (else it raises an error). Arguments: - old: Sequence of old values to replace. - new: Sequence of new values to replace. - return_dtype: Return dtype. + old: Sequence of values to replace. It also accepts a mapping of values to + their replacement as syntactic sugar for + `replace_all(old=list(mapping.keys()), new=list(mapping.values()))`. + new: Sequence of values to replace by. Length must match the length of `old`. + return_dtype: The data type of the resulting expression. If set to `None` + (default), the data type is determined automatically based on the other + inputs. Examples: >>> import narwhals as nw @@ -1037,6 +1045,14 @@ def replace_strict( a: [[3,0,1,2]] b: [["three","zero","one","two"]] """ + if new is None: + if not isinstance(old, Mapping): + msg = "`new` argument is required if `old` argument is not a Mapping type" + raise TypeError(msg) + + new = list(old.values()) + old = list(old.keys()) + return self.__class__( lambda plx: self._call(plx).replace_strict( old, new, return_dtype=return_dtype diff --git a/narwhals/series.py b/narwhals/series.py index 7b179cce8..d7845655a 100644 --- a/narwhals/series.py +++ b/narwhals/series.py @@ -6,6 +6,7 @@ from typing import Generic from typing import Iterator from typing import Literal +from typing import Mapping from typing import Sequence from typing import TypeVar from typing import overload @@ -1350,18 +1351,25 @@ def rename(self, name: str) -> Self: return self.alias(name=name) def replace_strict( - self, old: Sequence[Any], new: Sequence[Any], *, return_dtype: DType | type[DType] + self: Self, + old: Sequence[Any] | Mapping[Any, Any], + new: Sequence[Any] | None = None, + *, + return_dtype: DType | type[DType] | None = None, ) -> Self: """ - Replace old values with values. + Replace all values by different values. - This function must replace all non-null input values (else it raises an error), - and the return dtype must be specified. + This function must replace all non-null input values (else it raises an error). Arguments: - old: Sequence of old values to replace. - new: Sequence of new values to replace. - return_dtype: Return dtype. + old: Sequence of values to replace. It also accepts a mapping of values to + their replacement as syntactic sugar for + `replace_all(old=list(mapping.keys()), new=list(mapping.values()))`. + new: Sequence of values to replace by. Length must match the length of `old`. + return_dtype: The data type of the resulting expression. If set to `None` + (default), the data type is determined automatically based on the other + inputs. Examples: >>> import narwhals as nw @@ -1408,6 +1416,14 @@ def replace_strict( ] ] """ + if new is None: + if not isinstance(old, Mapping): + msg = "`new` argument is required if `old` argument is not a Mapping type" + raise TypeError(msg) + + new = list(old.values()) + old = list(old.keys()) + return self._from_compliant_series( self._compliant_series.replace_strict(old, new, return_dtype=return_dtype) ) diff --git a/tests/expr_and_series/replace_strict_test.py b/tests/expr_and_series/replace_strict_test.py index 8362897a1..b1449af24 100644 --- a/tests/expr_and_series/replace_strict_test.py +++ b/tests/expr_and_series/replace_strict_test.py @@ -1,5 +1,7 @@ from __future__ import annotations +from typing import TYPE_CHECKING + import pytest import narwhals.stable.v1 as nw @@ -8,17 +10,23 @@ from tests.utils import ConstructorEager from tests.utils import assert_equal_data +if TYPE_CHECKING: + from narwhals.dtypes import DType + @pytest.mark.skipif( POLARS_VERSION < (1, 0), reason="replace_strict only available after 1.0" ) -def test_replace_strict(constructor: Constructor, request: pytest.FixtureRequest) -> None: +@pytest.mark.parametrize("return_dtype", [nw.String(), None]) +def test_replace_strict( + constructor: Constructor, request: pytest.FixtureRequest, return_dtype: DType | None +) -> None: if "dask" in str(constructor): request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor({"a": [1, 2, 3]})) result = df.select( nw.col("a").replace_strict( - [1, 2, 3], ["one", "two", "three"], return_dtype=nw.String + [1, 2, 3], ["one", "two", "three"], return_dtype=return_dtype ) ) assert_equal_data(result, {"a": ["one", "two", "three"]}) @@ -27,10 +35,15 @@ def test_replace_strict(constructor: Constructor, request: pytest.FixtureRequest @pytest.mark.skipif( POLARS_VERSION < (1, 0), reason="replace_strict only available after 1.0" ) -def test_replace_strict_series(constructor_eager: ConstructorEager) -> None: +@pytest.mark.parametrize("return_dtype", [nw.String(), None]) +def test_replace_strict_series( + constructor_eager: ConstructorEager, return_dtype: DType | None +) -> None: df = nw.from_native(constructor_eager({"a": [1, 2, 3]})) result = df.select( - df["a"].replace_strict([1, 2, 3], ["one", "two", "three"], return_dtype=nw.String) + df["a"].replace_strict( + [1, 2, 3], ["one", "two", "three"], return_dtype=return_dtype + ) ) assert_equal_data(result, {"a": ["one", "two", "three"]}) @@ -54,3 +67,56 @@ def test_replace_non_full( else: with pytest.raises((ValueError, PolarsError)): df.select(nw.col("a").replace_strict([1, 3], [3, 4], return_dtype=nw.Int64)) + + +@pytest.mark.skipif( + POLARS_VERSION < (1, 0), reason="replace_strict only available after 1.0" +) +def test_replace_strict_mapping( + constructor: Constructor, request: pytest.FixtureRequest +) -> None: + if "dask" in str(constructor): + request.applymarker(pytest.mark.xfail) + + df = nw.from_native(constructor({"a": [1, 2, 3]})) + result = df.select( + nw.col("a").replace_strict( + {1: "one", 2: "two", 3: "three"}, return_dtype=nw.String() + ) + ) + assert_equal_data(result, {"a": ["one", "two", "three"]}) + + +@pytest.mark.skipif( + POLARS_VERSION < (1, 0), reason="replace_strict only available after 1.0" +) +def test_replace_strict_series_mapping(constructor_eager: ConstructorEager) -> None: + df = nw.from_native(constructor_eager({"a": [1, 2, 3]})) + result = df.select( + df["a"].replace_strict({1: "one", 2: "two", 3: "three"}, return_dtype=nw.String()) + ) + assert_equal_data(result, {"a": ["one", "two", "three"]}) + + +@pytest.mark.skipif( + POLARS_VERSION < (1, 0), reason="replace_strict only available after 1.0" +) +def test_replace_strict_invalid(constructor: Constructor) -> None: + df = nw.from_native(constructor({"a": [1, 2, 3]})) + with pytest.raises( + TypeError, + match="`new` argument is required if `old` argument is not a Mapping type", + ): + df.select(nw.col("a").replace_strict(old=[1, 2, 3])) + + +@pytest.mark.skipif( + POLARS_VERSION < (1, 0), reason="replace_strict only available after 1.0" +) +def test_replace_strict_series_invalid(constructor_eager: ConstructorEager) -> None: + df = nw.from_native(constructor_eager({"a": [1, 2, 3]})) + with pytest.raises( + TypeError, + match="`new` argument is required if `old` argument is not a Mapping type", + ): + df["a"].replace_strict([1, 2, 3])