Skip to content

Commit

Permalink
fix(python): improve ingest from numpy scalar values (pola-rs#12025)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-beedie authored Oct 26, 2023
1 parent c8074b6 commit b4b68d9
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 11 deletions.
1 change: 1 addition & 0 deletions py-polars/polars/datatypes/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,7 @@ def DTYPE_TO_PY_TYPE(self) -> dict[PolarsDataType, PythonDataType]:
def NUMPY_KIND_AND_ITEMSIZE_TO_DTYPE(self) -> dict[tuple[str, int], PolarsDataType]:
return {
# (np.dtype().kind, np.dtype().itemsize)
("b", 1): Boolean,
("i", 1): Int8,
("i", 2): Int16,
("i", 4): Int32,
Expand Down
24 changes: 16 additions & 8 deletions py-polars/polars/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,20 +207,28 @@ def _might_be(cls: type, type_: str) -> bool:
return False


def _check_for_numpy(obj: Any) -> bool:
return _NUMPY_AVAILABLE and _might_be(cast(Hashable, type(obj)), "numpy")
def _check_for_numpy(obj: Any, *, check_type: bool = True) -> bool:
return _NUMPY_AVAILABLE and _might_be(
cast(Hashable, type(obj) if check_type else obj), "numpy"
)


def _check_for_pandas(obj: Any) -> bool:
return _PANDAS_AVAILABLE and _might_be(cast(Hashable, type(obj)), "pandas")
def _check_for_pandas(obj: Any, *, check_type: bool = True) -> bool:
return _PANDAS_AVAILABLE and _might_be(
cast(Hashable, type(obj) if check_type else obj), "pandas"
)


def _check_for_pyarrow(obj: Any) -> bool:
return _PYARROW_AVAILABLE and _might_be(cast(Hashable, type(obj)), "pyarrow")
def _check_for_pyarrow(obj: Any, *, check_type: bool = True) -> bool:
return _PYARROW_AVAILABLE and _might_be(
cast(Hashable, type(obj) if check_type else obj), "pyarrow"
)


def _check_for_pydantic(obj: Any) -> bool:
return _PYDANTIC_AVAILABLE and _might_be(cast(Hashable, type(obj)), "pydantic")
def _check_for_pydantic(obj: Any, *, check_type: bool = True) -> bool:
return _PYDANTIC_AVAILABLE and _might_be(
cast(Hashable, type(obj) if check_type else obj), "pydantic"
)


__all__ = [
Expand Down
13 changes: 11 additions & 2 deletions py-polars/polars/utils/_construction.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
Utf8,
dtype_to_py_type,
is_polars_dtype,
numpy_char_code_to_dtype,
py_type_to_dtype,
)
from polars.datatypes.constructor import (
Expand Down Expand Up @@ -554,9 +555,17 @@ def sequence_to_pyseries(
constructor = py_type_to_constructor(python_dtype)
if constructor == PySeries.new_object:
try:
return PySeries.new_from_anyvalues(name, values, strict)
# raised if we cannot convert to Wrap<AnyValue>
srs = PySeries.new_from_anyvalues(name, values, strict)
if _check_for_numpy(python_dtype, check_type=False) and isinstance(
np.bool_(True), np.generic
):
dtype = numpy_char_code_to_dtype(np.dtype(python_dtype).char)
return srs.cast(dtype, strict=strict)
else:
return srs

except RuntimeError:
# raised if we cannot convert to Wrap<AnyValue>
return sequence_from_anyvalue_or_object(name, values)

return _construct_series_with_fallbacks(
Expand Down
17 changes: 16 additions & 1 deletion py-polars/tests/unit/test_constructors.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import sys
from collections import namedtuple
from collections import OrderedDict, namedtuple
from datetime import date, datetime, timedelta, timezone
from decimal import Decimal
from random import shuffle
Expand Down Expand Up @@ -611,6 +611,21 @@ def test_init_ndarray(monkeypatch: Any) -> None:
assert df2.rows() == [(1.0, 4.0), (2.5, None), (None, 6.5)]


def test_init_numpy_scalars() -> None:
df = pl.DataFrame(
{
"bool": [np.bool_(True), np.bool_(False)],
"i8": [np.int8(16), np.int8(64)],
"u32": [np.uint32(1234), np.uint32(9876)],
}
)
df_expected = pl.from_records(
data=[(True, 16, 1234), (False, 64, 9876)],
schema=OrderedDict([("bool", pl.Boolean), ("i8", pl.Int8), ("u32", pl.UInt32)]),
)
assert_frame_equal(df, df_expected)


def test_null_array_print_format() -> None:
pa_tbl_null = pa.table({"a": [None, None]})
df_null = pl.from_arrow(pa_tbl_null)
Expand Down

0 comments on commit b4b68d9

Please sign in to comment.