Skip to content

Commit

Permalink
feat: pyarrow lit
Browse files Browse the repository at this point in the history
  • Loading branch information
FBruzzesi committed Jul 14, 2024
1 parent 1a4a290 commit b16ab5e
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 11 deletions.
23 changes: 21 additions & 2 deletions narwhals/_arrow/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from narwhals import dtypes
from narwhals._arrow.expr import ArrowExpr
from narwhals._arrow.series import ArrowSeries
from narwhals._expression_parsing import parse_into_exprs
from narwhals.dependencies import get_pyarrow
from narwhals.utils import flatten
Expand All @@ -15,8 +16,6 @@
from typing import Callable

from narwhals._arrow.dataframe import ArrowDataFrame
from narwhals._arrow.expr import ArrowExpr
from narwhals._arrow.series import ArrowSeries
from narwhals._arrow.typing import IntoArrowExpr


Expand Down Expand Up @@ -126,6 +125,26 @@ def all(self) -> ArrowExpr:
backend_version=self._backend_version,
)

def lit(self, value: Any, dtype: dtypes.DType | None) -> ArrowExpr:
def _lit_arrow_series(df: ArrowDataFrame) -> ArrowSeries:
arrow_series = ArrowSeries._from_iterable(
data=[value] * len(df),
name="lit",
backend_version=self._backend_version,
)
if dtype:
return arrow_series.cast(dtype)
return arrow_series

return ArrowExpr(
lambda df: [_lit_arrow_series(df)],
depth=0,
function_name="lit",
root_names=None,
output_names=["lit"],
backend_version=self._backend_version,
)

def all_horizontal(self, *exprs: IntoArrowExpr) -> ArrowExpr:
return reduce(
lambda x, y: x & y,
Expand Down
4 changes: 3 additions & 1 deletion narwhals/_arrow/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from typing import Iterable
from typing import Sequence

from narwhals._arrow.namespace import ArrowNamespace
from narwhals._arrow.utils import reverse_translate_dtype
from narwhals._arrow.utils import translate_dtype
from narwhals._arrow.utils import validate_column_comparand
Expand All @@ -15,6 +14,7 @@
if TYPE_CHECKING:
from typing_extensions import Self

from narwhals._arrow.namespace import ArrowNamespace
from narwhals.dtypes import DType


Expand Down Expand Up @@ -164,6 +164,8 @@ def count(self) -> int:
return pc.count(self._native_series) # type: ignore[no-any-return]

def __narwhals_namespace__(self) -> ArrowNamespace:
from narwhals._arrow.namespace import ArrowNamespace

return ArrowNamespace(backend_version=self._backend_version)

@property
Expand Down
15 changes: 7 additions & 8 deletions tests/frame/lit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,26 +18,25 @@
[(None, [2, 2, 2]), (nw.String, ["2", "2", "2"]), (nw.Float32, [2.0, 2.0, 2.0])],
)
def test_lit(
constructor_with_lazy: Any, dtype: DType | None, expected_lit: list[Any]
constructor_with_pyarrow: Any, dtype: DType | None, expected_lit: list[Any]
) -> None:
data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]}
df_raw = constructor_with_lazy(data)
df = nw.from_native(df_raw)
df_raw = constructor_with_pyarrow(data)
df = nw.from_native(df_raw).lazy()
result = df.with_columns(nw.lit(2, dtype).alias("lit"))
result_native = nw.to_native(result)
expected = {
"a": [1, 3, 2],
"b": [4, 4, 6],
"z": [7.0, 8.0, 9.0],
"lit": expected_lit,
}
compare_dicts(result_native, expected)
compare_dicts(result, expected)


def test_lit_error(constructor_with_lazy: Any) -> None:
def test_lit_error(constructor_with_pyarrow: Any) -> None:
data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]}
df_raw = constructor_with_lazy(data)
df = nw.from_native(df_raw)
df_raw = constructor_with_pyarrow(data)
df = nw.from_native(df_raw).lazy()
with pytest.raises(
ValueError, match="numpy arrays are not supported as literal values"
):
Expand Down

0 comments on commit b16ab5e

Please sign in to comment.