From b16ab5ea50fce1e20b4bea895f53dedf22799981 Mon Sep 17 00:00:00 2001 From: FBruzzesi Date: Sun, 14 Jul 2024 22:13:08 +0200 Subject: [PATCH] feat: pyarrow lit --- narwhals/_arrow/namespace.py | 23 +++++++++++++++++++++-- narwhals/_arrow/series.py | 4 +++- tests/frame/lit_test.py | 15 +++++++-------- 3 files changed, 31 insertions(+), 11 deletions(-) diff --git a/narwhals/_arrow/namespace.py b/narwhals/_arrow/namespace.py index 306ecb773..e479ec463 100644 --- a/narwhals/_arrow/namespace.py +++ b/narwhals/_arrow/namespace.py @@ -7,6 +7,7 @@ from narwhals import dtypes from narwhals._arrow.expr import ArrowExpr +from narwhals._arrow.series import ArrowSeries from narwhals._expression_parsing import parse_into_exprs from narwhals.dependencies import get_pyarrow from narwhals.utils import flatten @@ -15,8 +16,6 @@ from typing import Callable from narwhals._arrow.dataframe import ArrowDataFrame - from narwhals._arrow.expr import ArrowExpr - from narwhals._arrow.series import ArrowSeries from narwhals._arrow.typing import IntoArrowExpr @@ -126,6 +125,26 @@ def all(self) -> ArrowExpr: backend_version=self._backend_version, ) + def lit(self, value: Any, dtype: dtypes.DType | None) -> ArrowExpr: + def _lit_arrow_series(df: ArrowDataFrame) -> ArrowSeries: + arrow_series = ArrowSeries._from_iterable( + data=[value] * len(df), + name="lit", + backend_version=self._backend_version, + ) + if dtype: + return arrow_series.cast(dtype) + return arrow_series + + return ArrowExpr( + lambda df: [_lit_arrow_series(df)], + depth=0, + function_name="lit", + root_names=None, + output_names=["lit"], + backend_version=self._backend_version, + ) + def all_horizontal(self, *exprs: IntoArrowExpr) -> ArrowExpr: return reduce( lambda x, y: x & y, diff --git a/narwhals/_arrow/series.py b/narwhals/_arrow/series.py index dc29de96d..af9fc7c00 100644 --- a/narwhals/_arrow/series.py +++ b/narwhals/_arrow/series.py @@ -5,7 +5,6 @@ from typing import Iterable from typing import Sequence -from narwhals._arrow.namespace import ArrowNamespace from narwhals._arrow.utils import reverse_translate_dtype from narwhals._arrow.utils import translate_dtype from narwhals._arrow.utils import validate_column_comparand @@ -15,6 +14,7 @@ if TYPE_CHECKING: from typing_extensions import Self + from narwhals._arrow.namespace import ArrowNamespace from narwhals.dtypes import DType @@ -164,6 +164,8 @@ def count(self) -> int: return pc.count(self._native_series) # type: ignore[no-any-return] def __narwhals_namespace__(self) -> ArrowNamespace: + from narwhals._arrow.namespace import ArrowNamespace + return ArrowNamespace(backend_version=self._backend_version) @property diff --git a/tests/frame/lit_test.py b/tests/frame/lit_test.py index 1a79b3440..da2999373 100644 --- a/tests/frame/lit_test.py +++ b/tests/frame/lit_test.py @@ -18,26 +18,25 @@ [(None, [2, 2, 2]), (nw.String, ["2", "2", "2"]), (nw.Float32, [2.0, 2.0, 2.0])], ) def test_lit( - constructor_with_lazy: Any, dtype: DType | None, expected_lit: list[Any] + constructor_with_pyarrow: Any, dtype: DType | None, expected_lit: list[Any] ) -> None: data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} - df_raw = constructor_with_lazy(data) - df = nw.from_native(df_raw) + df_raw = constructor_with_pyarrow(data) + df = nw.from_native(df_raw).lazy() result = df.with_columns(nw.lit(2, dtype).alias("lit")) - result_native = nw.to_native(result) expected = { "a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8.0, 9.0], "lit": expected_lit, } - compare_dicts(result_native, expected) + compare_dicts(result, expected) -def test_lit_error(constructor_with_lazy: Any) -> None: +def test_lit_error(constructor_with_pyarrow: Any) -> None: data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} - df_raw = constructor_with_lazy(data) - df = nw.from_native(df_raw) + df_raw = constructor_with_pyarrow(data) + df = nw.from_native(df_raw).lazy() with pytest.raises( ValueError, match="numpy arrays are not supported as literal values" ):