diff --git a/narwhals/_arrow/dataframe.py b/narwhals/_arrow/dataframe.py index 7bfa3c7a9..10696d6d4 100644 --- a/narwhals/_arrow/dataframe.py +++ b/narwhals/_arrow/dataframe.py @@ -175,11 +175,12 @@ def with_columns( to_concat = [] output_names = [] # Make sure to preserve column order + length = len(self) for name in self.columns: if name in new_column_name_to_new_column_map: to_concat.append( validate_dataframe_comparand( - new_column_name_to_new_column_map.pop(name) + length=length, other=new_column_name_to_new_column_map.pop(name) ) ) else: @@ -187,7 +188,9 @@ def with_columns( output_names.append(name) for s in new_column_name_to_new_column_map: to_concat.append( - validate_dataframe_comparand(new_column_name_to_new_column_map[s]) + validate_dataframe_comparand( + length=length, other=new_column_name_to_new_column_map[s] + ) ) output_names.append(s) df = self._native_dataframe.__class__.from_arrays(to_concat, names=output_names) diff --git a/narwhals/_arrow/namespace.py b/narwhals/_arrow/namespace.py index 306ecb773..d6fe38ba4 100644 --- a/narwhals/_arrow/namespace.py +++ b/narwhals/_arrow/namespace.py @@ -7,6 +7,7 @@ from narwhals import dtypes from narwhals._arrow.expr import ArrowExpr +from narwhals._arrow.series import ArrowSeries from narwhals._expression_parsing import parse_into_exprs from narwhals.dependencies import get_pyarrow from narwhals.utils import flatten @@ -15,8 +16,6 @@ from typing import Callable from narwhals._arrow.dataframe import ArrowDataFrame - from narwhals._arrow.expr import ArrowExpr - from narwhals._arrow.series import ArrowSeries from narwhals._arrow.typing import IntoArrowExpr @@ -126,6 +125,26 @@ def all(self) -> ArrowExpr: backend_version=self._backend_version, ) + def lit(self, value: Any, dtype: dtypes.DType | None) -> ArrowExpr: + def _lit_arrow_series(_: ArrowDataFrame) -> ArrowSeries: + arrow_series = ArrowSeries._from_iterable( + data=[value], + name="lit", + backend_version=self._backend_version, + ) + if dtype: + return arrow_series.cast(dtype) + return arrow_series + + return ArrowExpr( + lambda df: [_lit_arrow_series(df)], + depth=0, + function_name="lit", + root_names=None, + output_names=["lit"], + backend_version=self._backend_version, + ) + def all_horizontal(self, *exprs: IntoArrowExpr) -> ArrowExpr: return reduce( lambda x, y: x & y, diff --git a/narwhals/_arrow/series.py b/narwhals/_arrow/series.py index dc29de96d..af9fc7c00 100644 --- a/narwhals/_arrow/series.py +++ b/narwhals/_arrow/series.py @@ -5,7 +5,6 @@ from typing import Iterable from typing import Sequence -from narwhals._arrow.namespace import ArrowNamespace from narwhals._arrow.utils import reverse_translate_dtype from narwhals._arrow.utils import translate_dtype from narwhals._arrow.utils import validate_column_comparand @@ -15,6 +14,7 @@ if TYPE_CHECKING: from typing_extensions import Self + from narwhals._arrow.namespace import ArrowNamespace from narwhals.dtypes import DType @@ -164,6 +164,8 @@ def count(self) -> int: return pc.count(self._native_series) # type: ignore[no-any-return] def __narwhals_namespace__(self) -> ArrowNamespace: + from narwhals._arrow.namespace import ArrowNamespace + return ArrowNamespace(backend_version=self._backend_version) @property diff --git a/narwhals/_arrow/utils.py b/narwhals/_arrow/utils.py index a031c1569..6bda7e8d6 100644 --- a/narwhals/_arrow/utils.py +++ b/narwhals/_arrow/utils.py @@ -123,7 +123,7 @@ def validate_column_comparand(other: Any) -> Any: return other -def validate_dataframe_comparand(other: Any) -> Any: +def validate_dataframe_comparand(length: int, other: Any) -> Any: """Validate RHS of binary operation. If the comparison isn't supported, return `NotImplemented` so that the @@ -136,9 +136,8 @@ def validate_dataframe_comparand(other: Any) -> Any: return NotImplemented if isinstance(other, ArrowSeries): if len(other) == 1: - # broadcast - msg = "not implemented yet" # pragma: no cover - raise NotImplementedError(msg) + pa = get_pyarrow() + return pa.chunked_array([[other.item()] * length]) return other._native_series msg = "Please report a bug" # pragma: no cover raise AssertionError(msg) diff --git a/tests/frame/lit_test.py b/tests/frame/lit_test.py index 1a79b3440..da2999373 100644 --- a/tests/frame/lit_test.py +++ b/tests/frame/lit_test.py @@ -18,26 +18,25 @@ [(None, [2, 2, 2]), (nw.String, ["2", "2", "2"]), (nw.Float32, [2.0, 2.0, 2.0])], ) def test_lit( - constructor_with_lazy: Any, dtype: DType | None, expected_lit: list[Any] + constructor_with_pyarrow: Any, dtype: DType | None, expected_lit: list[Any] ) -> None: data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} - df_raw = constructor_with_lazy(data) - df = nw.from_native(df_raw) + df_raw = constructor_with_pyarrow(data) + df = nw.from_native(df_raw).lazy() result = df.with_columns(nw.lit(2, dtype).alias("lit")) - result_native = nw.to_native(result) expected = { "a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8.0, 9.0], "lit": expected_lit, } - compare_dicts(result_native, expected) + compare_dicts(result, expected) -def test_lit_error(constructor_with_lazy: Any) -> None: +def test_lit_error(constructor_with_pyarrow: Any) -> None: data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} - df_raw = constructor_with_lazy(data) - df = nw.from_native(df_raw) + df_raw = constructor_with_pyarrow(data) + df = nw.from_native(df_raw).lazy() with pytest.raises( ValueError, match="numpy arrays are not supported as literal values" ):