diff --git a/narwhals/_arrow/series.py b/narwhals/_arrow/series.py index bf0ca35eb..4d08886bd 100644 --- a/narwhals/_arrow/series.py +++ b/narwhals/_arrow/series.py @@ -668,9 +668,10 @@ def replace_strict( self, mapping: Mapping[Any, Any], *, return_dtype: DType ) -> ArrowSeries: import pyarrow as pa # ignore-banned-import + import pyarrow.compute as pc # ignore-banned-import # https://stackoverflow.com/a/79111029/4451315 - idxs = pa.compute.index_in(self._native_series, pa.array(list(mapping.keys()))) + idxs = pc.index_in(self._native_series, pa.array(list(mapping.keys()))) result_native = pa.compute.take(pa.array(list(mapping.values())), idxs).cast( narwhals_to_native_dtype(return_dtype, self._dtypes) ) diff --git a/narwhals/_polars/expr.py b/narwhals/_polars/expr.py index c443232f6..bf0082278 100644 --- a/narwhals/_polars/expr.py +++ b/narwhals/_polars/expr.py @@ -17,16 +17,21 @@ class PolarsExpr: - def __init__(self, expr: Any, dtypes: DTypes) -> None: + def __init__( + self, expr: Any, dtypes: DTypes, backend_version: tuple[int, ...] + ) -> None: self._native_expr = expr self._implementation = Implementation.POLARS self._dtypes = dtypes + self._backend_version = backend_version def __repr__(self) -> str: # pragma: no cover return "PolarsExpr" def _from_native_expr(self, expr: Any) -> Self: - return self.__class__(expr, dtypes=self._dtypes) + return self.__class__( + expr, dtypes=self._dtypes, backend_version=self._backend_version + ) def __getattr__(self, attr: str) -> Any: def func(*args: Any, **kwargs: Any) -> Any: @@ -45,6 +50,9 @@ def cast(self, dtype: DType) -> Self: def replace_strict(self, mapping: Mapping[Any, Any], *, return_dtype: DType) -> Self: expr = self._native_expr return_dtype = narwhals_to_native_dtype(return_dtype, self._dtypes) + if self._backend_version < (1,): + msg = f"`replace_strict` is only available in Polars>=1.0, found version {self._backend_version}" + raise NotImplementedError(msg) return self._from_native_expr( expr.replace_strict(mapping, return_dtype=return_dtype) ) diff --git a/narwhals/_polars/namespace.py b/narwhals/_polars/namespace.py index 4eb8451b7..4b28c50e4 100644 --- a/narwhals/_polars/namespace.py +++ b/narwhals/_polars/namespace.py @@ -33,7 +33,11 @@ def __getattr__(self, attr: str) -> Any: def func(*args: Any, **kwargs: Any) -> Any: args, kwargs = extract_args_kwargs(args, kwargs) # type: ignore[assignment] - return PolarsExpr(getattr(pl, attr)(*args, **kwargs), dtypes=self._dtypes) + return PolarsExpr( + getattr(pl, attr)(*args, **kwargs), + dtypes=self._dtypes, + backend_version=self._backend_version, + ) return func @@ -45,7 +49,9 @@ def nth(self, *indices: int) -> PolarsExpr: if self._backend_version < (1, 0, 0): # pragma: no cover msg = "`nth` is only supported for Polars>=1.0.0. Please use `col` for columns selection instead." raise AttributeError(msg) - return PolarsExpr(pl.nth(*indices), dtypes=self._dtypes) + return PolarsExpr( + pl.nth(*indices), dtypes=self._dtypes, backend_version=self._backend_version + ) def len(self) -> PolarsExpr: import polars as pl # ignore-banned-import() @@ -53,8 +59,14 @@ def len(self) -> PolarsExpr: from narwhals._polars.expr import PolarsExpr if self._backend_version < (0, 20, 5): # pragma: no cover - return PolarsExpr(pl.count().alias("len"), dtypes=self._dtypes) - return PolarsExpr(pl.len(), dtypes=self._dtypes) + return PolarsExpr( + pl.count().alias("len"), + dtypes=self._dtypes, + backend_version=self._backend_version, + ) + return PolarsExpr( + pl.len(), dtypes=self._dtypes, backend_version=self._backend_version + ) def concat( self, @@ -86,8 +98,11 @@ def lit(self, value: Any, dtype: DType | None = None) -> PolarsExpr: return PolarsExpr( pl.lit(value, dtype=narwhals_to_native_dtype(dtype, self._dtypes)), dtypes=self._dtypes, + backend_version=self._backend_version, ) - return PolarsExpr(pl.lit(value), dtypes=self._dtypes) + return PolarsExpr( + pl.lit(value), dtypes=self._dtypes, backend_version=self._backend_version + ) def mean(self, *column_names: str) -> PolarsExpr: import polars as pl # ignore-banned-import() @@ -95,8 +110,16 @@ def mean(self, *column_names: str) -> PolarsExpr: from narwhals._polars.expr import PolarsExpr if self._backend_version < (0, 20, 4): # pragma: no cover - return PolarsExpr(pl.mean([*column_names]), dtypes=self._dtypes) # type: ignore[arg-type] - return PolarsExpr(pl.mean(*column_names), dtypes=self._dtypes) + return PolarsExpr( + pl.mean([*column_names]), # type: ignore[arg-type] + dtypes=self._dtypes, + backend_version=self._backend_version, + ) + return PolarsExpr( + pl.mean(*column_names), + dtypes=self._dtypes, + backend_version=self._backend_version, + ) def mean_horizontal(self, *exprs: IntoPolarsExpr) -> PolarsExpr: import polars as pl # ignore-banned-import() @@ -110,11 +133,13 @@ def mean_horizontal(self, *exprs: IntoPolarsExpr) -> PolarsExpr: pl.sum_horizontal(e._native_expr for e in polars_exprs) / pl.sum_horizontal(1 - e.is_null()._native_expr for e in polars_exprs), dtypes=self._dtypes, + backend_version=self._backend_version, ) return PolarsExpr( pl.mean_horizontal(e._native_expr for e in polars_exprs), dtypes=self._dtypes, + backend_version=self._backend_version, ) def concat_str( @@ -163,8 +188,7 @@ def concat_str( ) return PolarsExpr( - result, - dtypes=self._dtypes, + result, dtypes=self._dtypes, backend_version=self._backend_version ) return PolarsExpr( @@ -174,16 +198,18 @@ def concat_str( ignore_nulls=ignore_nulls, ), dtypes=self._dtypes, + backend_version=self._backend_version, ) @property def selectors(self) -> PolarsSelectors: - return PolarsSelectors(self._dtypes) + return PolarsSelectors(self._dtypes, backend_version=self._backend_version) class PolarsSelectors: - def __init__(self, dtypes: DTypes) -> None: + def __init__(self, dtypes: DTypes, backend_version: tuple[int, ...]) -> None: self._dtypes = dtypes + self._backend_version = backend_version def by_dtype(self, dtypes: Iterable[DType]) -> PolarsExpr: import polars as pl # ignore-banned-import() @@ -195,6 +221,7 @@ def by_dtype(self, dtypes: Iterable[DType]) -> PolarsExpr: [narwhals_to_native_dtype(dtype, self._dtypes) for dtype in dtypes] ), dtypes=self._dtypes, + backend_version=self._backend_version, ) def numeric(self) -> PolarsExpr: @@ -202,32 +229,50 @@ def numeric(self) -> PolarsExpr: from narwhals._polars.expr import PolarsExpr - return PolarsExpr(pl.selectors.numeric(), dtypes=self._dtypes) + return PolarsExpr( + pl.selectors.numeric(), + dtypes=self._dtypes, + backend_version=self._backend_version, + ) def boolean(self) -> PolarsExpr: import polars as pl # ignore-banned-import() from narwhals._polars.expr import PolarsExpr - return PolarsExpr(pl.selectors.boolean(), dtypes=self._dtypes) + return PolarsExpr( + pl.selectors.boolean(), + dtypes=self._dtypes, + backend_version=self._backend_version, + ) def string(self) -> PolarsExpr: import polars as pl # ignore-banned-import() from narwhals._polars.expr import PolarsExpr - return PolarsExpr(pl.selectors.string(), dtypes=self._dtypes) + return PolarsExpr( + pl.selectors.string(), + dtypes=self._dtypes, + backend_version=self._backend_version, + ) def categorical(self) -> PolarsExpr: import polars as pl # ignore-banned-import() from narwhals._polars.expr import PolarsExpr - return PolarsExpr(pl.selectors.categorical(), dtypes=self._dtypes) + return PolarsExpr( + pl.selectors.categorical(), + dtypes=self._dtypes, + backend_version=self._backend_version, + ) def all(self) -> PolarsExpr: import polars as pl # ignore-banned-import() from narwhals._polars.expr import PolarsExpr - return PolarsExpr(pl.selectors.all(), dtypes=self._dtypes) + return PolarsExpr( + pl.selectors.all(), dtypes=self._dtypes, backend_version=self._backend_version + ) diff --git a/narwhals/_polars/series.py b/narwhals/_polars/series.py index 74dbdbc39..fc0efae11 100644 --- a/narwhals/_polars/series.py +++ b/narwhals/_polars/series.py @@ -109,6 +109,9 @@ def cast(self, dtype: DType) -> Self: def replace_strict(self, mapping: Mapping[Any, Any], *, return_dtype: DType) -> Self: ser = self._native_series dtype = narwhals_to_native_dtype(return_dtype, self._dtypes) + if self._backend_version < (1,): + msg = f"`replace_strict` is only available in Polars>=1.0, found version {self._backend_version}" + raise NotImplementedError(msg) return self._from_native_series(ser.replace_strict(mapping, return_dtype=dtype)) def __array__(self, dtype: Any = None, copy: bool | None = None) -> np.ndarray: diff --git a/tests/expr_and_series/replace_strict_test.py b/tests/expr_and_series/replace_strict_test.py index 1a88d290a..47b4a5f95 100644 --- a/tests/expr_and_series/replace_strict_test.py +++ b/tests/expr_and_series/replace_strict_test.py @@ -4,11 +4,15 @@ from polars.exceptions import PolarsError import narwhals.stable.v1 as nw +from tests.utils import POLARS_VERSION from tests.utils import Constructor from tests.utils import ConstructorEager from tests.utils import assert_equal_data +@pytest.mark.skipif( + POLARS_VERSION < (1, 0), reason="replace_strict only available after 1.0" +) def test_replace_strict(constructor: Constructor) -> None: df = nw.from_native(constructor({"a": [1, 2, 3]})) result = df.select( @@ -19,6 +23,9 @@ def test_replace_strict(constructor: Constructor) -> None: assert_equal_data(result, {"a": ["one", "two", "three"]}) +@pytest.mark.skipif( + POLARS_VERSION < (1, 0), reason="replace_strict only available after 1.0" +) def test_replace_strict_series(constructor_eager: ConstructorEager) -> None: df = nw.from_native(constructor_eager({"a": [1, 2, 3]})) result = df.select( @@ -27,6 +34,9 @@ def test_replace_strict_series(constructor_eager: ConstructorEager) -> None: assert_equal_data(result, {"a": ["one", "two", "three"]}) +@pytest.mark.skipif( + POLARS_VERSION < (1, 0), reason="replace_strict only available after 1.0" +) def test_replace_with_default( constructor: Constructor, request: pytest.FixtureRequest ) -> None: @@ -34,7 +44,7 @@ def test_replace_with_default( request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor({"a": [1, 2, 3]})) if "polars_lazy" in str(constructor): - with pytest.raises((ValueError, PolarsError)): + with pytest.raises(PolarsError): df.lazy().select( nw.col("a").replace_strict({1: 3, 3: 4}, return_dtype=nw.Int64) ).collect()