Skip to content

Commit

Permalink
perf(python): faster init from pydantic models with a small number of…
Browse files Browse the repository at this point in the history
… fields, plus support direct init from SQLModel data (often used with FastAPI)
  • Loading branch information
alexander-beedie committed Sep 23, 2023
1 parent 224eb81 commit 0c99c01
Show file tree
Hide file tree
Showing 2 changed files with 144 additions and 72 deletions.
146 changes: 106 additions & 40 deletions py-polars/polars/utils/_construction.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
import warnings
from datetime import date, datetime, time, timedelta
from decimal import Decimal as PyDecimal
from functools import lru_cache, partial, singledispatch
from functools import lru_cache, singledispatch
from itertools import islice, zip_longest
from operator import itemgetter
from sys import version_info
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -64,7 +65,13 @@
from polars.exceptions import ComputeError, ShapeError, TimeZoneAwareConstructorWarning
from polars.utils._wrap import wrap_df, wrap_s
from polars.utils.meta import get_index_type, threadpool_size
from polars.utils.various import _is_generator, arrlen, find_stacklevel, range_to_series
from polars.utils.various import (
_is_generator,
arrlen,
find_stacklevel,
parse_version,
range_to_series,
)

with contextlib.suppress(ImportError): # Module not available when building docs
from polars.polars import PyDataFrame, PySeries
Expand Down Expand Up @@ -972,10 +979,10 @@ def _sequence_to_pydf_dispatcher(
to_pydf = _sequence_of_pandas_to_pydf

elif dataclasses.is_dataclass(first_element):
to_pydf = _dataclasses_or_models_to_pydf
to_pydf = _dataclasses_to_pydf

elif is_pydantic_model(first_element):
to_pydf = partial(_dataclasses_or_models_to_pydf, pydantic_model=True)
to_pydf = _pydantic_models_to_pydf
else:
to_pydf = _sequence_of_elements_to_pydf

Expand Down Expand Up @@ -1179,72 +1186,131 @@ def _sequence_of_pandas_to_pydf(
return PyDataFrame(data_series)


def _dataclasses_or_models_to_pydf(
def _establish_dataclass_or_model_schema(
first_element: Any,
data: Sequence[Any],
schema: SchemaDefinition | None,
schema_overrides: SchemaDict | None,
infer_schema_length: int | None,
**kwargs: Any,
) -> PyDataFrame:
"""Initialise DataFrame from python dataclass and/or pydantic model objects."""
from dataclasses import asdict, astuple
model_fields: list[str] | None,
) -> tuple[bool, list[str], SchemaDict, SchemaDict]:
"""Shared utility code for establishing dataclasses/pydantic model cols/schema."""
from dataclasses import asdict

from_model = kwargs.get("pydantic_model")
unpack_nested = False
if schema:
column_names, schema_overrides = _unpack_schema(
schema, schema_overrides=schema_overrides
)
schema_override = {
col: schema_overrides.get(col, Unknown) for col in column_names
}
overrides = {col: schema_overrides.get(col, Unknown) for col in column_names}
else:
column_names = []
schema_override = {
overrides = {
col: (py_type_to_dtype(tp, raise_unmatched=False) or Unknown)
for col, tp in type_hints(first_element.__class__).items()
if col not in ("__slots__", "__pydantic_root_model__")
if ((col in model_fields) if model_fields else (col != "__slots__"))
}
if schema_overrides:
schema_override.update(schema_overrides)
elif not from_model:
overrides.update(schema_overrides)
elif not model_fields:
dc_fields = set(asdict(first_element))
schema_overrides = schema_override = {
nm: tp for nm, tp in schema_override.items() if nm in dc_fields
schema_overrides = overrides = {
nm: tp for nm, tp in overrides.items() if nm in dc_fields
}
else:
schema_overrides = schema_override
schema_overrides = overrides

for col, tp in schema_override.items():
for col, tp in overrides.items():
if tp == Categorical:
schema_override[col] = Utf8
overrides[col] = Utf8
elif not unpack_nested and (tp.base_type() in (Unknown, Struct)):
unpack_nested = contains_nested(
getattr(first_element, col, None),
is_pydantic_model if from_model else dataclasses.is_dataclass, # type: ignore[arg-type]
is_pydantic_model if model_fields else dataclasses.is_dataclass, # type: ignore[arg-type]
)

if model_fields and len(model_fields) == len(overrides):
overrides = dict(zip(model_fields, overrides.values()))

return unpack_nested, column_names, schema_overrides, overrides


def _dataclasses_to_pydf(
first_element: Any,
data: Sequence[Any],
schema: SchemaDefinition | None,
schema_overrides: SchemaDict | None,
infer_schema_length: int | None,
**kwargs: Any,
) -> PyDataFrame:
"""Initialise DataFrame from python dataclasses."""
from dataclasses import asdict, astuple

(
unpack_nested,
column_names,
schema_overrides,
overrides,
) = _establish_dataclass_or_model_schema(
first_element, schema, schema_overrides, model_fields=None
)
if unpack_nested:
if from_model:
dicts = (
[md.model_dump(mode="python") for md in data]
if hasattr(first_element, "model_dump")
else [md.dict() for md in data]
)
else:
dicts = [asdict(md) for md in data]
dicts = [asdict(md) for md in data]
pydf = PyDataFrame.read_dicts(dicts, infer_schema_length)
else:
rows = (
[tuple(md.__dict__.values()) for md in data]
if from_model
else [astuple(dc) for dc in data]
rows = [astuple(dc) for dc in data]
pydf = PyDataFrame.read_rows(rows, infer_schema_length, overrides or None)

if overrides:
structs = {c: tp for c, tp in overrides.items() if isinstance(tp, Struct)}
pydf = _post_apply_columns(pydf, column_names, structs, schema_overrides)

return pydf


def _pydantic_models_to_pydf(
first_element: Any,
data: Sequence[Any],
schema: SchemaDefinition | None,
schema_overrides: SchemaDict | None,
infer_schema_length: int | None,
**kwargs: Any,
) -> PyDataFrame:
"""Initialise DataFrame from pydantic model objects."""
import pydantic # note: must already be available in the env here

old_pydantic = parse_version(pydantic.__version__) < parse_version("2.0")
model_fields = list(
first_element.__fields__ if old_pydantic else first_element.model_fields
)
(
unpack_nested,
column_names,
schema_overrides,
overrides,
) = _establish_dataclass_or_model_schema(
first_element, schema, schema_overrides, model_fields
)
if unpack_nested:
# note: this is an *extremely* slow path, due to the requirement to
# use pydantic's 'dict()' method to properly unpack nested models
dicts = (
[md.dict() for md in data]
if old_pydantic
else [md.model_dump(mode="python") for md in data]
)
pydf = PyDataFrame.read_rows(rows, infer_schema_length, schema_override or None)
pydf = PyDataFrame.read_dicts(dicts, infer_schema_length)

elif len(model_fields) > 50:
# 'read_rows' is the faster codepath for models with a lot of fields...
get_values = itemgetter(*model_fields)
rows = [get_values(md.__dict__) for md in data]
pydf = PyDataFrame.read_rows(rows, infer_schema_length, overrides)
else:
# ...and 'read_dicts' is faster otherwise
dicts = [md.__dict__ for md in data]
pydf = PyDataFrame.read_dicts(dicts, infer_schema_length, overrides)

if schema_override:
structs = {c: tp for c, tp in schema_override.items() if isinstance(tp, Struct)}
if overrides:
structs = {c: tp for c, tp in overrides.items() if isinstance(tp, Struct)}
pydf = _post_apply_columns(pydf, column_names, structs, schema_overrides)

return pydf
Expand Down
70 changes: 38 additions & 32 deletions py-polars/tests/unit/test_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import pandas as pd
import pyarrow as pa
import pytest
from pydantic import BaseModel, Field, TypeAdapter
from pydantic import BaseModel, Field

import polars as pl
from polars.dependencies import _ZONEINFO_AVAILABLE, dataclasses, pydantic
Expand Down Expand Up @@ -267,37 +267,43 @@ class TradeNT(NamedTuple):


def test_init_pydantic_2x() -> None:
class PageView(BaseModel):
user_id: str
ts: datetime = Field(alias=["ts", "$date"]) # type: ignore[literal-required, arg-type]
path: str = Field("?", alias=["url", "path"]) # type: ignore[literal-required, arg-type]
referer: str = Field("?", alias="referer")
event: Literal["leave", "enter"] = Field("enter")
time_on_page: int = Field(0, serialization_alias="top")

data_json = """
[{
"user_id": "x",
"ts": {"$date": "2021-01-01T00:00:00.000Z"},
"url": "/latest/foobar",
"referer": "https://google.com",
"event": "enter",
"top": 123
}]
"""
adapter: TypeAdapter[Any] = TypeAdapter(List[PageView])
models = adapter.validate_json(data_json)

result = pl.DataFrame(models)

assert result.to_dict(False) == {
"user_id": ["x"],
"ts": [datetime(2021, 1, 1, 0, 0)],
"path": ["?"],
"referer": ["https://google.com"],
"event": ["enter"],
"time_on_page": [0],
}
try:
# don't fail if manually testing with pydantic 1.x
from pydantic import TypeAdapter

class PageView(BaseModel):
user_id: str
ts: datetime = Field(alias=["ts", "$date"]) # type: ignore[literal-required, arg-type]
path: str = Field("?", alias=["url", "path"]) # type: ignore[literal-required, arg-type]
referer: str = Field("?", alias="referer")
event: Literal["leave", "enter"] = Field("enter")
time_on_page: int = Field(0, serialization_alias="top")

data_json = """
[{
"user_id": "x",
"ts": {"$date": "2021-01-01T00:00:00.000Z"},
"url": "/latest/foobar",
"referer": "https://google.com",
"event": "enter",
"top": 123
}]
"""
adapter: TypeAdapter[Any] = TypeAdapter(List[PageView])
models = adapter.validate_json(data_json)

result = pl.DataFrame(models)

assert result.to_dict(False) == {
"user_id": ["x"],
"ts": [datetime(2021, 1, 1, 0, 0)],
"path": ["?"],
"referer": ["https://google.com"],
"event": ["enter"],
"time_on_page": [0],
}
except ImportError:
pass


def test_init_structured_objects_unhashable() -> None:
Expand Down

0 comments on commit 0c99c01

Please sign in to comment.