Skip to content

Commit

Permalink
old polars versions
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli committed Nov 6, 2024
1 parent c054c2f commit 6c34615
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 20 deletions.
3 changes: 2 additions & 1 deletion narwhals/_arrow/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,9 +668,10 @@ def replace_strict(
self, mapping: Mapping[Any, Any], *, return_dtype: DType
) -> ArrowSeries:
import pyarrow as pa # ignore-banned-import
import pyarrow.compute as pc # ignore-banned-import

# https://stackoverflow.com/a/79111029/4451315
idxs = pa.compute.index_in(self._native_series, pa.array(list(mapping.keys())))
idxs = pc.index_in(self._native_series, pa.array(list(mapping.keys())))
result_native = pa.compute.take(pa.array(list(mapping.values())), idxs).cast(
narwhals_to_native_dtype(return_dtype, self._dtypes)
)
Expand Down
12 changes: 10 additions & 2 deletions narwhals/_polars/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,21 @@


class PolarsExpr:
def __init__(self, expr: Any, dtypes: DTypes) -> None:
def __init__(
self, expr: Any, dtypes: DTypes, backend_version: tuple[int, ...]
) -> None:
self._native_expr = expr
self._implementation = Implementation.POLARS
self._dtypes = dtypes
self._backend_version = backend_version

def __repr__(self) -> str: # pragma: no cover
return "PolarsExpr"

def _from_native_expr(self, expr: Any) -> Self:
return self.__class__(expr, dtypes=self._dtypes)
return self.__class__(
expr, dtypes=self._dtypes, backend_version=self._backend_version
)

def __getattr__(self, attr: str) -> Any:
def func(*args: Any, **kwargs: Any) -> Any:
Expand All @@ -45,6 +50,9 @@ def cast(self, dtype: DType) -> Self:
def replace_strict(self, mapping: Mapping[Any, Any], *, return_dtype: DType) -> Self:
expr = self._native_expr
return_dtype = narwhals_to_native_dtype(return_dtype, self._dtypes)
if self._backend_version < (1,):
msg = f"`replace_strict` is only available in Polars>=1.0, found version {self._backend_version}"
raise NotImplementedError(msg)
return self._from_native_expr(
expr.replace_strict(mapping, return_dtype=return_dtype)
)
Expand Down
77 changes: 61 additions & 16 deletions narwhals/_polars/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,11 @@ def __getattr__(self, attr: str) -> Any:

def func(*args: Any, **kwargs: Any) -> Any:
args, kwargs = extract_args_kwargs(args, kwargs) # type: ignore[assignment]
return PolarsExpr(getattr(pl, attr)(*args, **kwargs), dtypes=self._dtypes)
return PolarsExpr(
getattr(pl, attr)(*args, **kwargs),
dtypes=self._dtypes,
backend_version=self._backend_version,
)

return func

Expand All @@ -45,16 +49,24 @@ def nth(self, *indices: int) -> PolarsExpr:
if self._backend_version < (1, 0, 0): # pragma: no cover
msg = "`nth` is only supported for Polars>=1.0.0. Please use `col` for columns selection instead."
raise AttributeError(msg)
return PolarsExpr(pl.nth(*indices), dtypes=self._dtypes)
return PolarsExpr(
pl.nth(*indices), dtypes=self._dtypes, backend_version=self._backend_version
)

def len(self) -> PolarsExpr:
import polars as pl # ignore-banned-import()

from narwhals._polars.expr import PolarsExpr

if self._backend_version < (0, 20, 5): # pragma: no cover
return PolarsExpr(pl.count().alias("len"), dtypes=self._dtypes)
return PolarsExpr(pl.len(), dtypes=self._dtypes)
return PolarsExpr(
pl.count().alias("len"),
dtypes=self._dtypes,
backend_version=self._backend_version,
)
return PolarsExpr(
pl.len(), dtypes=self._dtypes, backend_version=self._backend_version
)

def concat(
self,
Expand Down Expand Up @@ -86,17 +98,28 @@ def lit(self, value: Any, dtype: DType | None = None) -> PolarsExpr:
return PolarsExpr(
pl.lit(value, dtype=narwhals_to_native_dtype(dtype, self._dtypes)),
dtypes=self._dtypes,
backend_version=self._backend_version,
)
return PolarsExpr(pl.lit(value), dtypes=self._dtypes)
return PolarsExpr(
pl.lit(value), dtypes=self._dtypes, backend_version=self._backend_version
)

def mean(self, *column_names: str) -> PolarsExpr:
import polars as pl # ignore-banned-import()

from narwhals._polars.expr import PolarsExpr

if self._backend_version < (0, 20, 4): # pragma: no cover
return PolarsExpr(pl.mean([*column_names]), dtypes=self._dtypes) # type: ignore[arg-type]
return PolarsExpr(pl.mean(*column_names), dtypes=self._dtypes)
return PolarsExpr(
pl.mean([*column_names]), # type: ignore[arg-type]
dtypes=self._dtypes,
backend_version=self._backend_version,
)
return PolarsExpr(
pl.mean(*column_names),
dtypes=self._dtypes,
backend_version=self._backend_version,
)

def mean_horizontal(self, *exprs: IntoPolarsExpr) -> PolarsExpr:
import polars as pl # ignore-banned-import()
Expand All @@ -110,11 +133,13 @@ def mean_horizontal(self, *exprs: IntoPolarsExpr) -> PolarsExpr:
pl.sum_horizontal(e._native_expr for e in polars_exprs)
/ pl.sum_horizontal(1 - e.is_null()._native_expr for e in polars_exprs),
dtypes=self._dtypes,
backend_version=self._backend_version,
)

return PolarsExpr(
pl.mean_horizontal(e._native_expr for e in polars_exprs),
dtypes=self._dtypes,
backend_version=self._backend_version,
)

def concat_str(
Expand Down Expand Up @@ -163,8 +188,7 @@ def concat_str(
)

return PolarsExpr(
result,
dtypes=self._dtypes,
result, dtypes=self._dtypes, backend_version=self._backend_version
)

return PolarsExpr(
Expand All @@ -174,16 +198,18 @@ def concat_str(
ignore_nulls=ignore_nulls,
),
dtypes=self._dtypes,
backend_version=self._backend_version,
)

@property
def selectors(self) -> PolarsSelectors:
return PolarsSelectors(self._dtypes)
return PolarsSelectors(self._dtypes, backend_version=self._backend_version)


class PolarsSelectors:
def __init__(self, dtypes: DTypes) -> None:
def __init__(self, dtypes: DTypes, backend_version: tuple[int, ...]) -> None:
self._dtypes = dtypes
self._backend_version = backend_version

def by_dtype(self, dtypes: Iterable[DType]) -> PolarsExpr:
import polars as pl # ignore-banned-import()
Expand All @@ -195,39 +221,58 @@ def by_dtype(self, dtypes: Iterable[DType]) -> PolarsExpr:
[narwhals_to_native_dtype(dtype, self._dtypes) for dtype in dtypes]
),
dtypes=self._dtypes,
backend_version=self._backend_version,
)

def numeric(self) -> PolarsExpr:
import polars as pl # ignore-banned-import()

from narwhals._polars.expr import PolarsExpr

return PolarsExpr(pl.selectors.numeric(), dtypes=self._dtypes)
return PolarsExpr(
pl.selectors.numeric(),
dtypes=self._dtypes,
backend_version=self._backend_version,
)

def boolean(self) -> PolarsExpr:
import polars as pl # ignore-banned-import()

from narwhals._polars.expr import PolarsExpr

return PolarsExpr(pl.selectors.boolean(), dtypes=self._dtypes)
return PolarsExpr(
pl.selectors.boolean(),
dtypes=self._dtypes,
backend_version=self._backend_version,
)

def string(self) -> PolarsExpr:
import polars as pl # ignore-banned-import()

from narwhals._polars.expr import PolarsExpr

return PolarsExpr(pl.selectors.string(), dtypes=self._dtypes)
return PolarsExpr(
pl.selectors.string(),
dtypes=self._dtypes,
backend_version=self._backend_version,
)

def categorical(self) -> PolarsExpr:
import polars as pl # ignore-banned-import()

from narwhals._polars.expr import PolarsExpr

return PolarsExpr(pl.selectors.categorical(), dtypes=self._dtypes)
return PolarsExpr(
pl.selectors.categorical(),
dtypes=self._dtypes,
backend_version=self._backend_version,
)

def all(self) -> PolarsExpr:
import polars as pl # ignore-banned-import()

from narwhals._polars.expr import PolarsExpr

return PolarsExpr(pl.selectors.all(), dtypes=self._dtypes)
return PolarsExpr(
pl.selectors.all(), dtypes=self._dtypes, backend_version=self._backend_version
)
3 changes: 3 additions & 0 deletions narwhals/_polars/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,9 @@ def cast(self, dtype: DType) -> Self:
def replace_strict(self, mapping: Mapping[Any, Any], *, return_dtype: DType) -> Self:
ser = self._native_series
dtype = narwhals_to_native_dtype(return_dtype, self._dtypes)
if self._backend_version < (1,):
msg = f"`replace_strict` is only available in Polars>=1.0, found version {self._backend_version}"
raise NotImplementedError(msg)
return self._from_native_series(ser.replace_strict(mapping, return_dtype=dtype))

def __array__(self, dtype: Any = None, copy: bool | None = None) -> np.ndarray:
Expand Down
12 changes: 11 additions & 1 deletion tests/expr_and_series/replace_strict_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,15 @@
from polars.exceptions import PolarsError

import narwhals.stable.v1 as nw
from tests.utils import POLARS_VERSION
from tests.utils import Constructor
from tests.utils import ConstructorEager
from tests.utils import assert_equal_data


@pytest.mark.skipif(
POLARS_VERSION < (1, 0), reason="replace_strict only available after 1.0"
)
def test_replace_strict(constructor: Constructor) -> None:
df = nw.from_native(constructor({"a": [1, 2, 3]}))
result = df.select(
Expand All @@ -19,6 +23,9 @@ def test_replace_strict(constructor: Constructor) -> None:
assert_equal_data(result, {"a": ["one", "two", "three"]})


@pytest.mark.skipif(
POLARS_VERSION < (1, 0), reason="replace_strict only available after 1.0"
)
def test_replace_strict_series(constructor_eager: ConstructorEager) -> None:
df = nw.from_native(constructor_eager({"a": [1, 2, 3]}))
result = df.select(
Expand All @@ -27,14 +34,17 @@ def test_replace_strict_series(constructor_eager: ConstructorEager) -> None:
assert_equal_data(result, {"a": ["one", "two", "three"]})


@pytest.mark.skipif(
POLARS_VERSION < (1, 0), reason="replace_strict only available after 1.0"
)
def test_replace_with_default(
constructor: Constructor, request: pytest.FixtureRequest
) -> None:
if "dask" in str(constructor):
request.applymarker(pytest.mark.xfail)
df = nw.from_native(constructor({"a": [1, 2, 3]}))
if "polars_lazy" in str(constructor):
with pytest.raises((ValueError, PolarsError)):
with pytest.raises(PolarsError):
df.lazy().select(
nw.col("a").replace_strict({1: 3, 3: 4}, return_dtype=nw.Int64)
).collect()
Expand Down

0 comments on commit 6c34615

Please sign in to comment.