Skip to content

Commit

Permalink
feat: allow mapping in replace_strict method (#1340)
Browse files Browse the repository at this point in the history
  • Loading branch information
FBruzzesi authored Nov 9, 2024
1 parent 7ecd092 commit cc84860
Show file tree
Hide file tree
Showing 10 changed files with 142 additions and 36 deletions.
2 changes: 1 addition & 1 deletion narwhals/_arrow/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ def unique(self: Self, *, maintain_order: bool = False) -> Self:
return reuse_series_implementation(self, "unique", maintain_order=maintain_order)

def replace_strict(
self: Self, old: Sequence[Any], new: Sequence[Any], *, return_dtype: DType
self: Self, old: Sequence[Any], new: Sequence[Any], *, return_dtype: DType | None
) -> Self:
return reuse_series_implementation(
self, "replace_strict", old, new, return_dtype=return_dtype
Expand Down
8 changes: 4 additions & 4 deletions narwhals/_arrow/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,16 +656,16 @@ def unique(self: Self, *, maintain_order: bool = False) -> ArrowSeries:
return self._from_native_series(pc.unique(self._native_series))

def replace_strict(
self, old: Sequence[Any], new: Sequence[Any], *, return_dtype: DType
self, old: Sequence[Any], new: Sequence[Any], *, return_dtype: DType | None
) -> ArrowSeries:
import pyarrow as pa # ignore-banned-import
import pyarrow.compute as pc # ignore-banned-import

# https://stackoverflow.com/a/79111029/4451315
idxs = pc.index_in(self._native_series, pa.array(old))
result_native = pc.take(pa.array(new), idxs).cast(
narwhals_to_native_dtype(return_dtype, self._dtypes)
)
result_native = pc.take(pa.array(new), idxs)
if return_dtype is not None:
result_native.cast(narwhals_to_native_dtype(return_dtype, self._dtypes))
result = self._from_native_series(result_native)
if result.is_null().sum() != self.is_null().sum():
msg = (
Expand Down
2 changes: 1 addition & 1 deletion narwhals/_dask/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,7 @@ def head(self) -> NoReturn:
raise NotImplementedError(msg)

def replace_strict(
self, old: Sequence[Any], new: Sequence[Any], *, return_dtype: DType
self, old: Sequence[Any], new: Sequence[Any], *, return_dtype: DType | None
) -> Self:
msg = "`replace_strict` is not yet supported for Dask expressions"
raise NotImplementedError(msg)
Expand Down
2 changes: 1 addition & 1 deletion narwhals/_pandas_like/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def drop_nulls(self) -> Self:
return reuse_series_implementation(self, "drop_nulls")

def replace_strict(
self, old: Sequence[Any], new: Sequence[Any], *, return_dtype: DType
self, old: Sequence[Any], new: Sequence[Any], *, return_dtype: DType | None
) -> Self:
return reuse_series_implementation(
self, "replace_strict", old, new, return_dtype=return_dtype
Expand Down
18 changes: 11 additions & 7 deletions narwhals/_pandas_like/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,15 +492,19 @@ def shift(self, n: int) -> PandasLikeSeries:
return self._from_native_series(self._native_series.shift(n))

def replace_strict(
self, old: Sequence[Any], new: Sequence[Any], *, return_dtype: DType
self, old: Sequence[Any], new: Sequence[Any], *, return_dtype: DType | None
) -> PandasLikeSeries:
tmp_name = f"{self.name}_tmp"
dtype = narwhals_to_native_dtype(
return_dtype,
self._native_series.dtype,
self._implementation,
self._backend_version,
self._dtypes,
dtype = (
narwhals_to_native_dtype(
return_dtype,
self._native_series.dtype,
self._implementation,
self._backend_version,
self._dtypes,
)
if return_dtype
else None
)
other = self.__native_namespace__().DataFrame(
{
Expand Down
6 changes: 4 additions & 2 deletions narwhals/_polars/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,12 @@ def cast(self, dtype: DType) -> Self:
return self._from_native_expr(expr.cast(dtype))

def replace_strict(
self, old: Sequence[Any], new: Sequence[Any], *, return_dtype: DType
self, old: Sequence[Any], new: Sequence[Any], *, return_dtype: DType | None
) -> Self:
expr = self._native_expr
return_dtype = narwhals_to_native_dtype(return_dtype, self._dtypes)
return_dtype = (
narwhals_to_native_dtype(return_dtype, self._dtypes) if return_dtype else None
)
if self._backend_version < (1,):
msg = f"`replace_strict` is only available in Polars>=1.0, found version {self._backend_version}"
raise NotImplementedError(msg)
Expand Down
6 changes: 4 additions & 2 deletions narwhals/_polars/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,12 @@ def cast(self, dtype: DType) -> Self:
return self._from_native_series(ser.cast(dtype))

def replace_strict(
self, old: Sequence[Any], new: Sequence[Any], *, return_dtype: DType
self, old: Sequence[Any], new: Sequence[Any], *, return_dtype: DType | None
) -> Self:
ser = self._native_series
dtype = narwhals_to_native_dtype(return_dtype, self._dtypes)
dtype = (
narwhals_to_native_dtype(return_dtype, self._dtypes) if return_dtype else None
)
if self._backend_version < (1,):
msg = f"`replace_strict` is only available in Polars>=1.0, found version {self._backend_version}"
raise NotImplementedError(msg)
Expand Down
30 changes: 23 additions & 7 deletions narwhals/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Generic
from typing import Iterable
from typing import Literal
from typing import Mapping
from typing import Sequence
from typing import TypeVar

Expand Down Expand Up @@ -975,18 +976,25 @@ def shift(self, n: int) -> Self:
return self.__class__(lambda plx: self._call(plx).shift(n))

def replace_strict(
self, old: Sequence[Any], new: Sequence[Any], *, return_dtype: DType | type[DType]
self,
old: Sequence[Any] | Mapping[Any, Any],
new: Sequence[Any] | None = None,
*,
return_dtype: DType | type[DType] | None = None,
) -> Self:
"""
Replace old values with new values.
Replace all values by different values.
This function must replace all non-null input values (else it raises an error),
and the return dtype must be specified.
This function must replace all non-null input values (else it raises an error).
Arguments:
old: Sequence of old values to replace.
new: Sequence of new values to replace.
return_dtype: Return dtype.
old: Sequence of values to replace. It also accepts a mapping of values to
their replacement as syntactic sugar for
`replace_all(old=list(mapping.keys()), new=list(mapping.values()))`.
new: Sequence of values to replace by. Length must match the length of `old`.
return_dtype: The data type of the resulting expression. If set to `None`
(default), the data type is determined automatically based on the other
inputs.
Examples:
>>> import narwhals as nw
Expand Down Expand Up @@ -1037,6 +1045,14 @@ def replace_strict(
a: [[3,0,1,2]]
b: [["three","zero","one","two"]]
"""
if new is None:
if not isinstance(old, Mapping):
msg = "`new` argument is required if `old` argument is not a Mapping type"
raise TypeError(msg)

new = list(old.values())
old = list(old.keys())

return self.__class__(
lambda plx: self._call(plx).replace_strict(
old, new, return_dtype=return_dtype
Expand Down
30 changes: 23 additions & 7 deletions narwhals/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Generic
from typing import Iterator
from typing import Literal
from typing import Mapping
from typing import Sequence
from typing import TypeVar
from typing import overload
Expand Down Expand Up @@ -1350,18 +1351,25 @@ def rename(self, name: str) -> Self:
return self.alias(name=name)

def replace_strict(
self, old: Sequence[Any], new: Sequence[Any], *, return_dtype: DType | type[DType]
self: Self,
old: Sequence[Any] | Mapping[Any, Any],
new: Sequence[Any] | None = None,
*,
return_dtype: DType | type[DType] | None = None,
) -> Self:
"""
Replace old values with values.
Replace all values by different values.
This function must replace all non-null input values (else it raises an error),
and the return dtype must be specified.
This function must replace all non-null input values (else it raises an error).
Arguments:
old: Sequence of old values to replace.
new: Sequence of new values to replace.
return_dtype: Return dtype.
old: Sequence of values to replace. It also accepts a mapping of values to
their replacement as syntactic sugar for
`replace_all(old=list(mapping.keys()), new=list(mapping.values()))`.
new: Sequence of values to replace by. Length must match the length of `old`.
return_dtype: The data type of the resulting expression. If set to `None`
(default), the data type is determined automatically based on the other
inputs.
Examples:
>>> import narwhals as nw
Expand Down Expand Up @@ -1408,6 +1416,14 @@ def replace_strict(
]
]
"""
if new is None:
if not isinstance(old, Mapping):
msg = "`new` argument is required if `old` argument is not a Mapping type"
raise TypeError(msg)

new = list(old.values())
old = list(old.keys())

return self._from_compliant_series(
self._compliant_series.replace_strict(old, new, return_dtype=return_dtype)
)
Expand Down
74 changes: 70 additions & 4 deletions tests/expr_and_series/replace_strict_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import pytest

import narwhals.stable.v1 as nw
Expand All @@ -8,17 +10,23 @@
from tests.utils import ConstructorEager
from tests.utils import assert_equal_data

if TYPE_CHECKING:
from narwhals.dtypes import DType


@pytest.mark.skipif(
POLARS_VERSION < (1, 0), reason="replace_strict only available after 1.0"
)
def test_replace_strict(constructor: Constructor, request: pytest.FixtureRequest) -> None:
@pytest.mark.parametrize("return_dtype", [nw.String(), None])
def test_replace_strict(
constructor: Constructor, request: pytest.FixtureRequest, return_dtype: DType | None
) -> None:
if "dask" in str(constructor):
request.applymarker(pytest.mark.xfail)
df = nw.from_native(constructor({"a": [1, 2, 3]}))
result = df.select(
nw.col("a").replace_strict(
[1, 2, 3], ["one", "two", "three"], return_dtype=nw.String
[1, 2, 3], ["one", "two", "three"], return_dtype=return_dtype
)
)
assert_equal_data(result, {"a": ["one", "two", "three"]})
Expand All @@ -27,10 +35,15 @@ def test_replace_strict(constructor: Constructor, request: pytest.FixtureRequest
@pytest.mark.skipif(
POLARS_VERSION < (1, 0), reason="replace_strict only available after 1.0"
)
def test_replace_strict_series(constructor_eager: ConstructorEager) -> None:
@pytest.mark.parametrize("return_dtype", [nw.String(), None])
def test_replace_strict_series(
constructor_eager: ConstructorEager, return_dtype: DType | None
) -> None:
df = nw.from_native(constructor_eager({"a": [1, 2, 3]}))
result = df.select(
df["a"].replace_strict([1, 2, 3], ["one", "two", "three"], return_dtype=nw.String)
df["a"].replace_strict(
[1, 2, 3], ["one", "two", "three"], return_dtype=return_dtype
)
)
assert_equal_data(result, {"a": ["one", "two", "three"]})

Expand All @@ -54,3 +67,56 @@ def test_replace_non_full(
else:
with pytest.raises((ValueError, PolarsError)):
df.select(nw.col("a").replace_strict([1, 3], [3, 4], return_dtype=nw.Int64))


@pytest.mark.skipif(
POLARS_VERSION < (1, 0), reason="replace_strict only available after 1.0"
)
def test_replace_strict_mapping(
constructor: Constructor, request: pytest.FixtureRequest
) -> None:
if "dask" in str(constructor):
request.applymarker(pytest.mark.xfail)

df = nw.from_native(constructor({"a": [1, 2, 3]}))
result = df.select(
nw.col("a").replace_strict(
{1: "one", 2: "two", 3: "three"}, return_dtype=nw.String()
)
)
assert_equal_data(result, {"a": ["one", "two", "three"]})


@pytest.mark.skipif(
POLARS_VERSION < (1, 0), reason="replace_strict only available after 1.0"
)
def test_replace_strict_series_mapping(constructor_eager: ConstructorEager) -> None:
df = nw.from_native(constructor_eager({"a": [1, 2, 3]}))
result = df.select(
df["a"].replace_strict({1: "one", 2: "two", 3: "three"}, return_dtype=nw.String())
)
assert_equal_data(result, {"a": ["one", "two", "three"]})


@pytest.mark.skipif(
POLARS_VERSION < (1, 0), reason="replace_strict only available after 1.0"
)
def test_replace_strict_invalid(constructor: Constructor) -> None:
df = nw.from_native(constructor({"a": [1, 2, 3]}))
with pytest.raises(
TypeError,
match="`new` argument is required if `old` argument is not a Mapping type",
):
df.select(nw.col("a").replace_strict(old=[1, 2, 3]))


@pytest.mark.skipif(
POLARS_VERSION < (1, 0), reason="replace_strict only available after 1.0"
)
def test_replace_strict_series_invalid(constructor_eager: ConstructorEager) -> None:
df = nw.from_native(constructor_eager({"a": [1, 2, 3]}))
with pytest.raises(
TypeError,
match="`new` argument is required if `old` argument is not a Mapping type",
):
df["a"].replace_strict([1, 2, 3])

0 comments on commit cc84860

Please sign in to comment.