Skip to content

Commit

Permalink
more generic
Browse files Browse the repository at this point in the history
  • Loading branch information
FBruzzesi committed Jul 14, 2024
1 parent b16ab5e commit 4ad1e7e
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 8 deletions.
7 changes: 5 additions & 2 deletions narwhals/_arrow/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,19 +175,22 @@ def with_columns(
to_concat = []
output_names = []
# Make sure to preserve column order
length = len(self)
for name in self.columns:
if name in new_column_name_to_new_column_map:
to_concat.append(
validate_dataframe_comparand(
new_column_name_to_new_column_map.pop(name)
length=length, other=new_column_name_to_new_column_map.pop(name)
)
)
else:
to_concat.append(self._native_dataframe[name])
output_names.append(name)
for s in new_column_name_to_new_column_map:
to_concat.append(
validate_dataframe_comparand(new_column_name_to_new_column_map[s])
validate_dataframe_comparand(
length=length, other=new_column_name_to_new_column_map[s]
)
)
output_names.append(s)
df = self._native_dataframe.__class__.from_arrays(to_concat, names=output_names)
Expand Down
4 changes: 2 additions & 2 deletions narwhals/_arrow/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,9 @@ def all(self) -> ArrowExpr:
)

def lit(self, value: Any, dtype: dtypes.DType | None) -> ArrowExpr:
def _lit_arrow_series(df: ArrowDataFrame) -> ArrowSeries:
def _lit_arrow_series(_: ArrowDataFrame) -> ArrowSeries:
arrow_series = ArrowSeries._from_iterable(
data=[value] * len(df),
data=[value],
name="lit",
backend_version=self._backend_version,
)
Expand Down
7 changes: 3 additions & 4 deletions narwhals/_arrow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def validate_column_comparand(other: Any) -> Any:
return other


def validate_dataframe_comparand(other: Any) -> Any:
def validate_dataframe_comparand(length: int, other: Any) -> Any:
"""Validate RHS of binary operation.
If the comparison isn't supported, return `NotImplemented` so that the
Expand All @@ -136,9 +136,8 @@ def validate_dataframe_comparand(other: Any) -> Any:
return NotImplemented
if isinstance(other, ArrowSeries):
if len(other) == 1:
# broadcast
msg = "not implemented yet" # pragma: no cover
raise NotImplementedError(msg)
pa = get_pyarrow()
return pa.chunked_array([[other.item()] * length])
return other._native_series
msg = "Please report a bug" # pragma: no cover
raise AssertionError(msg)

0 comments on commit 4ad1e7e

Please sign in to comment.