Skip to content

Commit

Permalink
feat: add Expr.any, Expr.all, Series.any, and Series.all for PyArrow … (
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli authored Jul 2, 2024
1 parent bfe21e6 commit 7b9eff6
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 5 deletions.
6 changes: 6 additions & 0 deletions narwhals/_arrow/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,12 @@ def __narwhals_namespace__(self) -> ArrowNamespace:
def cum_sum(self) -> Self:
return reuse_series_implementation(self, "cum_sum") # type: ignore[type-var]

def any(self) -> Self:
return reuse_series_implementation(self, "any", returns_scalar=True) # type: ignore[type-var]

def all(self) -> Self:
return reuse_series_implementation(self, "all", returns_scalar=True) # type: ignore[type-var]

@property
def dt(self) -> ArrowExprDateTimeNamespace:
return ArrowExprDateTimeNamespace(self)
Expand Down
9 changes: 8 additions & 1 deletion narwhals/_arrow/namespace.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import Any
from typing import Iterable

from narwhals import dtypes
from narwhals._arrow.expr import ArrowExpr
from narwhals._arrow.series import ArrowSeries
from narwhals.utils import flatten

if TYPE_CHECKING:
from typing import Callable

from narwhals._arrow.dataframe import ArrowDataFrame
from narwhals._arrow.expr import ArrowExpr
from narwhals._arrow.series import ArrowSeries


class ArrowNamespace:
Expand Down Expand Up @@ -66,6 +67,12 @@ def _create_expr_from_series(self, series: ArrowSeries) -> ArrowExpr:
output_names=None,
)

def _create_series_from_scalar(self, value: Any, series: ArrowSeries) -> ArrowSeries:
return ArrowSeries.from_iterable(
[value],
name=series.name,
)

# --- not in spec ---
def __init__(self) -> None: ...

Expand Down
19 changes: 19 additions & 0 deletions narwhals/_arrow/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

from typing import TYPE_CHECKING
from typing import Any
from typing import Iterable

from narwhals._arrow.utils import translate_dtype
from narwhals._pandas_like.utils import native_series_from_iterable
from narwhals.dependencies import get_pyarrow_compute

if TYPE_CHECKING:
Expand All @@ -29,6 +31,15 @@ def _from_series(self, series: Any) -> Self:
name=self._name,
)

@classmethod
def from_iterable(cls: type[Self], data: Iterable[Any], name: str) -> Self:
return cls(
native_series_from_iterable(
data, name=name, index=None, implementation="arrow"
),
name=name,
)

def __len__(self) -> int:
return len(self._series)

Expand Down Expand Up @@ -62,6 +73,14 @@ def cum_sum(self) -> Self:
pc = get_pyarrow_compute()
return self._from_series(pc.cumulative_sum(self._series))

def any(self) -> bool:
pc = get_pyarrow_compute()
return pc.any(self._series) # type: ignore[no-any-return]

def all(self) -> bool:
pc = get_pyarrow_compute()
return pc.all(self._series) # type: ignore[no-any-return]

@property
def shape(self) -> tuple[int]:
return (len(self._series),)
Expand Down
4 changes: 4 additions & 0 deletions narwhals/_pandas_like/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from narwhals.dependencies import get_modin
from narwhals.dependencies import get_numpy
from narwhals.dependencies import get_pandas
from narwhals.dependencies import get_pyarrow
from narwhals.utils import flatten
from narwhals.utils import isinstance_or_issubclass
from narwhals.utils import parse_version
Expand Down Expand Up @@ -365,6 +366,9 @@ def native_series_from_iterable(
mpd = get_modin()

return mpd.Series(data, name=name, index=index)
if implementation == "arrow":
pa = get_pyarrow()
return pa.chunked_array([data])
msg = f"Unknown implementation: {implementation}" # pragma: no cover
raise TypeError(msg) # pragma: no cover

Expand Down
16 changes: 12 additions & 4 deletions tests/expr/any_all_test.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,30 @@
from typing import Any

import pyarrow as pa
import pytest

import narwhals as nw
from narwhals.utils import parse_version
from tests.utils import compare_dicts


def test_any_all(constructor: Any) -> None:
def test_any_all(constructor_with_pyarrow: Any, request: Any) -> None:
if "table" in str(constructor_with_pyarrow) and parse_version(
pa.__version__
) < parse_version("12.0.0"): # pragma: no cover
request.applymarker(pytest.mark.xfail)
df = nw.from_native(
constructor(
constructor_with_pyarrow(
{
"a": [True, False, True],
"b": [True, True, True],
"c": [False, False, False],
}
)
)
result = nw.to_native(df.select(nw.all().all()))
result = df.select(nw.all().all())
expected = {"a": [False], "b": [True], "c": [False]}
compare_dicts(result, expected)
result = nw.to_native(df.select(nw.all().any()))
result = df.select(nw.all().any())
expected = {"a": [True], "b": [True], "c": [False]}
compare_dicts(result, expected)

0 comments on commit 7b9eff6

Please sign in to comment.