Skip to content

Commit

Permalink
feat: add more methods to the pyarrow backend (#359)
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli authored Jun 30, 2024
1 parent b4df55b commit 4581ddd
Show file tree
Hide file tree
Showing 8 changed files with 75 additions and 24 deletions.
11 changes: 11 additions & 0 deletions narwhals/_arrow/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,17 @@ def _from_dataframe(self, df: Any) -> Self:
def shape(self) -> tuple[int, int]:
return self._dataframe.shape # type: ignore[no-any-return]

def __len__(self) -> int:
return len(self._dataframe)

def rows(
self, *, named: bool = False
) -> list[tuple[Any, ...]] | list[dict[str, Any]]:
if not named:
msg = "Unnamed rows are not yet supported on PyArrow tables"
raise NotImplementedError(msg)
return self._dataframe.to_pylist() # type: ignore[no-any-return]

@overload
def __getitem__(self, item: str) -> ArrowSeries: ...

Expand Down
7 changes: 7 additions & 0 deletions narwhals/_arrow/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@
from typing import TYPE_CHECKING
from typing import Any

from narwhals._arrow.utils import translate_dtype
from narwhals.dependencies import get_pyarrow_compute

if TYPE_CHECKING:
from typing_extensions import Self

from narwhals.dtypes import DType


class ArrowSeries:
def __init__(
Expand Down Expand Up @@ -51,6 +54,10 @@ def alias(self, name: str) -> Self:
name=name,
)

@property
def dtype(self) -> DType:
return translate_dtype(self._series.type)

def cum_sum(self) -> Self:
pc = get_pyarrow_compute()
return self._from_series(pc.cumulative_sum(self._series))
Expand Down
2 changes: 1 addition & 1 deletion narwhals/_pandas_like/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def rows(
if not named:
return list(self._dataframe.itertuples(index=False, name=None))

return self._dataframe.to_dict("records") # type: ignore[no-any-return]
return self._dataframe.to_dict(orient="records") # type: ignore[no-any-return]

def iter_rows(
self,
Expand Down
8 changes: 4 additions & 4 deletions narwhals/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,27 +640,27 @@ def rows(
self,
*,
named: Literal[False],
) -> tuple[Any, ...]: ...
) -> list[tuple[Any, ...]]: ...

@overload
def rows(
self,
*,
named: Literal[True],
) -> dict[str, Any]: ...
) -> list[dict[str, Any]]: ...

@overload
def rows(
self,
*,
named: bool,
) -> tuple[Any, ...] | dict[str, Any]: ...
) -> list[tuple[Any, ...]] | list[dict[str, Any]]: ...

def rows(
self,
*,
named: bool = False,
) -> tuple[Any, ...] | dict[str, Any]:
) -> list[tuple[Any, ...]] | list[dict[str, Any]]:
"""
Returns all data in the DataFrame as a list of rows of python-native values.
Expand Down
4 changes: 1 addition & 3 deletions narwhals/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@
from narwhals.series import Series


def to_native(
narwhals_object: LazyFrame | DataFrame | Series, *, strict: bool = True
) -> Any:
def to_native(narwhals_object: Any, *, strict: bool = True) -> Any:
"""
Convert Narwhals object to native one.
Expand Down
3 changes: 2 additions & 1 deletion tests/frame/len_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import pandas as pd
import polars as pl
import pyarrow as pa
import pytest

import narwhals as nw
Expand All @@ -12,7 +13,7 @@
}


@pytest.mark.parametrize("constructor", [pd.DataFrame, pl.DataFrame])
@pytest.mark.parametrize("constructor", [pd.DataFrame, pl.DataFrame, pa.table])
def test_drop_nulls(constructor: Any) -> None:
result = len(nw.from_native(constructor(data)))
assert result == 4
48 changes: 37 additions & 11 deletions tests/frame/rows_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@

import pandas as pd
import polars as pl
import pyarrow as pa
import pytest

import narwhals as nw
from narwhals.utils import parse_version
from tests.utils import maybe_get_modin_df

df_pandas = pd.DataFrame({"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]})
df_pa = pa.table({"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]})
if parse_version(pd.__version__) >= parse_version("1.5.0"):
df_pandas_pyarrow = pd.DataFrame(
{"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]}
Expand All @@ -34,15 +35,13 @@
df_pandas_pyarrow = df_pandas
df_pandas_nullable = df_pandas
df_polars = pl.DataFrame({"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]})
df_mpd = maybe_get_modin_df(df_pandas)

df_pandas_na = pd.DataFrame({"a": [None, 3, 2], "b": [4, 4, 6], "z": [7.0, None, 9]})
df_polars_na = pl.DataFrame({"a": [None, 3, 2], "b": [4, 4, 6], "z": [7.0, None, 9]})


@pytest.mark.parametrize("method_name", ["iter_rows", "rows"])
@pytest.mark.parametrize(
"df_raw", [df_pandas, df_pandas_nullable, df_pandas_pyarrow, df_mpd, df_polars]
"df_raw", [df_pandas, df_pandas_nullable, df_pandas_pyarrow, df_polars]
)
@pytest.mark.parametrize(
("named", "expected"),
Expand All @@ -58,20 +57,47 @@
),
],
)
@pytest.mark.filterwarnings("ignore::FutureWarning")
def test_rows(
method_name: str,
def test_iter_rows(
df_raw: Any,
named: bool, # noqa: FBT001
expected: list[tuple[Any, ...]] | list[dict[str, Any]],
) -> None:
# GIVEN
df = nw.DataFrame(df_raw)
result = list(df.iter_rows(named=named))
assert result == expected

# WHEN
result = list(getattr(df, method_name)(named=named))

# THEN
@pytest.mark.parametrize(
"df_raw", [df_pandas, df_pandas_nullable, df_pandas_pyarrow, df_polars, df_pa]
)
@pytest.mark.parametrize(
("named", "expected"),
[
(False, [(1, 4, 7.0), (3, 4, 8.0), (2, 6, 9.0)]),
(
True,
[
{"a": 1, "b": 4, "z": 7.0},
{"a": 3, "b": 4, "z": 8.0},
{"a": 2, "b": 6, "z": 9.0},
],
),
],
)
def test_rows(
df_raw: Any,
named: bool, # noqa: FBT001
expected: list[tuple[Any, ...]] | list[dict[str, Any]],
) -> None:
df = nw.DataFrame(df_raw)
if isinstance(df_raw, pa.Table) and not named:
with pytest.raises(
NotImplementedError,
match="Unnamed rows are not yet supported on PyArrow tables",
):
df.rows(named=named)
return
result = df.rows(named=named)
assert result == expected


Expand Down
16 changes: 12 additions & 4 deletions tests/frame/schema_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class Foo: ...
parse_version(pd.__version__) < parse_version("2.0.0"), reason="too old"
)
def test_dtypes() -> None:
df = pl.DataFrame(
df_pl = pl.DataFrame(
{
"a": [1],
"b": [1],
Expand Down Expand Up @@ -85,7 +85,8 @@ def test_dtypes() -> None:
"p": pl.Categorical,
},
)
result = nw.DataFrame(df).schema
df = nw.DataFrame(df_pl)
result = df.schema
expected = {
"a": nw.Int64,
"b": nw.Int32,
Expand All @@ -105,7 +106,14 @@ def test_dtypes() -> None:
"p": nw.Categorical,
}
assert result == expected
result_pd = nw.DataFrame(df.to_pandas(use_pyarrow_extension_array=True)).schema
assert {name: df[name].dtype for name in df.columns} == expected
df_pd = df_pl.to_pandas(use_pyarrow_extension_array=True)
df = nw.DataFrame(df_pd)
result_pd = df.schema
assert result_pd == expected
result_pa = nw.DataFrame(df.to_arrow()).schema
assert {name: df[name].dtype for name in df.columns} == expected
df_pa = df_pl.to_arrow()
df = nw.DataFrame(df_pa)
result_pa = df.schema
assert result_pa == expected
assert {name: df[name].dtype for name in df.columns} == expected

0 comments on commit 4581ddd

Please sign in to comment.