diff --git a/py-polars/polars/utils/_construction.py b/py-polars/polars/utils/_construction.py index 43069ed2977c..542da78b69a0 100644 --- a/py-polars/polars/utils/_construction.py +++ b/py-polars/polars/utils/_construction.py @@ -4,8 +4,9 @@ import warnings from datetime import date, datetime, time, timedelta from decimal import Decimal as PyDecimal -from functools import lru_cache, partial, singledispatch +from functools import lru_cache, singledispatch from itertools import islice, zip_longest +from operator import itemgetter from sys import version_info from typing import ( TYPE_CHECKING, @@ -64,7 +65,13 @@ from polars.exceptions import ComputeError, ShapeError, TimeZoneAwareConstructorWarning from polars.utils._wrap import wrap_df, wrap_s from polars.utils.meta import get_index_type, threadpool_size -from polars.utils.various import _is_generator, arrlen, find_stacklevel, range_to_series +from polars.utils.various import ( + _is_generator, + arrlen, + find_stacklevel, + parse_version, + range_to_series, +) with contextlib.suppress(ImportError): # Module not available when building docs from polars.polars import PyDataFrame, PySeries @@ -972,10 +979,10 @@ def _sequence_to_pydf_dispatcher( to_pydf = _sequence_of_pandas_to_pydf elif dataclasses.is_dataclass(first_element): - to_pydf = _dataclasses_or_models_to_pydf + to_pydf = _dataclasses_to_pydf elif is_pydantic_model(first_element): - to_pydf = partial(_dataclasses_or_models_to_pydf, pydantic_model=True) + to_pydf = _pydantic_models_to_pydf else: to_pydf = _sequence_of_elements_to_pydf @@ -1179,72 +1186,131 @@ def _sequence_of_pandas_to_pydf( return PyDataFrame(data_series) -def _dataclasses_or_models_to_pydf( +def _establish_dataclass_or_model_schema( first_element: Any, - data: Sequence[Any], schema: SchemaDefinition | None, schema_overrides: SchemaDict | None, - infer_schema_length: int | None, - **kwargs: Any, -) -> PyDataFrame: - """Initialise DataFrame from python dataclass and/or pydantic model objects.""" - from dataclasses import asdict, astuple + model_fields: list[str] | None, +) -> tuple[bool, list[str], SchemaDict, SchemaDict]: + """Shared utility code for establishing dataclasses/pydantic model cols/schema.""" + from dataclasses import asdict - from_model = kwargs.get("pydantic_model") unpack_nested = False if schema: column_names, schema_overrides = _unpack_schema( schema, schema_overrides=schema_overrides ) - schema_override = { - col: schema_overrides.get(col, Unknown) for col in column_names - } + overrides = {col: schema_overrides.get(col, Unknown) for col in column_names} else: column_names = [] - schema_override = { + overrides = { col: (py_type_to_dtype(tp, raise_unmatched=False) or Unknown) for col, tp in type_hints(first_element.__class__).items() - if col not in ("__slots__", "__pydantic_root_model__") + if ((col in model_fields) if model_fields else (col != "__slots__")) } if schema_overrides: - schema_override.update(schema_overrides) - elif not from_model: + overrides.update(schema_overrides) + elif not model_fields: dc_fields = set(asdict(first_element)) - schema_overrides = schema_override = { - nm: tp for nm, tp in schema_override.items() if nm in dc_fields + schema_overrides = overrides = { + nm: tp for nm, tp in overrides.items() if nm in dc_fields } else: - schema_overrides = schema_override + schema_overrides = overrides - for col, tp in schema_override.items(): + for col, tp in overrides.items(): if tp == Categorical: - schema_override[col] = Utf8 + overrides[col] = Utf8 elif not unpack_nested and (tp.base_type() in (Unknown, Struct)): unpack_nested = contains_nested( getattr(first_element, col, None), - is_pydantic_model if from_model else dataclasses.is_dataclass, # type: ignore[arg-type] + is_pydantic_model if model_fields else dataclasses.is_dataclass, # type: ignore[arg-type] ) + if model_fields and len(model_fields) == len(overrides): + overrides = dict(zip(model_fields, overrides.values())) + + return unpack_nested, column_names, schema_overrides, overrides + + +def _dataclasses_to_pydf( + first_element: Any, + data: Sequence[Any], + schema: SchemaDefinition | None, + schema_overrides: SchemaDict | None, + infer_schema_length: int | None, + **kwargs: Any, +) -> PyDataFrame: + """Initialise DataFrame from python dataclasses.""" + from dataclasses import asdict, astuple + + ( + unpack_nested, + column_names, + schema_overrides, + overrides, + ) = _establish_dataclass_or_model_schema( + first_element, schema, schema_overrides, model_fields=None + ) if unpack_nested: - if from_model: - dicts = ( - [md.model_dump(mode="python") for md in data] - if hasattr(first_element, "model_dump") - else [md.dict() for md in data] - ) - else: - dicts = [asdict(md) for md in data] + dicts = [asdict(md) for md in data] pydf = PyDataFrame.read_dicts(dicts, infer_schema_length) else: - rows = ( - [tuple(md.__dict__.values()) for md in data] - if from_model - else [astuple(dc) for dc in data] + rows = [astuple(dc) for dc in data] + pydf = PyDataFrame.read_rows(rows, infer_schema_length, overrides or None) + + if overrides: + structs = {c: tp for c, tp in overrides.items() if isinstance(tp, Struct)} + pydf = _post_apply_columns(pydf, column_names, structs, schema_overrides) + + return pydf + + +def _pydantic_models_to_pydf( + first_element: Any, + data: Sequence[Any], + schema: SchemaDefinition | None, + schema_overrides: SchemaDict | None, + infer_schema_length: int | None, + **kwargs: Any, +) -> PyDataFrame: + """Initialise DataFrame from pydantic model objects.""" + import pydantic # note: must already be available in the env here + + old_pydantic = parse_version(pydantic.__version__) < parse_version("2.0") + model_fields = list( + first_element.__fields__ if old_pydantic else first_element.model_fields + ) + ( + unpack_nested, + column_names, + schema_overrides, + overrides, + ) = _establish_dataclass_or_model_schema( + first_element, schema, schema_overrides, model_fields + ) + if unpack_nested: + # note: this is an *extremely* slow path, due to the requirement to + # use pydantic's 'dict()' method to properly unpack nested models + dicts = ( + [md.dict() for md in data] + if old_pydantic + else [md.model_dump(mode="python") for md in data] ) - pydf = PyDataFrame.read_rows(rows, infer_schema_length, schema_override or None) + pydf = PyDataFrame.read_dicts(dicts, infer_schema_length) + + elif len(model_fields) > 50: + # 'read_rows' is the faster codepath for models with a lot of fields... + get_values = itemgetter(*model_fields) + rows = [get_values(md.__dict__) for md in data] + pydf = PyDataFrame.read_rows(rows, infer_schema_length, overrides) + else: + # ...and 'read_dicts' is faster otherwise + dicts = [md.__dict__ for md in data] + pydf = PyDataFrame.read_dicts(dicts, infer_schema_length, overrides) - if schema_override: - structs = {c: tp for c, tp in schema_override.items() if isinstance(tp, Struct)} + if overrides: + structs = {c: tp for c, tp in overrides.items() if isinstance(tp, Struct)} pydf = _post_apply_columns(pydf, column_names, structs, schema_overrides) return pydf diff --git a/py-polars/tests/unit/test_constructors.py b/py-polars/tests/unit/test_constructors.py index 2832d5809e0e..512b2eb830e0 100644 --- a/py-polars/tests/unit/test_constructors.py +++ b/py-polars/tests/unit/test_constructors.py @@ -10,7 +10,7 @@ import pandas as pd import pyarrow as pa import pytest -from pydantic import BaseModel, Field, TypeAdapter +from pydantic import BaseModel, Field import polars as pl from polars.dependencies import _ZONEINFO_AVAILABLE, dataclasses, pydantic @@ -267,37 +267,43 @@ class TradeNT(NamedTuple): def test_init_pydantic_2x() -> None: - class PageView(BaseModel): - user_id: str - ts: datetime = Field(alias=["ts", "$date"]) # type: ignore[literal-required, arg-type] - path: str = Field("?", alias=["url", "path"]) # type: ignore[literal-required, arg-type] - referer: str = Field("?", alias="referer") - event: Literal["leave", "enter"] = Field("enter") - time_on_page: int = Field(0, serialization_alias="top") - - data_json = """ - [{ - "user_id": "x", - "ts": {"$date": "2021-01-01T00:00:00.000Z"}, - "url": "/latest/foobar", - "referer": "https://google.com", - "event": "enter", - "top": 123 - }] - """ - adapter: TypeAdapter[Any] = TypeAdapter(List[PageView]) - models = adapter.validate_json(data_json) - - result = pl.DataFrame(models) - - assert result.to_dict(False) == { - "user_id": ["x"], - "ts": [datetime(2021, 1, 1, 0, 0)], - "path": ["?"], - "referer": ["https://google.com"], - "event": ["enter"], - "time_on_page": [0], - } + try: + # don't fail if manually testing with pydantic 1.x + from pydantic import TypeAdapter + + class PageView(BaseModel): + user_id: str + ts: datetime = Field(alias=["ts", "$date"]) # type: ignore[literal-required, arg-type] + path: str = Field("?", alias=["url", "path"]) # type: ignore[literal-required, arg-type] + referer: str = Field("?", alias="referer") + event: Literal["leave", "enter"] = Field("enter") + time_on_page: int = Field(0, serialization_alias="top") + + data_json = """ + [{ + "user_id": "x", + "ts": {"$date": "2021-01-01T00:00:00.000Z"}, + "url": "/latest/foobar", + "referer": "https://google.com", + "event": "enter", + "top": 123 + }] + """ + adapter: TypeAdapter[Any] = TypeAdapter(List[PageView]) + models = adapter.validate_json(data_json) + + result = pl.DataFrame(models) + + assert result.to_dict(False) == { + "user_id": ["x"], + "ts": [datetime(2021, 1, 1, 0, 0)], + "path": ["?"], + "referer": ["https://google.com"], + "event": ["enter"], + "time_on_page": [0], + } + except ImportError: + pass def test_init_structured_objects_unhashable() -> None: