Skip to content

Commit

Permalink
feat: pyarrow lit (#524)
Browse files Browse the repository at this point in the history
* feat: pyarrow lit

* more generic
  • Loading branch information
FBruzzesi authored Jul 15, 2024
1 parent 1a4a290 commit 57cdcd0
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 17 deletions.
7 changes: 5 additions & 2 deletions narwhals/_arrow/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,19 +175,22 @@ def with_columns(
to_concat = []
output_names = []
# Make sure to preserve column order
length = len(self)
for name in self.columns:
if name in new_column_name_to_new_column_map:
to_concat.append(
validate_dataframe_comparand(
new_column_name_to_new_column_map.pop(name)
length=length, other=new_column_name_to_new_column_map.pop(name)
)
)
else:
to_concat.append(self._native_dataframe[name])
output_names.append(name)
for s in new_column_name_to_new_column_map:
to_concat.append(
validate_dataframe_comparand(new_column_name_to_new_column_map[s])
validate_dataframe_comparand(
length=length, other=new_column_name_to_new_column_map[s]
)
)
output_names.append(s)
df = self._native_dataframe.__class__.from_arrays(to_concat, names=output_names)
Expand Down
23 changes: 21 additions & 2 deletions narwhals/_arrow/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from narwhals import dtypes
from narwhals._arrow.expr import ArrowExpr
from narwhals._arrow.series import ArrowSeries
from narwhals._expression_parsing import parse_into_exprs
from narwhals.dependencies import get_pyarrow
from narwhals.utils import flatten
Expand All @@ -15,8 +16,6 @@
from typing import Callable

from narwhals._arrow.dataframe import ArrowDataFrame
from narwhals._arrow.expr import ArrowExpr
from narwhals._arrow.series import ArrowSeries
from narwhals._arrow.typing import IntoArrowExpr


Expand Down Expand Up @@ -126,6 +125,26 @@ def all(self) -> ArrowExpr:
backend_version=self._backend_version,
)

def lit(self, value: Any, dtype: dtypes.DType | None) -> ArrowExpr:
def _lit_arrow_series(_: ArrowDataFrame) -> ArrowSeries:
arrow_series = ArrowSeries._from_iterable(
data=[value],
name="lit",
backend_version=self._backend_version,
)
if dtype:
return arrow_series.cast(dtype)
return arrow_series

return ArrowExpr(
lambda df: [_lit_arrow_series(df)],
depth=0,
function_name="lit",
root_names=None,
output_names=["lit"],
backend_version=self._backend_version,
)

def all_horizontal(self, *exprs: IntoArrowExpr) -> ArrowExpr:
return reduce(
lambda x, y: x & y,
Expand Down
4 changes: 3 additions & 1 deletion narwhals/_arrow/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from typing import Iterable
from typing import Sequence

from narwhals._arrow.namespace import ArrowNamespace
from narwhals._arrow.utils import reverse_translate_dtype
from narwhals._arrow.utils import translate_dtype
from narwhals._arrow.utils import validate_column_comparand
Expand All @@ -15,6 +14,7 @@
if TYPE_CHECKING:
from typing_extensions import Self

from narwhals._arrow.namespace import ArrowNamespace
from narwhals.dtypes import DType


Expand Down Expand Up @@ -164,6 +164,8 @@ def count(self) -> int:
return pc.count(self._native_series) # type: ignore[no-any-return]

def __narwhals_namespace__(self) -> ArrowNamespace:
from narwhals._arrow.namespace import ArrowNamespace

return ArrowNamespace(backend_version=self._backend_version)

@property
Expand Down
7 changes: 3 additions & 4 deletions narwhals/_arrow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def validate_column_comparand(other: Any) -> Any:
return other


def validate_dataframe_comparand(other: Any) -> Any:
def validate_dataframe_comparand(length: int, other: Any) -> Any:
"""Validate RHS of binary operation.
If the comparison isn't supported, return `NotImplemented` so that the
Expand All @@ -136,9 +136,8 @@ def validate_dataframe_comparand(other: Any) -> Any:
return NotImplemented
if isinstance(other, ArrowSeries):
if len(other) == 1:
# broadcast
msg = "not implemented yet" # pragma: no cover
raise NotImplementedError(msg)
pa = get_pyarrow()
return pa.chunked_array([[other.item()] * length])
return other._native_series
msg = "Please report a bug" # pragma: no cover
raise AssertionError(msg)
15 changes: 7 additions & 8 deletions tests/frame/lit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,26 +18,25 @@
[(None, [2, 2, 2]), (nw.String, ["2", "2", "2"]), (nw.Float32, [2.0, 2.0, 2.0])],
)
def test_lit(
constructor_with_lazy: Any, dtype: DType | None, expected_lit: list[Any]
constructor_with_pyarrow: Any, dtype: DType | None, expected_lit: list[Any]
) -> None:
data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]}
df_raw = constructor_with_lazy(data)
df = nw.from_native(df_raw)
df_raw = constructor_with_pyarrow(data)
df = nw.from_native(df_raw).lazy()
result = df.with_columns(nw.lit(2, dtype).alias("lit"))
result_native = nw.to_native(result)
expected = {
"a": [1, 3, 2],
"b": [4, 4, 6],
"z": [7.0, 8.0, 9.0],
"lit": expected_lit,
}
compare_dicts(result_native, expected)
compare_dicts(result, expected)


def test_lit_error(constructor_with_lazy: Any) -> None:
def test_lit_error(constructor_with_pyarrow: Any) -> None:
data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]}
df_raw = constructor_with_lazy(data)
df = nw.from_native(df_raw)
df_raw = constructor_with_pyarrow(data)
df = nw.from_native(df_raw).lazy()
with pytest.raises(
ValueError, match="numpy arrays are not supported as literal values"
):
Expand Down

0 comments on commit 57cdcd0

Please sign in to comment.