Skip to content

Commit

Permalink
fix(python): DataFrame init from collections.namedtuple values
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-beedie committed Sep 25, 2023
1 parent 26ba7ae commit 7617fd2
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 3 deletions.
7 changes: 4 additions & 3 deletions py-polars/polars/utils/_construction.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ def sequence_to_pyseries(
if (
dataclasses.is_dataclass(value)
or is_pydantic_model(value)
or is_namedtuple(value.__class__, annotated=True)
or is_namedtuple(value.__class__)
):
return pl.DataFrame(values).to_struct(name)._s
elif isinstance(value, range):
Expand Down Expand Up @@ -1080,12 +1080,13 @@ def _sequence_of_tuple_to_pydf(
if is_namedtuple(first_element.__class__):
if schema is None:
schema = first_element._fields # type: ignore[attr-defined]
if len(first_element.__annotations__) == len(schema):
annotations = getattr(first_element, "__annotations__", None)
if annotations and len(annotations) == len(schema):
schema = [
(name, py_type_to_dtype(tp, raise_unmatched=False))
for name, tp in first_element.__annotations__.items()
]
elif orient is None:
if orient is None:
orient = "row"

# ...then defer to generic sequence processing
Expand Down
19 changes: 19 additions & 0 deletions py-polars/tests/unit/test_constructors.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import sys
from collections import namedtuple
from datetime import date, datetime, timedelta, timezone
from decimal import Decimal
from random import shuffle
Expand Down Expand Up @@ -462,6 +463,24 @@ class ABC:
assert dataclasses.asdict(abc) == df.rows(named=True)[0]


def test_collections_namedtuple() -> None:
TestData = namedtuple("TestData", ["id", "info"])
nt_data = [TestData(1, "a"), TestData(2, "b"), TestData(3, "c")]

df1 = pl.DataFrame(nt_data)
assert df1.to_dict(False) == {"id": [1, 2, 3], "info": ["a", "b", "c"]}

df2 = pl.DataFrame({"data": nt_data, "misc": ["x", "y", "z"]})
assert df2.to_dict(False) == {
"data": [
{"id": 1, "info": "a"},
{"id": 2, "info": "b"},
{"id": 3, "info": "c"},
],
"misc": ["x", "y", "z"],
}


def test_init_ndarray(monkeypatch: Any) -> None:
# Empty array
df = pl.DataFrame(np.array([]))
Expand Down

0 comments on commit 7617fd2

Please sign in to comment.