Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Pydantic v2 compat #2888

Closed
Closed
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
ee45b97
work so far
thejaminator May 20, 2023
4c13e84
separate
thejaminator May 20, 2023
95f979a
remove field map
thejaminator May 20, 2023
ddf7ade
yay test pass
thejaminator May 20, 2023
9734494
fix fields
thejaminator May 20, 2023
2b2c1a8
fix passing fields
thejaminator May 20, 2023
df87fb7
it worksgit stage .
thejaminator May 20, 2023
017e168
revert
thejaminator May 20, 2023
2103eab
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 25, 2023
064db83
remove compat file
thejaminator Jun 25, 2023
274f2b6
fix pydantic v2 update issues
thejaminator Jul 7, 2023
ae295f5
Merge remote-tracking branch 'origin/main' into pydantic-v2-compat
thejaminator Jul 7, 2023
f582d41
add test for #2782
thejaminator Jul 7, 2023
045856a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 7, 2023
590166d
Merge branch 'main' into pydantic-v2-compat
patrick91 Jul 8, 2023
176dcbf
Merge remote-tracking branch 'origin/main' into pydantic-v2-compat
thejaminator Jul 11, 2023
5139f7a
mark pydantic v2 explicitly
thejaminator Jul 11, 2023
0e0d1a5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 11, 2023
8fa12dc
add pytest markers for pydantic_v2
thejaminator Jul 11, 2023
c7a9b8c
try again
thejaminator Jul 11, 2023
6be92be
add ignore pydantic2
thejaminator Jul 12, 2023
fecd02d
add explicit dir
thejaminator Jul 12, 2023
883b8a4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 12, 2023
e6fd48b
remove markers
thejaminator Jul 12, 2023
858998c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 12, 2023
ae9eff1
add ignore for test
thejaminator Jul 12, 2023
42203c6
add hints
thejaminator Jul 13, 2023
3d97786
fix weird cli tests changes
thejaminator Jul 13, 2023
227736d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 13, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions strawberry/experimental/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
try:
from . import pydantic

__all__ = ["pydantic"]
except ImportError:
pass
else:
__all__ = ["pydantic"]
try:
from . import pydantic2

# Support for pydantic2 is highly experimental and the interface will change
# We don't recommend using it yet
__all__ = ["pydantic2"]
except ImportError as e:
pass
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

explicitly a separate module

11 changes: 11 additions & 0 deletions strawberry/experimental/pydantic2/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from .error_type import error_type
from .exceptions import UnregisteredTypeException
from .object_type import input, interface, type

__all__ = [
"error_type",
"UnregisteredTypeException",
"input",
"type",
"interface",
]
113 changes: 113 additions & 0 deletions strawberry/experimental/pydantic2/conversion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
from __future__ import annotations

import copy
import dataclasses
from typing import TYPE_CHECKING, Any, Type, Union, cast

from strawberry.enum import EnumDefinition
from strawberry.type import StrawberryList, StrawberryOptional
from strawberry.union import StrawberryUnion

if TYPE_CHECKING:
from strawberry.type import StrawberryType


def _convert_from_pydantic_to_strawberry_type(
type_: Union[StrawberryType, type], data_from_model=None, extra=None # noqa: ANN001
):
data = data_from_model if data_from_model is not None else extra

if isinstance(type_, StrawberryOptional):
if data is None:
return data
return _convert_from_pydantic_to_strawberry_type(
type_.of_type, data_from_model=data, extra=extra
)
if isinstance(type_, StrawberryUnion):
for option_type in type_.types:
if hasattr(option_type, "_pydantic_type"):
source_type = option_type._pydantic_type
else:
source_type = cast(type, option_type)
if isinstance(data, source_type):
return _convert_from_pydantic_to_strawberry_type(
option_type, data_from_model=data, extra=extra
)
if isinstance(type_, EnumDefinition):
return data
if isinstance(type_, StrawberryList):
items = []
for index, item in enumerate(data):
items.append(
_convert_from_pydantic_to_strawberry_type(
type_.of_type,
data_from_model=item,
extra=extra[index] if extra else None,
)
)

return items

if hasattr(type_, "_type_definition"):
# in the case of an interface, the concrete type may be more specific
# than the type in the field definition
# don't check _strawberry_input_type because inputs can't be interfaces
if hasattr(type(data), "_strawberry_type"):
type_ = type(data)._strawberry_type
if hasattr(type_, "from_pydantic"):
return type_.from_pydantic(data_from_model, extra)
return convert_pydantic_model_to_strawberry_class(
type_, model_instance=data_from_model, extra=extra
)

return data


def convert_pydantic_model_to_strawberry_class(
cls, *, model_instance=None, extra=None # noqa: ANN001
) -> Any:
extra = extra or {}
kwargs = {}

for field_ in cls._type_definition.fields:
field = cast("StrawberryField", field_)
python_name = field.python_name

data_from_extra = extra.get(python_name, None)
data_from_model = (
getattr(model_instance, python_name, None) if model_instance else None
)

# only convert and add fields to kwargs if they are present in the `__init__`
# method of the class
if field.init:
kwargs[python_name] = _convert_from_pydantic_to_strawberry_type(
field.type, data_from_model, extra=data_from_extra
)

return cls(**kwargs)


def convert_strawberry_class_to_pydantic_model(obj: Type) -> Any:
if hasattr(obj, "to_pydantic"):
return obj.to_pydantic()
elif dataclasses.is_dataclass(obj):
result = []
for f in dataclasses.fields(obj):
value = convert_strawberry_class_to_pydantic_model(getattr(obj, f.name))
result.append((f.name, value))
return dict(result)
elif isinstance(obj, (list, tuple)):
# Assume we can create an object of this type by passing in a
# generator (which is not true for namedtuples, not supported).
return type(obj)(convert_strawberry_class_to_pydantic_model(v) for v in obj)
elif isinstance(obj, dict):
return type(obj)(
(
convert_strawberry_class_to_pydantic_model(k),
convert_strawberry_class_to_pydantic_model(v),
)
for k, v in obj.items()
)
else:
return copy.deepcopy(obj)
37 changes: 37 additions & 0 deletions strawberry/experimental/pydantic2/conversion_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Dict, Optional, Type, TypeVar
from typing_extensions import Protocol

from pydantic import BaseModel

if TYPE_CHECKING:
from strawberry.types.types import TypeDefinition


PydanticModel = TypeVar("PydanticModel", bound=BaseModel)


class StrawberryTypeFromPydantic(Protocol[PydanticModel]):
"""This class does not exist in runtime.
It only makes the methods below visible for IDEs"""

def __init__(self, **kwargs):
...

@staticmethod
def from_pydantic(
instance: PydanticModel, extra: Optional[Dict[str, Any]] = None
) -> StrawberryTypeFromPydantic[PydanticModel]:
...

def to_pydantic(self, **kwargs) -> PydanticModel:
...

@property
def _type_definition(self) -> TypeDefinition:
...

@property
def _pydantic_type(self) -> Type[PydanticModel]:
...
149 changes: 149 additions & 0 deletions strawberry/experimental/pydantic2/error_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
from __future__ import annotations

import dataclasses
import warnings
from typing import (
TYPE_CHECKING,
Any,
Callable,
List,
Optional,
Sequence,
Tuple,
Type,
Union,
cast,
)

from pydantic import BaseModel
from pydantic._internal._utils import lenient_issubclass

from strawberry.auto import StrawberryAuto
from strawberry.experimental.pydantic2.utils import (
get_private_fields,
get_strawberry_type_from_model,
normalize_type,
)
from strawberry.object_type import _process_type, _wrap_dataclass
from strawberry.types.type_resolver import _get_fields
from strawberry.utils.typing import get_list_annotation, is_list

from .exceptions import MissingFieldsListError

if TYPE_CHECKING:
from pydantic.fields import ModelField


def get_type_for_field(field: ModelField) -> Union[Any, Type[None], Type[List]]:
type_ = field.outer_type_
type_ = normalize_type(type_)
return field_type_to_type(type_)


def field_type_to_type(type_: Type) -> Union[Any, List[Any], None]:
error_class: Any = str
strawberry_type: Any = error_class

if is_list(type_):
child_type = get_list_annotation(type_)

if is_list(child_type):
strawberry_type = field_type_to_type(child_type)
elif lenient_issubclass(child_type, BaseModel):
strawberry_type = get_strawberry_type_from_model(child_type)
else:
strawberry_type = List[error_class]

strawberry_type = Optional[strawberry_type]
elif lenient_issubclass(type_, BaseModel):
strawberry_type = get_strawberry_type_from_model(type_)
return Optional[strawberry_type]

return Optional[List[strawberry_type]]


def error_type(
model: Type[BaseModel],
*,
fields: Optional[List[str]] = None,
name: Optional[str] = None,
description: Optional[str] = None,
directives: Optional[Sequence[object]] = (),
all_fields: bool = False,
) -> Callable[..., Type]:
def wrap(cls: Type) -> Type:
model_fields = model.__fields__
fields_set = set(fields) if fields else set()

if fields:
warnings.warn(
"`fields` is deprecated, use `auto` type annotations instead",
DeprecationWarning,
stacklevel=2,
)

existing_fields = getattr(cls, "__annotations__", {})
fields_set = fields_set.union(
{
name
for name, type_ in existing_fields.items()
if isinstance(type_, StrawberryAuto)
}
)

if all_fields:
if fields_set:
warnings.warn(
"Using all_fields overrides any explicitly defined fields "
"in the model, using both is likely a bug",
stacklevel=2,
)
fields_set = set(model_fields.keys())

if not fields_set:
raise MissingFieldsListError(cls)

all_model_fields: List[Tuple[str, Any, dataclasses.Field]] = [
(
name,
get_type_for_field(field),
dataclasses.field(default=None), # type: ignore[arg-type]
)
for name, field in model_fields.items()
if name in fields_set
]

wrapped = _wrap_dataclass(cls)
extra_fields = cast(List[dataclasses.Field], _get_fields(wrapped))
private_fields = get_private_fields(wrapped)

all_model_fields.extend(
(
field.name,
field.type,
field,
)
for field in extra_fields + private_fields
if not isinstance(field.type, StrawberryAuto)
)

cls = dataclasses.make_dataclass(
cls.__name__,
all_model_fields,
bases=cls.__bases__,
)

_process_type(
cls,
name=name,
is_input=False,
is_interface=False,
description=description,
directives=directives,
)

model._strawberry_type = cls # type: ignore[attr-defined]
cls._pydantic_type = model
return cls

return wrap
50 changes: 50 additions & 0 deletions strawberry/experimental/pydantic2/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any, List, Type

if TYPE_CHECKING:
from pydantic import BaseModel
from pydantic.typing import NoArgAnyCallable


class MissingFieldsListError(Exception):
def __init__(self, type: Type[BaseModel]):
message = (
f"List of fields to copy from {type} is empty. Add fields with the "
f"`auto` type annotation"
)

super().__init__(message)


class UnsupportedTypeError(Exception):
pass


class UnregisteredTypeException(Exception):
def __init__(self, type: Type[BaseModel]):
message = (
f"Cannot find a Strawberry Type for {type} did you forget to register it?"
)

super().__init__(message)


class BothDefaultAndDefaultFactoryDefinedError(Exception):
def __init__(self, default: Any, default_factory: NoArgAnyCallable):
message = (
f"Not allowed to specify both default and default_factory. "
f"default:{default} default_factory:{default_factory}"
)

super().__init__(message)


class AutoFieldsNotInBaseModelError(Exception):
def __init__(self, fields: List[str], cls_name: str, model: Type[BaseModel]):
message = (
f"{cls_name} defines {fields} with strawberry.auto. "
f"Field(s) not present in {model.__name__} BaseModel."
)

super().__init__(message)
Loading