Skip to content

Commit

Permalink
merge main
Browse files Browse the repository at this point in the history
  • Loading branch information
FBruzzesi committed Jul 15, 2024
2 parents a6be7bc + 3da5392 commit cfdda8c
Show file tree
Hide file tree
Showing 89 changed files with 667 additions and 576 deletions.
7 changes: 5 additions & 2 deletions narwhals/_arrow/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,19 +175,22 @@ def with_columns(
to_concat = []
output_names = []
# Make sure to preserve column order
length = len(self)
for name in self.columns:
if name in new_column_name_to_new_column_map:
to_concat.append(
validate_dataframe_comparand(
new_column_name_to_new_column_map.pop(name)
length=length, other=new_column_name_to_new_column_map.pop(name)
)
)
else:
to_concat.append(self._native_dataframe[name])
output_names.append(name)
for s in new_column_name_to_new_column_map:
to_concat.append(
validate_dataframe_comparand(new_column_name_to_new_column_map[s])
validate_dataframe_comparand(
length=length, other=new_column_name_to_new_column_map[s]
)
)
output_names.append(s)
df = self._native_dataframe.__class__.from_arrays(to_concat, names=output_names)
Expand Down
23 changes: 21 additions & 2 deletions narwhals/_arrow/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from narwhals import dtypes
from narwhals._arrow.dataframe import ArrowDataFrame
from narwhals._arrow.expr import ArrowExpr
from narwhals._arrow.series import ArrowSeries
from narwhals._arrow.utils import horizontal_concat
from narwhals._arrow.utils import vertical_concat
from narwhals._expression_parsing import parse_into_exprs
Expand All @@ -17,8 +18,6 @@
if TYPE_CHECKING:
from typing import Callable

from narwhals._arrow.expr import ArrowExpr
from narwhals._arrow.series import ArrowSeries
from narwhals._arrow.typing import IntoArrowExpr


Expand Down Expand Up @@ -128,6 +127,26 @@ def all(self) -> ArrowExpr:
backend_version=self._backend_version,
)

def lit(self, value: Any, dtype: dtypes.DType | None) -> ArrowExpr:
def _lit_arrow_series(_: ArrowDataFrame) -> ArrowSeries:
arrow_series = ArrowSeries._from_iterable(
data=[value],
name="lit",
backend_version=self._backend_version,
)
if dtype:
return arrow_series.cast(dtype)
return arrow_series

return ArrowExpr(
lambda df: [_lit_arrow_series(df)],
depth=0,
function_name="lit",
root_names=None,
output_names=["lit"],
backend_version=self._backend_version,
)

def all_horizontal(self, *exprs: IntoArrowExpr) -> ArrowExpr:
return reduce(
lambda x, y: x & y,
Expand Down
4 changes: 3 additions & 1 deletion narwhals/_arrow/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from typing import Iterable
from typing import Sequence

from narwhals._arrow.namespace import ArrowNamespace
from narwhals._arrow.utils import reverse_translate_dtype
from narwhals._arrow.utils import translate_dtype
from narwhals._arrow.utils import validate_column_comparand
Expand All @@ -15,6 +14,7 @@
if TYPE_CHECKING:
from typing_extensions import Self

from narwhals._arrow.namespace import ArrowNamespace
from narwhals.dtypes import DType


Expand Down Expand Up @@ -164,6 +164,8 @@ def count(self) -> int:
return pc.count(self._native_series) # type: ignore[no-any-return]

def __narwhals_namespace__(self) -> ArrowNamespace:
from narwhals._arrow.namespace import ArrowNamespace

return ArrowNamespace(backend_version=self._backend_version)

@property
Expand Down
7 changes: 3 additions & 4 deletions narwhals/_arrow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def validate_column_comparand(other: Any) -> Any:
return other


def validate_dataframe_comparand(other: Any) -> Any:
def validate_dataframe_comparand(length: int, other: Any) -> Any:
"""Validate RHS of binary operation.
If the comparison isn't supported, return `NotImplemented` so that the
Expand All @@ -136,9 +136,8 @@ def validate_dataframe_comparand(other: Any) -> Any:
return NotImplemented
if isinstance(other, ArrowSeries):
if len(other) == 1:
# broadcast
msg = "not implemented yet" # pragma: no cover
raise NotImplementedError(msg)
pa = get_pyarrow()
return pa.chunked_array([[other.item()] * length])
return other._native_series
msg = "Please report a bug" # pragma: no cover
raise AssertionError(msg)
Expand Down
50 changes: 24 additions & 26 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,40 +48,43 @@ def modin_constructor(obj: Any) -> IntoDataFrame: # pragma: no cover
return mpd.DataFrame(obj).convert_dtypes(dtype_backend="pyarrow") # type: ignore[no-any-return]


def polars_constructor(obj: Any) -> IntoDataFrame:
def polars_eager_constructor(obj: Any) -> IntoDataFrame:
return pl.DataFrame(obj)


def polars_lazy_constructor(obj: Any) -> pl.LazyFrame:
return pl.LazyFrame(obj)


def pyarrow_table_constructor(obj: Any) -> IntoDataFrame:
return pa.table(obj) # type: ignore[no-any-return]


if parse_version(pd.__version__) >= parse_version("2.0.0"):
params = [pandas_constructor, pandas_nullable_constructor, pandas_pyarrow_constructor]
eager_constructors = [
pandas_constructor,
pandas_nullable_constructor,
pandas_pyarrow_constructor,
]
else: # pragma: no cover
params = [pandas_constructor]
params.append(polars_constructor)
eager_constructors = [pandas_constructor]

eager_constructors.extend([polars_eager_constructor, pyarrow_table_constructor])

if get_modin() is not None: # pragma: no cover
params.append(modin_constructor)
eager_constructors.append(modin_constructor)


@pytest.fixture(params=params)
@pytest.fixture(params=eager_constructors)
def constructor(request: Any) -> Callable[[Any], IntoDataFrame]:
return request.param # type: ignore[no-any-return]


@pytest.fixture(params=[*params, polars_lazy_constructor])
@pytest.fixture(params=[*eager_constructors, polars_lazy_constructor])
def constructor_with_lazy(request: Any) -> Callable[[Any], Any]:
return request.param # type: ignore[no-any-return]


# TODO(Unassigned): once pyarrow has complete coverage, we can remove this one,
# and just put `pa.table` into `constructor`
@pytest.fixture(params=[*params, pa.table])
def constructor_with_pyarrow(request: Any) -> Callable[[Any], IntoDataFrame]:
return request.param # type: ignore[no-any-return]


def pandas_series_constructor(obj: Any) -> Any:
return pd.Series(obj)

Expand All @@ -103,6 +106,10 @@ def polars_series_constructor(obj: Any) -> Any:
return pl.Series(obj)


def pyarrow_series_constructor(obj: Any) -> Any:
return pa.chunked_array([obj])


if parse_version(pd.__version__) >= parse_version("2.0.0"):
params_series = [
pandas_series_constructor,
Expand All @@ -111,22 +118,13 @@ def polars_series_constructor(obj: Any) -> Any:
]
else: # pragma: no cover
params_series = [pandas_series_constructor]
params_series.append(polars_series_constructor)
if get_modin() is not None: # pragma: no cover
params_series.append(modin_series_constructor)


@pytest.fixture(params=params_series)
def constructor_series(request: Any) -> Callable[[Any], Any]:
return request.param # type: ignore[no-any-return]
params_series.extend([polars_series_constructor, pyarrow_series_constructor])


def pyarrow_chunked_array_constructor(obj: Any) -> Any:
return pa.chunked_array([obj])


# TODO(Unassigned): once pyarrow has complete coverage, we can remove this one,
# and just put `pa.table` into `constructor`
@pytest.fixture(params=[*params_series, pyarrow_chunked_array_constructor])
def constructor_series_with_pyarrow(request: Any) -> Callable[[Any], Any]:
@pytest.fixture(params=params_series)
def constructor_series(request: Any) -> Callable[[Any], Any]:
return request.param # type: ignore[no-any-return]
6 changes: 2 additions & 4 deletions tests/expr/abs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,8 @@
from tests.utils import compare_dicts


def test_abs(constructor_with_pyarrow: Any) -> None:
df = nw.from_native(
constructor_with_pyarrow({"a": [1, 2, 3, -4, 5]}), eager_only=True
)
def test_abs(constructor: Any) -> None:
df = nw.from_native(constructor({"a": [1, 2, 3, -4, 5]}), eager_only=True)
result = df.select(b=nw.col("a").abs())
expected = {"b": [1, 2, 3, 4, 5]}
compare_dicts(result, expected)
Expand Down
4 changes: 2 additions & 2 deletions tests/expr/all_horizontal_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@


@pytest.mark.parametrize("col_expr", [np.array([False, False, True]), nw.col("a"), "a"])
def test_allh(constructor_with_pyarrow: Any, col_expr: Any) -> None:
def test_allh(constructor: Any, col_expr: Any) -> None:
data = {
"a": [False, False, True],
"b": [False, True, True],
}
df = nw.from_native(constructor_with_pyarrow(data), eager_only=True)
df = nw.from_native(constructor(data), eager_only=True)
result = df.select(all=nw.all_horizontal(col_expr, nw.col("b")))

expected = {"all": [False, False, True]}
Expand Down
4 changes: 2 additions & 2 deletions tests/expr/any_all_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
from tests.utils import compare_dicts


def test_any_all(constructor_with_pyarrow: Any) -> None:
def test_any_all(constructor: Any) -> None:
df = nw.from_native(
constructor_with_pyarrow(
constructor(
{
"a": [True, False, True],
"b": [True, True, True],
Expand Down
16 changes: 8 additions & 8 deletions tests/expr/arithmetic_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,21 +21,21 @@
],
)
def test_arithmetic(
attr: str, rhs: Any, expected: list[Any], constructor_with_pyarrow: Any, request: Any
attr: str, rhs: Any, expected: list[Any], constructor: Any, request: Any
) -> None:
if "pandas_pyarrow" in str(constructor_with_pyarrow) and attr == "__mod__":
if "pandas_pyarrow" in str(constructor) and attr == "__mod__":
request.applymarker(pytest.mark.xfail)

# pyarrow case
if "table" in str(constructor_with_pyarrow) and attr in {
if "pyarrow_table" in str(constructor) and attr in {
"__truediv__",
"__floordiv__",
"__mod__",
}:
request.applymarker(pytest.mark.xfail)

data = {"a": [1, 2, 3]}
df = nw.from_native(constructor_with_pyarrow(data))
df = nw.from_native(constructor(data))
result = df.select(getattr(nw.col("a"), attr)(rhs))
compare_dicts(result, {"a": expected})

Expand All @@ -53,21 +53,21 @@ def test_arithmetic(
],
)
def test_right_arithmetic(
attr: str, rhs: Any, expected: list[Any], constructor_with_pyarrow: Any, request: Any
attr: str, rhs: Any, expected: list[Any], constructor: Any, request: Any
) -> None:
if "pandas_pyarrow" in str(constructor_with_pyarrow) and attr in {"__rmod__"}:
if "pandas_pyarrow" in str(constructor) and attr in {"__rmod__"}:
request.applymarker(pytest.mark.xfail)

# pyarrow case
if "table" in str(constructor_with_pyarrow) and attr in {
if "table" in str(constructor) and attr in {
"__rtruediv__",
"__rfloordiv__",
"__rmod__",
}:
request.applymarker(pytest.mark.xfail)

data = {"a": [1, 2, 3]}
df = nw.from_native(constructor_with_pyarrow(data))
df = nw.from_native(constructor(data))
result = df.select(a=getattr(nw.col("a"), attr)(rhs))
compare_dicts(result, {"a": expected})
result = df.select(a=getattr(df["a"], attr)(rhs))
Expand Down
10 changes: 5 additions & 5 deletions tests/expr/cast_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@


@pytest.mark.filterwarnings("ignore:casting period[M] values to int64:FutureWarning")
def test_cast(constructor_with_pyarrow: Any, request: Any) -> None:
if "table" in str(constructor_with_pyarrow) and parse_version(pa.__version__) <= (
15,
): # pragma: no cover
def test_cast(constructor: Any, request: Any) -> None:
if "pyarrow_table_constructor" in str(constructor) and parse_version(
pa.__version__
) <= (15,): # pragma: no cover
request.applymarker(pytest.mark.xfail)
data = {
"a": [1],
Expand Down Expand Up @@ -49,7 +49,7 @@ def test_cast(constructor_with_pyarrow: Any, request: Any) -> None:
"o": nw.Categorical,
"p": nw.Int64,
}
df = nw.from_native(constructor_with_pyarrow(data), eager_only=True).select(
df = nw.from_native(constructor(data), eager_only=True).select(
nw.col(key).cast(value) for key, value in schema.items()
)
result = df.select(
Expand Down
8 changes: 7 additions & 1 deletion tests/expr/cat/get_categories_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,25 @@
from typing import Any

import pyarrow as pa
import pytest

import narwhals.stable.v1 as nw
from tests.utils import compare_dicts

data = {"a": ["one", "two", "two"]}


def test_get_categories(constructor: Any) -> None:
def test_get_categories(request: Any, constructor: Any) -> None:
if "pyarrow_table" in str(constructor):
request.applymarker(pytest.mark.xfail)

df = nw.from_native(constructor(data), eager_only=True)
df = df.select(nw.col("a").cast(nw.Categorical))

result = df.select(nw.col("a").cat.get_categories())
expected = {"a": ["one", "two"]}
compare_dicts(result, expected)

result = df.select(df["a"].cat.get_categories())
compare_dicts(result, expected)

Expand Down
4 changes: 2 additions & 2 deletions tests/expr/count_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
from tests.utils import compare_dicts


def test_count(constructor_with_pyarrow: Any) -> None:
def test_count(constructor: Any) -> None:
data = {"a": [1, 3, 2], "b": [4, None, 6], "z": [7.0, None, None]}
df = nw.from_native(constructor_with_pyarrow(data), eager_only=True)
df = nw.from_native(constructor(data), eager_only=True)
result = df.select(nw.col("a", "b", "z").count())
expected = {"a": [3], "b": [2], "z": [1]}
compare_dicts(result, expected)
4 changes: 2 additions & 2 deletions tests/expr/cum_sum_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
}


def test_cum_sum_simple(constructor_with_pyarrow: Any) -> None:
df = nw.from_native(constructor_with_pyarrow(data), eager_only=True)
def test_cum_sum_simple(constructor: Any) -> None:
df = nw.from_native(constructor(data), eager_only=True)
result = df.select(nw.all().cum_sum())
expected = {
"a": [0, 1, 3, 6, 10],
Expand Down
8 changes: 5 additions & 3 deletions tests/expr/diff_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@
}


def test_diff(constructor_with_pyarrow: Any, request: Any) -> None:
if "table" in str(constructor_with_pyarrow) and parse_version(pa.__version__) < (13,):
def test_diff(constructor: Any, request: Any) -> None:
if "pyarrow_table_constructor" in str(constructor) and parse_version(
pa.__version__
) < (13,):
# pc.pairwisediff is available since pyarrow 13.0.0
request.applymarker(pytest.mark.xfail)
df = nw.from_native(constructor_with_pyarrow(data), eager_only=True)
df = nw.from_native(constructor(data), eager_only=True)
result = df.with_columns(c_diff=nw.col("c").diff())[1:]
expected = {
"i": [1, 2, 3, 4],
Expand Down
Loading

0 comments on commit cfdda8c

Please sign in to comment.