Skip to content

Commit

Permalink
feat(python): Automatically wrap NumPy array as lit (pola-rs#12709)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Nov 27, 2023
1 parent 8aebb03 commit 941e9c1
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 2 deletions.
2 changes: 1 addition & 1 deletion py-polars/polars/expr/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2380,7 +2380,7 @@ def gather(
):
indices_lit = F.lit(pl.Series("", indices, dtype=UInt32))._pyexpr
else:
indices_lit = parse_as_expression(indices) # type: ignore[arg-type]
indices_lit = parse_as_expression(indices)
return self._from_pyexpr(self._pyexpr.gather(indices_lit))

def get(self, index: int | Expr) -> Self:
Expand Down
2 changes: 1 addition & 1 deletion py-polars/polars/type_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@
NumericLiteral, TemporalLiteral, str, bool, bytes, List[Any]
]
# Inputs that can convert into a `col` expression
IntoExprColumn: TypeAlias = Union["Expr", "Series", str]
IntoExprColumn: TypeAlias = Union["Expr", "Series", str, "np.ndarray"]
# Inputs that can convert into an expression
IntoExpr: TypeAlias = Union[PythonLiteral, IntoExprColumn, None]

Expand Down
6 changes: 6 additions & 0 deletions py-polars/polars/utils/_parse_expr_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
from polars.polars import PyExpr
from polars.type_aliases import IntoExpr

from polars.dependencies import _check_for_numpy
from polars.dependencies import numpy as np


def parse_as_list_of_expressions(
*inputs: IntoExpr | Iterable[IntoExpr],
Expand Down Expand Up @@ -116,6 +119,9 @@ def parse_as_expression(
elif isinstance(input, (list, tuple)):
expr = F.lit(pl.Series("literal", [input]))
structify = False
elif _check_for_numpy(input) and isinstance(input, np.ndarray):
expr = F.lit(pl.Series("literal", input))
structify = False
else:
raise TypeError(
f"did not expect value {input!r} of type {type(input).__name__!r}"
Expand Down
8 changes: 8 additions & 0 deletions py-polars/tests/unit/interop/test_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,11 @@ def test_view_deprecated() -> None:
result = s.view()
assert isinstance(result, np.ndarray)
assert np.all(result == np.array([1.0, 2.5, 3.0]))


def test_numpy_disambiguation() -> None:
a = np.array([1, 2])
assert pl.DataFrame({"a": a}).with_columns(b=a).to_dict(as_series=False) == {
"a": [1, 2],
"b": [1, 2],
}

0 comments on commit 941e9c1

Please sign in to comment.