From 4ad1e7e8fd2246984e2ab8d3f0a53b526795ea59 Mon Sep 17 00:00:00 2001 From: FBruzzesi Date: Sun, 14 Jul 2024 22:32:28 +0200 Subject: [PATCH] more generic --- narwhals/_arrow/dataframe.py | 7 +++++-- narwhals/_arrow/namespace.py | 4 ++-- narwhals/_arrow/utils.py | 7 +++---- 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/narwhals/_arrow/dataframe.py b/narwhals/_arrow/dataframe.py index 7bfa3c7a9..10696d6d4 100644 --- a/narwhals/_arrow/dataframe.py +++ b/narwhals/_arrow/dataframe.py @@ -175,11 +175,12 @@ def with_columns( to_concat = [] output_names = [] # Make sure to preserve column order + length = len(self) for name in self.columns: if name in new_column_name_to_new_column_map: to_concat.append( validate_dataframe_comparand( - new_column_name_to_new_column_map.pop(name) + length=length, other=new_column_name_to_new_column_map.pop(name) ) ) else: @@ -187,7 +188,9 @@ def with_columns( output_names.append(name) for s in new_column_name_to_new_column_map: to_concat.append( - validate_dataframe_comparand(new_column_name_to_new_column_map[s]) + validate_dataframe_comparand( + length=length, other=new_column_name_to_new_column_map[s] + ) ) output_names.append(s) df = self._native_dataframe.__class__.from_arrays(to_concat, names=output_names) diff --git a/narwhals/_arrow/namespace.py b/narwhals/_arrow/namespace.py index e479ec463..d6fe38ba4 100644 --- a/narwhals/_arrow/namespace.py +++ b/narwhals/_arrow/namespace.py @@ -126,9 +126,9 @@ def all(self) -> ArrowExpr: ) def lit(self, value: Any, dtype: dtypes.DType | None) -> ArrowExpr: - def _lit_arrow_series(df: ArrowDataFrame) -> ArrowSeries: + def _lit_arrow_series(_: ArrowDataFrame) -> ArrowSeries: arrow_series = ArrowSeries._from_iterable( - data=[value] * len(df), + data=[value], name="lit", backend_version=self._backend_version, ) diff --git a/narwhals/_arrow/utils.py b/narwhals/_arrow/utils.py index a031c1569..6bda7e8d6 100644 --- a/narwhals/_arrow/utils.py +++ b/narwhals/_arrow/utils.py @@ -123,7 +123,7 @@ def validate_column_comparand(other: Any) -> Any: return other -def validate_dataframe_comparand(other: Any) -> Any: +def validate_dataframe_comparand(length: int, other: Any) -> Any: """Validate RHS of binary operation. If the comparison isn't supported, return `NotImplemented` so that the @@ -136,9 +136,8 @@ def validate_dataframe_comparand(other: Any) -> Any: return NotImplemented if isinstance(other, ArrowSeries): if len(other) == 1: - # broadcast - msg = "not implemented yet" # pragma: no cover - raise NotImplementedError(msg) + pa = get_pyarrow() + return pa.chunked_array([[other.item()] * length]) return other._native_series msg = "Please report a bug" # pragma: no cover raise AssertionError(msg)