Skip to content

Commit

Permalink
feat: pyarrow item, is_empty, rename, write_parquet (#520)
Browse files Browse the repository at this point in the history
* feat: pyarrow item,is_empty,iter_rows,rename,write_parquet

* merge main, rollback iter_rows
  • Loading branch information
FBruzzesi authored Jul 14, 2024
1 parent a9abd1b commit 5f0e701
Show file tree
Hide file tree
Showing 10 changed files with 67 additions and 17 deletions.
31 changes: 31 additions & 0 deletions narwhals/_arrow/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from narwhals._expression_parsing import evaluate_into_exprs
from narwhals.dependencies import get_numpy
from narwhals.dependencies import get_pyarrow
from narwhals.dependencies import get_pyarrow_parquet
from narwhals.utils import flatten

if TYPE_CHECKING:
Expand Down Expand Up @@ -340,3 +341,33 @@ def collect(self) -> ArrowDataFrame:

def clone(self) -> Self:
raise NotImplementedError("clone is not yet supported on PyArrow tables")

def is_empty(self: Self) -> bool:
return self.shape[0] == 0

def item(self: Self, row: int | None = None, column: int | str | None = None) -> Any:
if row is None and column is None:
if self.shape != (1, 1):
msg = (
"can only call `.item()` if the dataframe is of shape (1, 1),"
" or if explicit row/col values are provided;"
f" frame has shape {self.shape!r}"
)
raise ValueError(msg)
return self._native_dataframe[0][0].as_py()

elif row is None or column is None:
msg = "cannot call `.item()` with only one of `row` or `column`"
raise ValueError(msg)

_col = self.columns.index(column) if isinstance(column, str) else column
return self._native_dataframe[_col][row].as_py()

def rename(self, mapping: dict[str, str]) -> Self:
df = self._native_dataframe
new_cols = [mapping.get(c, c) for c in df.column_names]
return self._from_native_dataframe(df.rename_columns(new_cols))

def write_parquet(self, file: Any) -> Any:
pp = get_pyarrow_parquet()
pp.write_table(self._native_dataframe, file)
11 changes: 11 additions & 0 deletions narwhals/_arrow/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,17 @@ def tail(self, n: int) -> Self:
else:
return self._from_native_series(ser.slice(abs(n)))

def item(self: Self, index: int | None = None) -> Any:
if index is None:
if len(self) != 1:
msg = (
"can only call '.item()' if the Series is of length 1,"
f" or an explicit index is provided (Series is of length {len(self)})"
)
raise ValueError(msg)
return self._native_series[0].as_py()
return self._native_series[index].as_py()

@property
def shape(self) -> tuple[int]:
return (len(self._native_series),)
Expand Down
9 changes: 9 additions & 0 deletions narwhals/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,15 @@ def get_pyarrow_compute() -> Any: # pragma: no cover
return None


def get_pyarrow_parquet() -> Any: # pragma: no cover
"""Get pyarrow.parquet module (if pyarrow has already been imported - else return None)."""
if "pyarrow" in sys.modules:
import pyarrow.parquet as pp

return pp
return None


def get_numpy() -> Any:
"""Get numpy module (if already imported - else return None)."""
return sys.modules.get("numpy", None)
Expand Down
4 changes: 2 additions & 2 deletions tests/frame/is_empty_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@


@pytest.mark.parametrize(("threshold", "expected"), [(0, False), (10, True)])
def test_is_empty(constructor: Any, threshold: Any, expected: Any) -> None:
def test_is_empty(constructor_with_pyarrow: Any, threshold: Any, expected: Any) -> None:
data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]}
df_raw = constructor(data)
df_raw = constructor_with_pyarrow(data)
df = nw.from_native(df_raw, eager_only=True)
result = df.filter(nw.col("a") > threshold).is_empty()
assert result == expected
7 changes: 3 additions & 4 deletions tests/frame/rename_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@
from tests.utils import compare_dicts


def test_rename(constructor: Any) -> None:
def test_rename(constructor_with_pyarrow: Any) -> None:
data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]}
df = nw.from_native(constructor(data), eager_only=True)
df = nw.from_native(constructor_with_pyarrow(data), eager_only=True)
result = df.rename({"a": "x", "b": "y"})
result_native = nw.to_native(result)
expected = {"x": [1, 3, 2], "y": [4, 4, 6], "z": [7.0, 8, 9]}
compare_dicts(result_native, expected)
compare_dicts(result, expected)
1 change: 1 addition & 0 deletions tests/frame/rows_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def test_rows(
):
df.rows(named=named)
return

result = df.rows(named=named)
assert result == expected

Expand Down
4 changes: 2 additions & 2 deletions tests/frame/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ def test_library(df_raw: Any, df_raw_right: Any) -> None:
df_left.join(df_right, left_on=["a"], right_on=["a"], how="inner")


@pytest.mark.parametrize("df_raw", [df_pandas, df_polars])
@pytest.mark.parametrize("df_raw", [df_pandas, df_polars, df_pa])
@pytest.mark.parametrize(
("row", "column", "expected"),
[(0, 2, 7), (1, "z", 8)],
Expand All @@ -350,7 +350,7 @@ def test_item(
assert df.select("a").head(1).item() == 1


@pytest.mark.parametrize("df_raw", [df_pandas, df_polars])
@pytest.mark.parametrize("df_raw", [df_pandas, df_polars, df_pa])
@pytest.mark.parametrize(
("row", "column", "err_msg"),
[
Expand Down
6 changes: 4 additions & 2 deletions tests/frame/write_parquet_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
@pytest.mark.skipif(
parse_version(pd.__version__) < parse_version("2.0.0"), reason="too old for pyarrow"
)
def test_write_parquet(constructor: Any, tmpdir: pytest.TempdirFactory) -> None:
def test_write_parquet(
constructor_with_pyarrow: Any, tmpdir: pytest.TempdirFactory
) -> None:
path = str(tmpdir / "foo.parquet") # type: ignore[operator]
nw.from_native(constructor(data), eager_only=True).write_parquet(path)
nw.from_native(constructor_with_pyarrow(data), eager_only=True).write_parquet(path)
assert os.path.exists(path)
6 changes: 4 additions & 2 deletions tests/series/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numpy as np
import pandas as pd
import polars as pl
import pyarrow as pa
import pytest
from numpy.testing import assert_array_equal
from pandas.testing import assert_series_equal
Expand Down Expand Up @@ -40,6 +41,7 @@
df_pandas_nullable = df_pandas
df_polars = pl.DataFrame({"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]})
df_lazy = pl.LazyFrame({"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]})
df_pa = pa.table({"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]})


@pytest.mark.parametrize(
Expand Down Expand Up @@ -255,13 +257,13 @@ def test_cast_string() -> None:
df_pandas = pd.DataFrame({"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]})


@pytest.mark.parametrize("df_raw", [df_pandas, df_polars])
@pytest.mark.parametrize("df_raw", [df_pandas, df_polars, df_pa])
@pytest.mark.parametrize(("index", "expected"), [(0, 1), (1, 3)])
def test_item(df_raw: Any, index: int, expected: int) -> None:
s = nw.from_native(df_raw["a"], series_only=True)
result = s.item(index)
assert result == expected
assert nw.from_native(df_raw["a"].head(1), series_only=True).item() == 1
assert s.head(1).item() == 1

with pytest.raises(
ValueError,
Expand Down
5 changes: 0 additions & 5 deletions utils/check_backend_completeness.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,10 @@

MISSING = [
"DataFrame.is_duplicated",
"DataFrame.is_empty",
"DataFrame.is_unique",
"DataFrame.item",
"DataFrame.iter_rows",
"DataFrame.pipe",
"DataFrame.rename",
"DataFrame.unique",
"DataFrame.write_parquet",
"Series.drop_nulls",
"Series.fill_null",
"Series.from_iterable",
Expand All @@ -32,7 +28,6 @@
"Series.is_null",
"Series.is_sorted",
"Series.is_unique",
"Series.item",
"Series.len",
"Series.n_unique",
"Series.quantile",
Expand Down

0 comments on commit 5f0e701

Please sign in to comment.