From ee45b971f98140e50d0690aff224b343b44501b1 Mon Sep 17 00:00:00 2001 From: jameschua Date: Sat, 20 May 2023 17:46:16 +0100 Subject: [PATCH 01/26] work so far --- .../experimental/pydantic/conversion.py | 1 - .../experimental/pydantic/error_type.py | 20 +++++++++++-------- strawberry/experimental/pydantic/fields.py | 6 ++---- strawberry/experimental/pydantic/utils.py | 5 ++--- strawberry/experimental/pydantic/v2_compat.py | 15 ++++++++++++++ 5 files changed, 31 insertions(+), 16 deletions(-) create mode 100644 strawberry/experimental/pydantic/v2_compat.py diff --git a/strawberry/experimental/pydantic/conversion.py b/strawberry/experimental/pydantic/conversion.py index bc0f787948..e138e1e807 100644 --- a/strawberry/experimental/pydantic/conversion.py +++ b/strawberry/experimental/pydantic/conversion.py @@ -9,7 +9,6 @@ from strawberry.union import StrawberryUnion if TYPE_CHECKING: - from strawberry.field import StrawberryField from strawberry.type import StrawberryType diff --git a/strawberry/experimental/pydantic/error_type.py b/strawberry/experimental/pydantic/error_type.py index adcdd5cdf4..5757517acb 100644 --- a/strawberry/experimental/pydantic/error_type.py +++ b/strawberry/experimental/pydantic/error_type.py @@ -16,7 +16,6 @@ ) from pydantic import BaseModel -from pydantic.utils import lenient_issubclass from strawberry.auto import StrawberryAuto from strawberry.experimental.pydantic.utils import ( @@ -24,10 +23,13 @@ get_strawberry_type_from_model, normalize_type, ) +from strawberry.experimental.pydantic.v2_compat import lenient_issubclass from strawberry.object_type import _process_type, _wrap_dataclass from strawberry.types.type_resolver import _get_fields from strawberry.utils.typing import get_list_annotation, is_list + + from .exceptions import MissingFieldsListError if TYPE_CHECKING: @@ -40,6 +42,8 @@ def get_type_for_field(field: ModelField) -> Union[Any, Type[None], Type[List]]: return field_type_to_type(type_) + + def field_type_to_type(type_: Type) -> Union[Any, List[Any], None]: error_class: Any = str strawberry_type: Any = error_class @@ -63,13 +67,13 @@ def field_type_to_type(type_: Type) -> Union[Any, List[Any], None]: def error_type( - model: Type[BaseModel], - *, - fields: Optional[List[str]] = None, - name: Optional[str] = None, - description: Optional[str] = None, - directives: Optional[Sequence[object]] = (), - all_fields: bool = False, + model: Type[BaseModel], + *, + fields: Optional[List[str]] = None, + name: Optional[str] = None, + description: Optional[str] = None, + directives: Optional[Sequence[object]] = (), + all_fields: bool = False, ) -> Callable[..., Type]: def wrap(cls: Type) -> Type: model_fields = model.__fields__ diff --git a/strawberry/experimental/pydantic/fields.py b/strawberry/experimental/pydantic/fields.py index cfa3a6be2c..24d2b60b36 100644 --- a/strawberry/experimental/pydantic/fields.py +++ b/strawberry/experimental/pydantic/fields.py @@ -5,9 +5,8 @@ import pydantic from pydantic import BaseModel -from pydantic.typing import get_args, get_origin, is_new_type, new_type_supertype -from pydantic.utils import lenient_issubclass - +from strawberry.experimental.pydantic.v2_compat import lenient_issubclass, get_args, get_origin, is_new_type, \ + new_type_supertype from strawberry.experimental.pydantic.exceptions import ( UnregisteredTypeException, UnsupportedTypeError, @@ -70,7 +69,6 @@ "RedisDsn": str, } - FIELDS_MAP = { getattr(pydantic, field_name): type for field_name, type in ATTR_TO_TYPE_MAP.items() diff --git a/strawberry/experimental/pydantic/utils.py b/strawberry/experimental/pydantic/utils.py index e863a47e16..0368fd84f9 100644 --- a/strawberry/experimental/pydantic/utils.py +++ b/strawberry/experimental/pydantic/utils.py @@ -14,7 +14,6 @@ cast, ) -from pydantic.utils import smart_deepcopy from strawberry.experimental.pydantic.exceptions import ( AutoFieldsNotInBaseModelError, @@ -76,7 +75,7 @@ def to_tuple(self) -> Tuple[str, Type, dataclasses.Field]: def get_default_factory_for_field( - field: ModelField, + field: ModelField, ) -> Union[NoArgAnyCallable, dataclasses._MISSING_TYPE]: """ Gets the default factory for a pydantic field. @@ -128,7 +127,7 @@ def get_default_factory_for_field( def ensure_all_auto_fields_in_pydantic( - model: Type[BaseModel], auto_fields: Set[str], cls_name: str + model: Type[BaseModel], auto_fields: Set[str], cls_name: str ) -> Union[NoReturn, None]: # Raise error if user defined a strawberry.auto field not present in the model non_existing_fields = list(auto_fields - model.__fields__.keys()) diff --git a/strawberry/experimental/pydantic/v2_compat.py b/strawberry/experimental/pydantic/v2_compat.py new file mode 100644 index 0000000000..7d27fa593d --- /dev/null +++ b/strawberry/experimental/pydantic/v2_compat.py @@ -0,0 +1,15 @@ +import pydantic + +if pydantic.VERSION[0] == '2': + from pydantic._internal._utils import smart_deepcopy + from pydantic._internal._utils import lenient_issubclass + from typing_extensions import get_args, get_origin + from pydantic._internal._typing_extra import is_new_type + def new_type_supertype(type_): + return type_.__supertype__ +else: + from pydantic.utils import smart_deepcopy + from pydantic.utils import lenient_issubclass + from pydantic.typing import get_args, get_origin, is_new_type, new_type_supertype + +__all__ = ["smart_deepcopy", "lenient_issubclass"] From 4c13e84ec10185a1b5bc566a286beb05767db801 Mon Sep 17 00:00:00 2001 From: jameschua Date: Sat, 20 May 2023 17:56:40 +0100 Subject: [PATCH 02/26] separate --- strawberry/experimental/__init__.py | 10 +- strawberry/experimental/pydantic2/__init__.py | 11 + .../experimental/pydantic2/conversion.py | 113 +++ .../pydantic2/conversion_types.py | 37 + .../experimental/pydantic2/error_type.py | 153 ++++ .../experimental/pydantic2/exceptions.py | 50 + strawberry/experimental/pydantic2/fields.py | 137 +++ .../experimental/pydantic2/object_type.py | 346 +++++++ strawberry/experimental/pydantic2/utils.py | 140 +++ .../experimental/pydantic2/v2_compat.py | 15 + tests/experimental/pydantic2/test_basic.py | 861 ++++++++++++++++++ 11 files changed, 1871 insertions(+), 2 deletions(-) create mode 100644 strawberry/experimental/pydantic2/__init__.py create mode 100644 strawberry/experimental/pydantic2/conversion.py create mode 100644 strawberry/experimental/pydantic2/conversion_types.py create mode 100644 strawberry/experimental/pydantic2/error_type.py create mode 100644 strawberry/experimental/pydantic2/exceptions.py create mode 100644 strawberry/experimental/pydantic2/fields.py create mode 100644 strawberry/experimental/pydantic2/object_type.py create mode 100644 strawberry/experimental/pydantic2/utils.py create mode 100644 strawberry/experimental/pydantic2/v2_compat.py create mode 100644 tests/experimental/pydantic2/test_basic.py diff --git a/strawberry/experimental/__init__.py b/strawberry/experimental/__init__.py index 6386ad81d7..0674ca0c11 100644 --- a/strawberry/experimental/__init__.py +++ b/strawberry/experimental/__init__.py @@ -1,6 +1,12 @@ try: from . import pydantic + + __all__ = ["pydantic"] except ImportError: pass -else: - __all__ = ["pydantic"] +try: + from . import pydantic2 + + __all__ = ["pydantic2"] +except ImportError as e: + print(e) diff --git a/strawberry/experimental/pydantic2/__init__.py b/strawberry/experimental/pydantic2/__init__.py new file mode 100644 index 0000000000..10f382650d --- /dev/null +++ b/strawberry/experimental/pydantic2/__init__.py @@ -0,0 +1,11 @@ +from .error_type import error_type +from .exceptions import UnregisteredTypeException +from .object_type import input, interface, type + +__all__ = [ + "error_type", + "UnregisteredTypeException", + "input", + "type", + "interface", +] diff --git a/strawberry/experimental/pydantic2/conversion.py b/strawberry/experimental/pydantic2/conversion.py new file mode 100644 index 0000000000..e138e1e807 --- /dev/null +++ b/strawberry/experimental/pydantic2/conversion.py @@ -0,0 +1,113 @@ +from __future__ import annotations + +import copy +import dataclasses +from typing import TYPE_CHECKING, Any, Type, Union, cast + +from strawberry.enum import EnumDefinition +from strawberry.type import StrawberryList, StrawberryOptional +from strawberry.union import StrawberryUnion + +if TYPE_CHECKING: + from strawberry.type import StrawberryType + + +def _convert_from_pydantic_to_strawberry_type( + type_: Union[StrawberryType, type], data_from_model=None, extra=None # noqa: ANN001 +): + data = data_from_model if data_from_model is not None else extra + + if isinstance(type_, StrawberryOptional): + if data is None: + return data + return _convert_from_pydantic_to_strawberry_type( + type_.of_type, data_from_model=data, extra=extra + ) + if isinstance(type_, StrawberryUnion): + for option_type in type_.types: + if hasattr(option_type, "_pydantic_type"): + source_type = option_type._pydantic_type + else: + source_type = cast(type, option_type) + if isinstance(data, source_type): + return _convert_from_pydantic_to_strawberry_type( + option_type, data_from_model=data, extra=extra + ) + if isinstance(type_, EnumDefinition): + return data + if isinstance(type_, StrawberryList): + items = [] + for index, item in enumerate(data): + items.append( + _convert_from_pydantic_to_strawberry_type( + type_.of_type, + data_from_model=item, + extra=extra[index] if extra else None, + ) + ) + + return items + + if hasattr(type_, "_type_definition"): + # in the case of an interface, the concrete type may be more specific + # than the type in the field definition + # don't check _strawberry_input_type because inputs can't be interfaces + if hasattr(type(data), "_strawberry_type"): + type_ = type(data)._strawberry_type + if hasattr(type_, "from_pydantic"): + return type_.from_pydantic(data_from_model, extra) + return convert_pydantic_model_to_strawberry_class( + type_, model_instance=data_from_model, extra=extra + ) + + return data + + +def convert_pydantic_model_to_strawberry_class( + cls, *, model_instance=None, extra=None # noqa: ANN001 +) -> Any: + extra = extra or {} + kwargs = {} + + for field_ in cls._type_definition.fields: + field = cast("StrawberryField", field_) + python_name = field.python_name + + data_from_extra = extra.get(python_name, None) + data_from_model = ( + getattr(model_instance, python_name, None) if model_instance else None + ) + + # only convert and add fields to kwargs if they are present in the `__init__` + # method of the class + if field.init: + kwargs[python_name] = _convert_from_pydantic_to_strawberry_type( + field.type, data_from_model, extra=data_from_extra + ) + + return cls(**kwargs) + + +def convert_strawberry_class_to_pydantic_model(obj: Type) -> Any: + if hasattr(obj, "to_pydantic"): + return obj.to_pydantic() + elif dataclasses.is_dataclass(obj): + result = [] + for f in dataclasses.fields(obj): + value = convert_strawberry_class_to_pydantic_model(getattr(obj, f.name)) + result.append((f.name, value)) + return dict(result) + elif isinstance(obj, (list, tuple)): + # Assume we can create an object of this type by passing in a + # generator (which is not true for namedtuples, not supported). + return type(obj)(convert_strawberry_class_to_pydantic_model(v) for v in obj) + elif isinstance(obj, dict): + return type(obj)( + ( + convert_strawberry_class_to_pydantic_model(k), + convert_strawberry_class_to_pydantic_model(v), + ) + for k, v in obj.items() + ) + else: + return copy.deepcopy(obj) diff --git a/strawberry/experimental/pydantic2/conversion_types.py b/strawberry/experimental/pydantic2/conversion_types.py new file mode 100644 index 0000000000..aca9cdccd9 --- /dev/null +++ b/strawberry/experimental/pydantic2/conversion_types.py @@ -0,0 +1,37 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Dict, Optional, Type, TypeVar +from typing_extensions import Protocol + +from pydantic import BaseModel + +if TYPE_CHECKING: + from strawberry.types.types import TypeDefinition + + +PydanticModel = TypeVar("PydanticModel", bound=BaseModel) + + +class StrawberryTypeFromPydantic(Protocol[PydanticModel]): + """This class does not exist in runtime. + It only makes the methods below visible for IDEs""" + + def __init__(self, **kwargs): + ... + + @staticmethod + def from_pydantic( + instance: PydanticModel, extra: Optional[Dict[str, Any]] = None + ) -> StrawberryTypeFromPydantic[PydanticModel]: + ... + + def to_pydantic(self, **kwargs) -> PydanticModel: + ... + + @property + def _type_definition(self) -> TypeDefinition: + ... + + @property + def _pydantic_type(self) -> Type[PydanticModel]: + ... diff --git a/strawberry/experimental/pydantic2/error_type.py b/strawberry/experimental/pydantic2/error_type.py new file mode 100644 index 0000000000..271eab8775 --- /dev/null +++ b/strawberry/experimental/pydantic2/error_type.py @@ -0,0 +1,153 @@ +from __future__ import annotations + +import dataclasses +import warnings +from typing import ( + TYPE_CHECKING, + Any, + Callable, + List, + Optional, + Sequence, + Tuple, + Type, + Union, + cast, +) + +from pydantic import BaseModel + +from strawberry.auto import StrawberryAuto +from strawberry.experimental.pydantic2.utils import ( + get_private_fields, + get_strawberry_type_from_model, + normalize_type, +) +from strawberry.experimental.pydantic2.v2_compat import lenient_issubclass +from strawberry.object_type import _process_type, _wrap_dataclass +from strawberry.types.type_resolver import _get_fields +from strawberry.utils.typing import get_list_annotation, is_list + + + +from .exceptions import MissingFieldsListError + +if TYPE_CHECKING: + from pydantic.fields import ModelField + + +def get_type_for_field(field: ModelField) -> Union[Any, Type[None], Type[List]]: + type_ = field.outer_type_ + type_ = normalize_type(type_) + return field_type_to_type(type_) + + + + +def field_type_to_type(type_: Type) -> Union[Any, List[Any], None]: + error_class: Any = str + strawberry_type: Any = error_class + + if is_list(type_): + child_type = get_list_annotation(type_) + + if is_list(child_type): + strawberry_type = field_type_to_type(child_type) + elif lenient_issubclass(child_type, BaseModel): + strawberry_type = get_strawberry_type_from_model(child_type) + else: + strawberry_type = List[error_class] + + strawberry_type = Optional[strawberry_type] + elif lenient_issubclass(type_, BaseModel): + strawberry_type = get_strawberry_type_from_model(type_) + return Optional[strawberry_type] + + return Optional[List[strawberry_type]] + + +def error_type( + model: Type[BaseModel], + *, + fields: Optional[List[str]] = None, + name: Optional[str] = None, + description: Optional[str] = None, + directives: Optional[Sequence[object]] = (), + all_fields: bool = False, +) -> Callable[..., Type]: + def wrap(cls: Type) -> Type: + model_fields = model.__fields__ + fields_set = set(fields) if fields else set() + + if fields: + warnings.warn( + "`fields` is deprecated, use `auto` type annotations instead", + DeprecationWarning, + stacklevel=2, + ) + + existing_fields = getattr(cls, "__annotations__", {}) + fields_set = fields_set.union( + { + name + for name, type_ in existing_fields.items() + if isinstance(type_, StrawberryAuto) + } + ) + + if all_fields: + if fields_set: + warnings.warn( + "Using all_fields overrides any explicitly defined fields " + "in the model, using both is likely a bug", + stacklevel=2, + ) + fields_set = set(model_fields.keys()) + + if not fields_set: + raise MissingFieldsListError(cls) + + all_model_fields: List[Tuple[str, Any, dataclasses.Field]] = [ + ( + name, + get_type_for_field(field), + dataclasses.field(default=None), # type: ignore[arg-type] + ) + for name, field in model_fields.items() + if name in fields_set + ] + + wrapped = _wrap_dataclass(cls) + extra_fields = cast(List[dataclasses.Field], _get_fields(wrapped)) + private_fields = get_private_fields(wrapped) + + all_model_fields.extend( + ( + field.name, + field.type, + field, + ) + for field in extra_fields + private_fields + if not isinstance(field.type, StrawberryAuto) + ) + + cls = dataclasses.make_dataclass( + cls.__name__, + all_model_fields, + bases=cls.__bases__, + ) + + _process_type( + cls, + name=name, + is_input=False, + is_interface=False, + description=description, + directives=directives, + ) + + model._strawberry_type = cls # type: ignore[attr-defined] + cls._pydantic_type = model + return cls + + return wrap diff --git a/strawberry/experimental/pydantic2/exceptions.py b/strawberry/experimental/pydantic2/exceptions.py new file mode 100644 index 0000000000..cdb16baaad --- /dev/null +++ b/strawberry/experimental/pydantic2/exceptions.py @@ -0,0 +1,50 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, List, Type + +if TYPE_CHECKING: + from pydantic import BaseModel + from pydantic.typing import NoArgAnyCallable + + +class MissingFieldsListError(Exception): + def __init__(self, type: Type[BaseModel]): + message = ( + f"List of fields to copy from {type} is empty. Add fields with the " + f"`auto` type annotation" + ) + + super().__init__(message) + + +class UnsupportedTypeError(Exception): + pass + + +class UnregisteredTypeException(Exception): + def __init__(self, type: Type[BaseModel]): + message = ( + f"Cannot find a Strawberry Type for {type} did you forget to register it?" + ) + + super().__init__(message) + + +class BothDefaultAndDefaultFactoryDefinedError(Exception): + def __init__(self, default: Any, default_factory: NoArgAnyCallable): + message = ( + f"Not allowed to specify both default and default_factory. " + f"default:{default} default_factory:{default_factory}" + ) + + super().__init__(message) + + +class AutoFieldsNotInBaseModelError(Exception): + def __init__(self, fields: List[str], cls_name: str, model: Type[BaseModel]): + message = ( + f"{cls_name} defines {fields} with strawberry.auto. " + f"Field(s) not present in {model.__name__} BaseModel." + ) + + super().__init__(message) diff --git a/strawberry/experimental/pydantic2/fields.py b/strawberry/experimental/pydantic2/fields.py new file mode 100644 index 0000000000..fbe519f3c4 --- /dev/null +++ b/strawberry/experimental/pydantic2/fields.py @@ -0,0 +1,137 @@ +import builtins +from decimal import Decimal +from typing import Any, List, Optional, Type +from uuid import UUID + +import pydantic +from pydantic import BaseModel +from strawberry.experimental.pydantic2.v2_compat import lenient_issubclass, get_args, get_origin, is_new_type, \ + new_type_supertype +from strawberry.experimental.pydantic2.exceptions import ( + UnregisteredTypeException, + UnsupportedTypeError, +) +from strawberry.types.types import TypeDefinition + +try: + from typing import GenericAlias as TypingGenericAlias # type: ignore +except ImportError: + import sys + + # python < 3.9 does not have GenericAlias (list[int], tuple[str, ...] and so on) + # we do this under a conditional to avoid a mypy :) + if sys.version_info < (3, 9): + TypingGenericAlias = () + else: + raise + +ATTR_TO_TYPE_MAP = { + "NoneBytes": Optional[bytes], + "StrBytes": None, + "NoneStrBytes": None, + "StrictStr": str, + "ConstrainedBytes": bytes, + "conbytes": bytes, + "ConstrainedStr": str, + "constr": str, + "EmailStr": str, + "PyObject": None, + "ConstrainedInt": int, + "conint": int, + "PositiveInt": int, + "NegativeInt": int, + "ConstrainedFloat": float, + "confloat": float, + "PositiveFloat": float, + "NegativeFloat": float, + "ConstrainedDecimal": Decimal, + "condecimal": Decimal, + "UUID1": UUID, + "UUID3": UUID, + "UUID4": UUID, + "UUID5": UUID, + "FilePath": None, + "DirectoryPath": None, + "Json": None, + "JsonWrapper": None, + "SecretStr": str, + "SecretBytes": bytes, + "StrictBool": bool, + "StrictInt": int, + "StrictFloat": float, + "PaymentCardNumber": None, + "ByteSize": None, + "AnyUrl": str, + "AnyHttpUrl": str, + "HttpUrl": str, + "PostgresDsn": str, + "RedisDsn": str, +} + +FIELDS_MAP = { + getattr(pydantic, field_name): type + for field_name, type in ATTR_TO_TYPE_MAP.items() + if hasattr(pydantic, field_name) +} + + +def get_basic_type(type_: Any) -> Type[Any]: + if lenient_issubclass(type_, pydantic.ConstrainedInt): + return int + if lenient_issubclass(type_, pydantic.ConstrainedFloat): + return float + if lenient_issubclass(type_, pydantic.ConstrainedStr): + return str + if lenient_issubclass(type_, pydantic.ConstrainedList): + return List[get_basic_type(type_.item_type)] # type: ignore + + if type_ in FIELDS_MAP: + type_ = FIELDS_MAP.get(type_) + + if type_ is None: + raise UnsupportedTypeError() + + if is_new_type(type_): + return new_type_supertype(type_) + + return type_ + + +def replace_pydantic_types(type_: Any, is_input: bool) -> Any: + if lenient_issubclass(type_, BaseModel): + attr = "_strawberry_input_type" if is_input else "_strawberry_type" + if hasattr(type_, attr): + return getattr(type_, attr) + else: + raise UnregisteredTypeException(type_) + return type_ + + +def replace_types_recursively(type_: Any, is_input: bool) -> Any: + """Runs the conversions recursively into the arguments of generic types if any""" + basic_type = get_basic_type(type_) + replaced_type = replace_pydantic_types(basic_type, is_input) + + origin = get_origin(type_) + if not origin or not hasattr(type_, "__args__"): + return replaced_type + + converted = tuple( + replace_types_recursively(t, is_input=is_input) for t in get_args(replaced_type) + ) + + if isinstance(replaced_type, TypingGenericAlias): + return TypingGenericAlias(origin, converted) + + replaced_type = replaced_type.copy_with(converted) + + if isinstance(replaced_type, TypeDefinition): + # TODO: Not sure if this is necessary. No coverage in tests + # TODO: Unnecessary with StrawberryObject + replaced_type = builtins.type( + replaced_type.name, + (), + {"_type_definition": replaced_type}, + ) + + return replaced_type diff --git a/strawberry/experimental/pydantic2/object_type.py b/strawberry/experimental/pydantic2/object_type.py new file mode 100644 index 0000000000..39c4668143 --- /dev/null +++ b/strawberry/experimental/pydantic2/object_type.py @@ -0,0 +1,346 @@ +from __future__ import annotations + +import dataclasses +import sys +import warnings +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + Optional, + Sequence, + Set, + Type, + cast, +) + +from strawberry.annotation import StrawberryAnnotation +from strawberry.auto import StrawberryAuto +from strawberry.experimental.pydantic2.conversion import ( + convert_pydantic_model_to_strawberry_class, + convert_strawberry_class_to_pydantic_model, +) +from strawberry.experimental.pydantic2.exceptions import MissingFieldsListError +from strawberry.experimental.pydantic2.fields import replace_types_recursively +from strawberry.experimental.pydantic2.utils import ( + DataclassCreationFields, + ensure_all_auto_fields_in_pydantic, + get_default_factory_for_field, + get_private_fields, +) +from strawberry.field import StrawberryField +from strawberry.object_type import _process_type, _wrap_dataclass +from strawberry.types.type_resolver import _get_fields +from strawberry.utils.dataclasses import add_custom_init_fn + +if TYPE_CHECKING: + from graphql import GraphQLResolveInfo + from pydantic.fields import ModelField + + +def get_type_for_field(field: ModelField, is_input: bool): # noqa: ANN201 + outer_type = field.outer_type_ + replaced_type = replace_types_recursively(outer_type, is_input) + + default_defined: bool = ( + field.default_factory is not None or field.default is not None + ) + should_add_optional: bool = not (field.required or default_defined) + if should_add_optional: + return Optional[replaced_type] + else: + return replaced_type + + +def _build_dataclass_creation_fields( + field: ModelField, + is_input: bool, + existing_fields: Dict[str, StrawberryField], + auto_fields_set: Set[str], + use_pydantic_alias: bool, +) -> DataclassCreationFields: + field_type = ( + get_type_for_field(field, is_input) + if field.name in auto_fields_set + else existing_fields[field.name].type + ) + + if ( + field.name in existing_fields + and existing_fields[field.name].base_resolver is not None + ): + # if the user has defined a resolver for this field, always use it + strawberry_field = existing_fields[field.name] + else: + # otherwise we build an appropriate strawberry field that resolves it + existing_field = existing_fields.get(field.name) + graphql_name = None + if existing_field and existing_field.graphql_name: + graphql_name = existing_field.graphql_name + elif field.has_alias and use_pydantic_alias: + graphql_name = field.alias + + strawberry_field = StrawberryField( + python_name=field.name, + graphql_name=graphql_name, + # always unset because we use default_factory instead + default=dataclasses.MISSING, + default_factory=get_default_factory_for_field(field), + type_annotation=StrawberryAnnotation.from_annotation(field_type), + description=field.field_info.description, + deprecation_reason=( + existing_field.deprecation_reason if existing_field else None + ), + permission_classes=( + existing_field.permission_classes if existing_field else [] + ), + directives=existing_field.directives if existing_field else (), + metadata=existing_field.metadata if existing_field else {}, + ) + + return DataclassCreationFields( + name=field.name, + field_type=field_type, + field=strawberry_field, + ) + + +if TYPE_CHECKING: + from strawberry.experimental.pydantic2.conversion_types import ( + PydanticModel, + StrawberryTypeFromPydantic, + ) + + +def type( + model: Type[PydanticModel], + *, + fields: Optional[List[str]] = None, + name: Optional[str] = None, + is_input: bool = False, + is_interface: bool = False, + description: Optional[str] = None, + directives: Optional[Sequence[object]] = (), + all_fields: bool = False, + use_pydantic_alias: bool = True, +) -> Callable[..., Type[StrawberryTypeFromPydantic[PydanticModel]]]: + def wrap(cls: Any) -> Type[StrawberryTypeFromPydantic[PydanticModel]]: + model_fields = model.__fields__ + original_fields_set = set(fields) if fields else set() + + if fields: + warnings.warn( + "`fields` is deprecated, use `auto` type annotations instead", + DeprecationWarning, + stacklevel=2, + ) + + existing_fields = getattr(cls, "__annotations__", {}) + + # these are the fields that matched a field name in the pydantic model + # and should copy their alias from the pydantic model + fields_set = original_fields_set.union( + {name for name, _ in existing_fields.items() if name in model_fields} + ) + # these are the fields that were marked with strawberry.auto and + # should copy their type from the pydantic model + auto_fields_set = original_fields_set.union( + { + name + for name, type_ in existing_fields.items() + if isinstance(type_, StrawberryAuto) + } + ) + + if all_fields: + if fields_set: + warnings.warn( + "Using all_fields overrides any explicitly defined fields " + "in the model, using both is likely a bug", + stacklevel=2, + ) + fields_set = set(model_fields.keys()) + auto_fields_set = set(model_fields.keys()) + + if not fields_set: + raise MissingFieldsListError(cls) + + ensure_all_auto_fields_in_pydantic( + model=model, auto_fields=auto_fields_set, cls_name=cls.__name__ + ) + + wrapped = _wrap_dataclass(cls) + extra_strawberry_fields = _get_fields(wrapped) + extra_fields = cast(List[dataclasses.Field], extra_strawberry_fields) + private_fields = get_private_fields(wrapped) + + extra_fields_dict = {field.name: field for field in extra_strawberry_fields} + + all_model_fields: List[DataclassCreationFields] = [ + _build_dataclass_creation_fields( + field, is_input, extra_fields_dict, auto_fields_set, use_pydantic_alias + ) + for field_name, field in model_fields.items() + if field_name in fields_set + ] + + all_model_fields = [ + DataclassCreationFields( + name=field.name, + field_type=field.type, + field=field, + ) + for field in extra_fields + private_fields + if field.name not in fields_set + ] + all_model_fields + + # Implicitly define `is_type_of` to support interfaces/unions that use + # pydantic objects (not the corresponding strawberry type) + @classmethod # type: ignore + def is_type_of(cls: Type, obj: Any, _info: GraphQLResolveInfo) -> bool: + return isinstance(obj, (cls, model)) + + namespace = {"is_type_of": is_type_of} + # We need to tell the difference between a from_pydantic method that is + # inherited from a base class and one that is defined by the user in the + # decorated class. We want to override the method only if it is + # inherited. To tell the difference, we compare the class name to the + # fully qualified name of the method, which will end in .from_pydantic + has_custom_from_pydantic = hasattr( + cls, "from_pydantic" + ) and cls.from_pydantic.__qualname__.endswith(f"{cls.__name__}.from_pydantic") + has_custom_to_pydantic = hasattr( + cls, "to_pydantic" + ) and cls.to_pydantic.__qualname__.endswith(f"{cls.__name__}.to_pydantic") + + if has_custom_from_pydantic: + namespace["from_pydantic"] = cls.from_pydantic + if has_custom_to_pydantic: + namespace["to_pydantic"] = cls.to_pydantic + + if hasattr(cls, "resolve_reference"): + namespace["resolve_reference"] = cls.resolve_reference + + kwargs: Dict[str, object] = {} + + # Python 3.10.1 introduces the kw_only param to `make_dataclass`. + # If we're on an older version then generate our own custom init function + # Note: Python 3.10.0 added the `kw_only` param to dataclasses, it was + # just missed from the `make_dataclass` function: + # https://github.com/python/cpython/issues/89961 + if sys.version_info >= (3, 10, 1): + kwargs["kw_only"] = dataclasses.MISSING + else: + kwargs["init"] = False + + cls = dataclasses.make_dataclass( + cls.__name__, + [field.to_tuple() for field in all_model_fields], + bases=cls.__bases__, + namespace=namespace, + **kwargs, # type: ignore + ) + + if sys.version_info < (3, 10, 1): + add_custom_init_fn(cls) + + _process_type( + cls, + name=name, + is_input=is_input, + is_interface=is_interface, + description=description, + directives=directives, + ) + + if is_input: + model._strawberry_input_type = cls # type: ignore + else: + model._strawberry_type = cls # type: ignore + cls._pydantic_type = model + + def from_pydantic_default( + instance: PydanticModel, extra: Optional[Dict[str, Any]] = None + ) -> StrawberryTypeFromPydantic[PydanticModel]: + ret = convert_pydantic_model_to_strawberry_class( + cls=cls, model_instance=instance, extra=extra + ) + ret._original_model = instance + return ret + + def to_pydantic_default(self: Any, **kwargs) -> PydanticModel: + instance_kwargs = { + f.name: convert_strawberry_class_to_pydantic_model( + getattr(self, f.name) + ) + for f in dataclasses.fields(self) + } + instance_kwargs.update(kwargs) + return model(**instance_kwargs) + + if not has_custom_from_pydantic: + cls.from_pydantic = staticmethod(from_pydantic_default) + if not has_custom_to_pydantic: + cls.to_pydantic = to_pydantic_default + + return cls + + return wrap + + +def input( + model: Type[PydanticModel], + *, + fields: Optional[List[str]] = None, + name: Optional[str] = None, + is_interface: bool = False, + description: Optional[str] = None, + directives: Optional[Sequence[object]] = (), + all_fields: bool = False, + use_pydantic_alias: bool = True, +) -> Callable[..., Type[StrawberryTypeFromPydantic[PydanticModel]]]: + """Convenience decorator for creating an input type from a Pydantic model. + Equal to partial(type, is_input=True) + See https://github.com/strawberry-graphql/strawberry/issues/1830 + """ + return type( + model=model, + fields=fields, + name=name, + is_input=True, + is_interface=is_interface, + description=description, + directives=directives, + all_fields=all_fields, + use_pydantic_alias=use_pydantic_alias, + ) + + +def interface( + model: Type[PydanticModel], + *, + fields: Optional[List[str]] = None, + name: Optional[str] = None, + is_input: bool = False, + description: Optional[str] = None, + directives: Optional[Sequence[object]] = (), + all_fields: bool = False, + use_pydantic_alias: bool = True, +) -> Callable[..., Type[StrawberryTypeFromPydantic[PydanticModel]]]: + """Convenience decorator for creating an interface type from a Pydantic model. + Equal to partial(type, is_interface=True) + See https://github.com/strawberry-graphql/strawberry/issues/1830 + """ + return type( + model=model, + fields=fields, + name=name, + is_input=is_input, + is_interface=True, + description=description, + directives=directives, + all_fields=all_fields, + use_pydantic_alias=use_pydantic_alias, + ) diff --git a/strawberry/experimental/pydantic2/utils.py b/strawberry/experimental/pydantic2/utils.py new file mode 100644 index 0000000000..63020fdbc0 --- /dev/null +++ b/strawberry/experimental/pydantic2/utils.py @@ -0,0 +1,140 @@ +from __future__ import annotations + +import dataclasses +from typing import ( + TYPE_CHECKING, + Any, + List, + NamedTuple, + NoReturn, + Set, + Tuple, + Type, + Union, + cast, +) + + +from strawberry.experimental.pydantic2.exceptions import ( + AutoFieldsNotInBaseModelError, + BothDefaultAndDefaultFactoryDefinedError, + UnregisteredTypeException, +) +from strawberry.private import is_private +from strawberry.unset import UNSET +from strawberry.utils.typing import ( + get_list_annotation, + get_optional_annotation, + is_list, + is_optional, +) + +if TYPE_CHECKING: + from pydantic import BaseModel + from pydantic.fields import ModelField + from pydantic.typing import NoArgAnyCallable + + +def normalize_type(type_: Type) -> Any: + if is_list(type_): + return List[normalize_type(get_list_annotation(type_))] # type: ignore + + if is_optional(type_): + return get_optional_annotation(type_) + + return type_ + + +def get_strawberry_type_from_model(type_: Any) -> Any: + if hasattr(type_, "_strawberry_type"): + return type_._strawberry_type + else: + raise UnregisteredTypeException(type_) + + +def get_private_fields(cls: Type) -> List[dataclasses.Field]: + private_fields: List[dataclasses.Field] = [] + + for field in dataclasses.fields(cls): + if is_private(field.type): + private_fields.append(field) + + return private_fields + + +class DataclassCreationFields(NamedTuple): + """Fields required for the fields parameter of make_dataclass""" + + name: str + field_type: Type + field: dataclasses.Field + + def to_tuple(self) -> Tuple[str, Type, dataclasses.Field]: + # fields parameter wants (name, type, Field) + return self.name, self.field_type, self.field + + +def get_default_factory_for_field( + field: ModelField, +) -> Union[NoArgAnyCallable, dataclasses._MISSING_TYPE]: + """ + Gets the default factory for a pydantic field. + + Handles mutable defaults when making the dataclass by + using pydantic's smart_deepcopy + + Returns optionally a NoArgAnyCallable representing a default_factory parameter + """ + # replace dataclasses.MISSING with our own UNSET to make comparisons easier + default_factory = ( + field.default_factory + if field.default_factory is not dataclasses.MISSING + else UNSET + ) + default = field.default if field.default is not dataclasses.MISSING else UNSET + + has_factory = default_factory is not None and default_factory is not UNSET + has_default = default is not None and default is not UNSET + + # defining both default and default_factory is not supported + + if has_factory and has_default: + default_factory = cast("NoArgAnyCallable", default_factory) + + raise BothDefaultAndDefaultFactoryDefinedError( + default=default, default_factory=default_factory + ) + + # if we have a default_factory, we should return it + + if has_factory: + default_factory = cast("NoArgAnyCallable", default_factory) + + return default_factory + + # if we have a default, we should return it + + if has_default: + return lambda: smart_deepcopy(default) + + # if we don't have default or default_factory, but the field is not required, + # we should return a factory that returns None + + if not field.required: + return lambda: None + + return dataclasses.MISSING + + +def ensure_all_auto_fields_in_pydantic( + model: Type[BaseModel], auto_fields: Set[str], cls_name: str +) -> Union[NoReturn, None]: + # Raise error if user defined a strawberry.auto field not present in the model + non_existing_fields = list(auto_fields - model.__fields__.keys()) + + if non_existing_fields: + raise AutoFieldsNotInBaseModelError( + fields=non_existing_fields, cls_name=cls_name, model=model + ) + else: + return None diff --git a/strawberry/experimental/pydantic2/v2_compat.py b/strawberry/experimental/pydantic2/v2_compat.py new file mode 100644 index 0000000000..7d27fa593d --- /dev/null +++ b/strawberry/experimental/pydantic2/v2_compat.py @@ -0,0 +1,15 @@ +import pydantic + +if pydantic.VERSION[0] == '2': + from pydantic._internal._utils import smart_deepcopy + from pydantic._internal._utils import lenient_issubclass + from typing_extensions import get_args, get_origin + from pydantic._internal._typing_extra import is_new_type + def new_type_supertype(type_): + return type_.__supertype__ +else: + from pydantic.utils import smart_deepcopy + from pydantic.utils import lenient_issubclass + from pydantic.typing import get_args, get_origin, is_new_type, new_type_supertype + +__all__ = ["smart_deepcopy", "lenient_issubclass"] diff --git a/tests/experimental/pydantic2/test_basic.py b/tests/experimental/pydantic2/test_basic.py new file mode 100644 index 0000000000..9dfd4c9ef9 --- /dev/null +++ b/tests/experimental/pydantic2/test_basic.py @@ -0,0 +1,861 @@ +import dataclasses +from enum import Enum +from typing import Any, List, Optional, Union + +import pydantic +import pytest + +import strawberry +from strawberry.enum import EnumDefinition +from strawberry.experimental.pydantic2.exceptions import MissingFieldsListError +from strawberry.schema_directive import Location +from strawberry.type import StrawberryList, StrawberryOptional +from strawberry.types.types import TypeDefinition +from strawberry.union import StrawberryUnion + + +def test_basic_type_field_list(): + class User(pydantic.BaseModel): + age: int + password: Optional[str] + + with pytest.deprecated_call(): + @strawberry.experimental.pydantic2.type(User, fields=["age", "password"]) + class UserType: + pass + + definition: TypeDefinition = UserType._type_definition + assert definition.name == "UserType" + + [field1, field2] = definition.fields + + assert field1.python_name == "age" + assert field1.graphql_name is None + assert field1.type is int + + assert field2.python_name == "password" + assert field2.graphql_name is None + assert isinstance(field2.type, StrawberryOptional) + assert field2.type.of_type is str + + +def test_basic_type_all_fields(): + class User(pydantic.BaseModel): + age: int + password: Optional[str] + + @strawberry.experimental.pydantic2.type(User, all_fields=True) + class UserType: + pass + + definition: TypeDefinition = UserType._type_definition + assert definition.name == "UserType" + + [field1, field2] = definition.fields + + assert field1.python_name == "age" + assert field1.graphql_name is None + assert field1.type is int + + assert field2.python_name == "password" + assert field2.graphql_name is None + assert isinstance(field2.type, StrawberryOptional) + assert field2.type.of_type is str + + +@pytest.mark.filterwarnings("error") +def test_basic_type_all_fields_warn(): + class User(pydantic.BaseModel): + age: int + password: Optional[str] + + with pytest.raises( + UserWarning, + match=("Using all_fields overrides any explicitly defined fields"), + ): + @strawberry.experimental.pydantic2.type(User, all_fields=True) + class UserType: + age: strawberry.auto + + +def test_basic_type_auto_fields(): + class User(pydantic.BaseModel): + age: int + password: Optional[str] + other: float + + @strawberry.experimental.pydantic2.type(User) + class UserType: + age: strawberry.auto + password: strawberry.auto + + definition: TypeDefinition = UserType._type_definition + assert definition.name == "UserType" + + [field1, field2] = definition.fields + + assert field1.python_name == "age" + assert field1.graphql_name is None + assert field1.type is int + + assert field2.python_name == "password" + assert field2.graphql_name is None + assert isinstance(field2.type, StrawberryOptional) + assert field2.type.of_type is str + + +def test_auto_fields_other_sentinel(): + class other_sentinel: + pass + + class User(pydantic.BaseModel): + age: int + password: Optional[str] + other: int + + @strawberry.experimental.pydantic2.type(User) + class UserType: + age: strawberry.auto + password: strawberry.auto + other: other_sentinel # this should be a private field, not an auto field + + definition: TypeDefinition = UserType._type_definition + assert definition.name == "UserType" + + [field1, field2, field3] = definition.fields + + assert field1.python_name == "age" + assert field1.graphql_name is None + assert field1.type is int + + assert field2.python_name == "password" + assert field2.graphql_name is None + assert isinstance(field2.type, StrawberryOptional) + assert field2.type.of_type is str + + assert field3.python_name == "other" + assert field3.graphql_name is None + assert field3.type is other_sentinel + + +def test_referencing_other_models_fails_when_not_registered(): + class Group(pydantic.BaseModel): + name: str + + class User(pydantic.BaseModel): + age: int + password: Optional[str] + group: Group + + with pytest.raises( + strawberry.experimental.pydantic2.UnregisteredTypeException, + match=("Cannot find a Strawberry Type for (.*) did you forget to register it?"), + ): + @strawberry.experimental.pydantic2.type(User) + class UserType: + age: strawberry.auto + password: strawberry.auto + group: strawberry.auto + + +def test_referencing_other_input_models_fails_when_not_registered(): + class Group(pydantic.BaseModel): + name: str + + class User(pydantic.BaseModel): + age: int + password: Optional[str] + group: Group + + @strawberry.experimental.pydantic2.type(Group) + class GroupType: + name: strawberry.auto + + with pytest.raises( + strawberry.experimental.pydantic2.UnregisteredTypeException, + match=("Cannot find a Strawberry Type for (.*) did you forget to register it?"), + ): + @strawberry.experimental.pydantic2.input(User) + class UserInputType: + age: strawberry.auto + password: strawberry.auto + group: strawberry.auto + + +def test_referencing_other_registered_models(): + class Group(pydantic.BaseModel): + name: str + + class User(pydantic.BaseModel): + age: int + group: Group + + @strawberry.experimental.pydantic2.type(Group) + class GroupType: + name: strawberry.auto + + @strawberry.experimental.pydantic2.type(User) + class UserType: + age: strawberry.auto + group: strawberry.auto + + definition: TypeDefinition = UserType._type_definition + assert definition.name == "UserType" + + [field1, field2] = definition.fields + + assert field1.python_name == "age" + assert field1.type is int + + assert field2.python_name == "group" + assert field2.type is GroupType + + +def test_list(): + class User(pydantic.BaseModel): + friend_names: List[str] + + @strawberry.experimental.pydantic2.type(User) + class UserType: + friend_names: strawberry.auto + + definition: TypeDefinition = UserType._type_definition + assert definition.name == "UserType" + + [field] = definition.fields + + assert field.python_name == "friend_names" + assert isinstance(field.type, StrawberryList) + assert field.type.of_type is str + + +def test_list_of_types(): + class Friend(pydantic.BaseModel): + name: str + + class User(pydantic.BaseModel): + friends: Optional[List[Optional[Friend]]] + + @strawberry.experimental.pydantic2.type(Friend) + class FriendType: + name: strawberry.auto + + @strawberry.experimental.pydantic2.type(User) + class UserType: + friends: strawberry.auto + + definition: TypeDefinition = UserType._type_definition + assert definition.name == "UserType" + + [field] = definition.fields + + assert field.python_name == "friends" + assert isinstance(field.type, StrawberryOptional) + assert isinstance(field.type.of_type, StrawberryList) + assert isinstance(field.type.of_type.of_type, StrawberryOptional) + assert field.type.of_type.of_type.of_type is FriendType + + +def test_basic_type_without_fields_throws_an_error(): + class User(pydantic.BaseModel): + age: int + password: Optional[str] + + with pytest.raises(MissingFieldsListError): + @strawberry.experimental.pydantic2.type(User) + class UserType: + pass + + +def test_type_with_fields_coming_from_strawberry_and_pydantic(): + class User(pydantic.BaseModel): + age: int + password: Optional[str] + + @strawberry.experimental.pydantic2.type(User) + class UserType: + name: str + age: strawberry.auto + password: strawberry.auto + + definition: TypeDefinition = UserType._type_definition + assert definition.name == "UserType" + + [field1, field2, field3] = definition.fields + + assert field1.python_name == "name" + assert field1.type is str + + assert field2.python_name == "age" + assert field2.type is int + + assert field3.python_name == "password" + assert isinstance(field3.type, StrawberryOptional) + assert field3.type.of_type is str + + +def test_default_and_default_factory(): + class User1(pydantic.BaseModel): + friend: Optional[str] = "friend_value" + + @strawberry.experimental.pydantic2.type(User1) + class UserType1: + friend: strawberry.auto + + assert UserType1().friend == "friend_value" + assert UserType1().to_pydantic().friend == "friend_value" + + class User2(pydantic.BaseModel): + friend: Optional[str] = None + + @strawberry.experimental.pydantic2.type(User2) + class UserType2: + friend: strawberry.auto + + assert UserType2().friend is None + assert UserType2().to_pydantic().friend is None + + # Test instantiation using default_factory + + class User3(pydantic.BaseModel): + friend: Optional[str] = pydantic.Field(default_factory=lambda: "friend_value") + + @strawberry.experimental.pydantic2.type(User3) + class UserType3: + friend: strawberry.auto + + assert UserType3().friend == "friend_value" + assert UserType3().to_pydantic().friend == "friend_value" + + class User4(pydantic.BaseModel): + friend: Optional[str] = pydantic.Field(default_factory=lambda: None) + + @strawberry.experimental.pydantic2.type(User4) + class UserType4: + friend: strawberry.auto + + assert UserType4().friend is None + assert UserType4().to_pydantic().friend is None + + +def test_type_with_fields_mutable_default(): + empty_list = [] + + class User(pydantic.BaseModel): + groups: List[str] + friends: List[str] = empty_list + + @strawberry.experimental.pydantic2.type(User) + class UserType: + groups: strawberry.auto + friends: strawberry.auto + + definition: TypeDefinition = UserType._type_definition + assert definition.name == "UserType" + + [groups_field, friends_field] = definition.fields + + assert groups_field.default is dataclasses.MISSING + assert groups_field.default_factory is dataclasses.MISSING + assert friends_field.default is dataclasses.MISSING + + # check that we really made a copy + assert friends_field.default_factory() is not empty_list + assert UserType(groups=["groups"]).friends is not empty_list + UserType(groups=["groups"]).friends.append("joe") + assert empty_list == [] + + +@pytest.mark.xfail( + reason=( + "passing default values when extending types from pydantic is not" + "supported. https://github.com/strawberry-graphql/strawberry/issues/829" + ) +) +def test_type_with_fields_coming_from_strawberry_and_pydantic_with_default(): + class User(pydantic.BaseModel): + age: int + password: Optional[str] + + @strawberry.experimental.pydantic2.type(User) + class UserType: + name: str = "Michael" + age: strawberry.auto + password: strawberry.auto + + definition: TypeDefinition = UserType._type_definition + assert definition.name == "UserType" + + [field1, field2, field3] = definition.fields + + assert field1.python_name == "age" + assert field1.type is int + + assert field2.python_name == "password" + assert isinstance(field2.type, StrawberryOptional) + assert field2.type.of_type is str + + assert field3.python_name == "name" + assert field3.type is str + assert field3.default == "Michael" + + +def test_type_with_nested_fields_coming_from_strawberry_and_pydantic(): + @strawberry.type + class Name: + first_name: str + last_name: str + + class User(pydantic.BaseModel): + age: int + password: Optional[str] + + @strawberry.experimental.pydantic2.type(User) + class UserType: + name: Name + age: strawberry.auto + password: strawberry.auto + + definition: TypeDefinition = UserType._type_definition + assert definition.name == "UserType" + + [field1, field2, field3] = definition.fields + + assert field1.python_name == "name" + assert field1.type is Name + + assert field2.python_name == "age" + assert field2.type is int + + assert field3.python_name == "password" + assert isinstance(field3.type, StrawberryOptional) + assert field3.type.of_type is str + + +def test_type_with_aliased_pydantic_field(): + class UserModel(pydantic.BaseModel): + age_: int = pydantic.Field(..., alias="age") + password: Optional[str] + + @strawberry.experimental.pydantic2.type(UserModel) + class User: + age_: strawberry.auto + password: strawberry.auto + + definition: TypeDefinition = User._type_definition + assert definition.name == "User" + + [field1, field2] = definition.fields + + assert field1.python_name == "age_" + assert field1.type is int + assert field1.graphql_name == "age" + + assert field2.python_name == "password" + assert isinstance(field2.type, StrawberryOptional) + assert field2.type.of_type is str + + +def test_union(): + class BranchA(pydantic.BaseModel): + field_a: str + + class BranchB(pydantic.BaseModel): + field_b: int + + class User(pydantic.BaseModel): + age: int + union_field: Union[BranchA, BranchB] + + @strawberry.experimental.pydantic2.type(BranchA) + class BranchAType: + field_a: strawberry.auto + + @strawberry.experimental.pydantic2.type(BranchB) + class BranchBType: + field_b: strawberry.auto + + @strawberry.experimental.pydantic2.type(User) + class UserType: + age: strawberry.auto + union_field: strawberry.auto + + definition: TypeDefinition = UserType._type_definition + assert definition.name == "UserType" + + [field1, field2] = definition.fields + + assert field1.python_name == "age" + assert field1.type is int + + assert field2.python_name == "union_field" + assert isinstance(field2.type, StrawberryUnion) + assert field2.type.types[0] is BranchAType + assert field2.type.types[1] is BranchBType + + +def test_enum(): + @strawberry.enum + class UserKind(Enum): + user = 0 + admin = 1 + + class User(pydantic.BaseModel): + age: int + kind: UserKind + + @strawberry.experimental.pydantic2.type(User) + class UserType: + age: strawberry.auto + kind: strawberry.auto + + definition: TypeDefinition = UserType._type_definition + assert definition.name == "UserType" + + [field1, field2] = definition.fields + + assert field1.python_name == "age" + assert field1.type is int + + assert field2.python_name == "kind" + assert isinstance(field2.type, EnumDefinition) + assert field2.type.wrapped_cls is UserKind + + +def test_interface(): + class Base(pydantic.BaseModel): + base_field: str + + class BranchA(Base): + field_a: str + + class BranchB(Base): + field_b: int + + class User(pydantic.BaseModel): + age: int + interface_field: Base + + @strawberry.experimental.pydantic2.interface(Base) + class BaseType: + base_field: strawberry.auto + + @strawberry.experimental.pydantic2.type(BranchA) + class BranchAType(BaseType): + field_a: strawberry.auto + + @strawberry.experimental.pydantic2.type(BranchB) + class BranchBType(BaseType): + field_b: strawberry.auto + + @strawberry.experimental.pydantic2.type(User) + class UserType: + age: strawberry.auto + interface_field: strawberry.auto + + definition: TypeDefinition = UserType._type_definition + assert definition.name == "UserType" + + [field1, field2] = definition.fields + + assert field1.python_name == "age" + assert field1.type is int + + assert field2.python_name == "interface_field" + assert field2.type is BaseType + + +def test_both_output_and_input_type(): + class Work(pydantic.BaseModel): + time: float + + class User(pydantic.BaseModel): + name: str + work: Optional[Work] + + class Group(pydantic.BaseModel): + users: List[User] + + # Test both definition orders + @strawberry.experimental.pydantic2.input(Work) + class WorkInput: + time: strawberry.auto + + @strawberry.experimental.pydantic2.type(Work) + class WorkOutput: + time: strawberry.auto + + @strawberry.experimental.pydantic2.type(User) + class UserOutput: + name: strawberry.auto + work: strawberry.auto + + @strawberry.experimental.pydantic2.input(User) + class UserInput: + name: strawberry.auto + work: strawberry.auto + + @strawberry.experimental.pydantic2.input(Group) + class GroupInput: + users: strawberry.auto + + @strawberry.experimental.pydantic2.type(Group) + class GroupOutput: + users: strawberry.auto + + @strawberry.type + class Query: + groups: List[GroupOutput] + + @strawberry.type + class Mutation: + @strawberry.mutation + def updateGroup(group: GroupInput) -> GroupOutput: + pass + + # This triggers the exception from #1504 + schema = strawberry.Schema(query=Query, mutation=Mutation) + expected_schema = """ +input GroupInput { + users: [UserInput!]! +} + +type GroupOutput { + users: [UserOutput!]! +} + +type Mutation { + updateGroup(group: GroupInput!): GroupOutput! +} + +type Query { + groups: [GroupOutput!]! +} + +input UserInput { + name: String! + work: WorkInput = null +} + +type UserOutput { + name: String! + work: WorkOutput +} + +input WorkInput { + time: Float! +} + +type WorkOutput { + time: Float! +}""" + assert schema.as_str().strip() == expected_schema.strip() + + assert Group._strawberry_type == GroupOutput + assert Group._strawberry_input_type == GroupInput + assert User._strawberry_type == UserOutput + assert User._strawberry_input_type == UserInput + assert Work._strawberry_type == WorkOutput + assert Work._strawberry_input_type == WorkInput + + +def test_single_field_changed_type(): + class User(pydantic.BaseModel): + age: int + + @strawberry.experimental.pydantic2.type(User) + class UserType: + age: str + + definition: TypeDefinition = UserType._type_definition + assert definition.name == "UserType" + + [field1] = definition.fields + + assert field1.python_name == "age" + assert field1.graphql_name is None + assert field1.type is str + + +def test_type_with_aliased_pydantic_field_changed_type(): + class UserModel(pydantic.BaseModel): + age_: int = pydantic.Field(..., alias="age") + password: Optional[str] + + @strawberry.experimental.pydantic2.type(UserModel) + class User: + age_: str + password: strawberry.auto + + definition: TypeDefinition = User._type_definition + assert definition.name == "User" + + [field1, field2] = definition.fields + + assert field1.python_name == "age_" + assert field1.type is str + assert field1.graphql_name == "age" + + assert field2.python_name == "password" + assert isinstance(field2.type, StrawberryOptional) + assert field2.type.of_type is str + + +def test_deprecated_fields(): + class User(pydantic.BaseModel): + age: int + password: Optional[str] + other: float + + @strawberry.experimental.pydantic2.type(User) + class UserType: + age: strawberry.auto = strawberry.field(deprecation_reason="Because") + password: strawberry.auto + + definition: TypeDefinition = UserType._type_definition + assert definition.name == "UserType" + + [field1, field2] = definition.fields + + assert field1.python_name == "age" + assert field1.graphql_name is None + assert field1.type is int + assert field1.deprecation_reason == "Because" + + assert field2.python_name == "password" + assert field2.graphql_name is None + assert isinstance(field2.type, StrawberryOptional) + assert field2.type.of_type is str + + +def test_permission_classes(): + class IsAuthenticated(strawberry.BasePermission): + message = "User is not authenticated" + + def has_permission( + self, source: Any, info: strawberry.types.Info, **kwargs + ) -> bool: + return False + + class User(pydantic.BaseModel): + age: int + password: Optional[str] + other: float + + @strawberry.experimental.pydantic2.type(User) + class UserType: + age: strawberry.auto = strawberry.field(permission_classes=[IsAuthenticated]) + password: strawberry.auto + + definition: TypeDefinition = UserType._type_definition + assert definition.name == "UserType" + + [field1, field2] = definition.fields + + assert field1.python_name == "age" + assert field1.graphql_name is None + assert field1.type is int + assert field1.permission_classes == [IsAuthenticated] + + assert field2.python_name == "password" + assert field2.graphql_name is None + assert isinstance(field2.type, StrawberryOptional) + assert field2.type.of_type is str + + +def test_field_directives(): + @strawberry.schema_directive(locations=[Location.FIELD_DEFINITION]) + class Sensitive: + reason: str + + class User(pydantic.BaseModel): + age: int + password: Optional[str] + other: float + + @strawberry.experimental.pydantic2.type(User) + class UserType: + age: strawberry.auto = strawberry.field(directives=[Sensitive(reason="GDPR")]) + password: strawberry.auto + + definition: TypeDefinition = UserType._type_definition + assert definition.name == "UserType" + + [field1, field2] = definition.fields + + assert field1.python_name == "age" + assert field1.graphql_name is None + assert field1.type is int + assert field1.directives == [Sensitive(reason="GDPR")] + + assert field2.python_name == "password" + assert field2.graphql_name is None + assert isinstance(field2.type, StrawberryOptional) + assert field2.type.of_type is str + + +def test_alias_fields(): + class User(pydantic.BaseModel): + age: int + + @strawberry.experimental.pydantic2.type(User) + class UserType: + age: strawberry.auto = strawberry.field(name="ageAlias") + + definition: TypeDefinition = UserType._type_definition + assert definition.name == "UserType" + + field1 = definition.fields[0] + + assert field1.python_name == "age" + assert field1.graphql_name == "ageAlias" + assert field1.type is int + + +def test_alias_fields_with_use_pydantic_alias(): + class User(pydantic.BaseModel): + age: int + state: str = pydantic.Field(alias="statePydantic") + country: str = pydantic.Field(alias="countryPydantic") + + @strawberry.experimental.pydantic2.type(User, use_pydantic_alias=True) + class UserType: + age: strawberry.auto = strawberry.field(name="ageAlias") + state: strawberry.auto = strawberry.field(name="state") + country: strawberry.auto + + definition: TypeDefinition = UserType._type_definition + assert definition.name == "UserType" + + [field1, field2, field3] = definition.fields + + assert field1.python_name == "age" + assert field1.graphql_name == "ageAlias" + + assert field2.python_name == "state" + assert field2.graphql_name == "state" + + assert field3.python_name == "country" + assert field3.graphql_name == "countryPydantic" + + +def test_field_metadata(): + class User(pydantic.BaseModel): + private: bool + public: bool + + @strawberry.experimental.pydantic2.type(User) + class UserType: + private: strawberry.auto = strawberry.field(metadata={"admin_only": True}) + public: strawberry.auto + + definition: TypeDefinition = UserType._type_definition + assert definition.name == "UserType" + + [field1, field2] = definition.fields + + assert field1.python_name == "private" + assert field1.metadata["admin_only"] + + assert field2.python_name == "public" + assert not field2.metadata From 95f979a3d88b1768df0c2593aa600a41011181d6 Mon Sep 17 00:00:00 2001 From: jameschua Date: Sat, 20 May 2023 18:11:22 +0100 Subject: [PATCH 03/26] remove field map --- .../experimental/pydantic/error_type.py | 17 ++--- strawberry/experimental/pydantic/fields.py | 9 ++- strawberry/experimental/pydantic/utils.py | 4 +- strawberry/experimental/pydantic/v2_compat.py | 4 +- .../experimental/pydantic2/error_type.py | 17 ++--- strawberry/experimental/pydantic2/fields.py | 53 +++----------- .../experimental/pydantic2/object_type.py | 69 +++++++++---------- strawberry/experimental/pydantic2/utils.py | 6 +- .../experimental/pydantic2/v2_compat.py | 4 +- tests/experimental/pydantic2/test_basic.py | 47 ++++--------- 10 files changed, 86 insertions(+), 144 deletions(-) diff --git a/strawberry/experimental/pydantic/error_type.py b/strawberry/experimental/pydantic/error_type.py index 5757517acb..f6a2d380ec 100644 --- a/strawberry/experimental/pydantic/error_type.py +++ b/strawberry/experimental/pydantic/error_type.py @@ -29,7 +29,6 @@ from strawberry.utils.typing import get_list_annotation, is_list - from .exceptions import MissingFieldsListError if TYPE_CHECKING: @@ -42,8 +41,6 @@ def get_type_for_field(field: ModelField) -> Union[Any, Type[None], Type[List]]: return field_type_to_type(type_) - - def field_type_to_type(type_: Type) -> Union[Any, List[Any], None]: error_class: Any = str strawberry_type: Any = error_class @@ -67,13 +64,13 @@ def field_type_to_type(type_: Type) -> Union[Any, List[Any], None]: def error_type( - model: Type[BaseModel], - *, - fields: Optional[List[str]] = None, - name: Optional[str] = None, - description: Optional[str] = None, - directives: Optional[Sequence[object]] = (), - all_fields: bool = False, + model: Type[BaseModel], + *, + fields: Optional[List[str]] = None, + name: Optional[str] = None, + description: Optional[str] = None, + directives: Optional[Sequence[object]] = (), + all_fields: bool = False, ) -> Callable[..., Type]: def wrap(cls: Type) -> Type: model_fields = model.__fields__ diff --git a/strawberry/experimental/pydantic/fields.py b/strawberry/experimental/pydantic/fields.py index 24d2b60b36..ebcfe6dd34 100644 --- a/strawberry/experimental/pydantic/fields.py +++ b/strawberry/experimental/pydantic/fields.py @@ -5,8 +5,13 @@ import pydantic from pydantic import BaseModel -from strawberry.experimental.pydantic.v2_compat import lenient_issubclass, get_args, get_origin, is_new_type, \ - new_type_supertype +from strawberry.experimental.pydantic.v2_compat import ( + lenient_issubclass, + get_args, + get_origin, + is_new_type, + new_type_supertype, +) from strawberry.experimental.pydantic.exceptions import ( UnregisteredTypeException, UnsupportedTypeError, diff --git a/strawberry/experimental/pydantic/utils.py b/strawberry/experimental/pydantic/utils.py index 0368fd84f9..00a966f2ed 100644 --- a/strawberry/experimental/pydantic/utils.py +++ b/strawberry/experimental/pydantic/utils.py @@ -75,7 +75,7 @@ def to_tuple(self) -> Tuple[str, Type, dataclasses.Field]: def get_default_factory_for_field( - field: ModelField, + field: ModelField, ) -> Union[NoArgAnyCallable, dataclasses._MISSING_TYPE]: """ Gets the default factory for a pydantic field. @@ -127,7 +127,7 @@ def get_default_factory_for_field( def ensure_all_auto_fields_in_pydantic( - model: Type[BaseModel], auto_fields: Set[str], cls_name: str + model: Type[BaseModel], auto_fields: Set[str], cls_name: str ) -> Union[NoReturn, None]: # Raise error if user defined a strawberry.auto field not present in the model non_existing_fields = list(auto_fields - model.__fields__.keys()) diff --git a/strawberry/experimental/pydantic/v2_compat.py b/strawberry/experimental/pydantic/v2_compat.py index 7d27fa593d..f965136d1c 100644 --- a/strawberry/experimental/pydantic/v2_compat.py +++ b/strawberry/experimental/pydantic/v2_compat.py @@ -1,12 +1,14 @@ import pydantic -if pydantic.VERSION[0] == '2': +if pydantic.VERSION[0] == "2": from pydantic._internal._utils import smart_deepcopy from pydantic._internal._utils import lenient_issubclass from typing_extensions import get_args, get_origin from pydantic._internal._typing_extra import is_new_type + def new_type_supertype(type_): return type_.__supertype__ + else: from pydantic.utils import smart_deepcopy from pydantic.utils import lenient_issubclass diff --git a/strawberry/experimental/pydantic2/error_type.py b/strawberry/experimental/pydantic2/error_type.py index 271eab8775..6933df5e9c 100644 --- a/strawberry/experimental/pydantic2/error_type.py +++ b/strawberry/experimental/pydantic2/error_type.py @@ -29,7 +29,6 @@ from strawberry.utils.typing import get_list_annotation, is_list - from .exceptions import MissingFieldsListError if TYPE_CHECKING: @@ -42,8 +41,6 @@ def get_type_for_field(field: ModelField) -> Union[Any, Type[None], Type[List]]: return field_type_to_type(type_) - - def field_type_to_type(type_: Type) -> Union[Any, List[Any], None]: error_class: Any = str strawberry_type: Any = error_class @@ -67,13 +64,13 @@ def field_type_to_type(type_: Type) -> Union[Any, List[Any], None]: def error_type( - model: Type[BaseModel], - *, - fields: Optional[List[str]] = None, - name: Optional[str] = None, - description: Optional[str] = None, - directives: Optional[Sequence[object]] = (), - all_fields: bool = False, + model: Type[BaseModel], + *, + fields: Optional[List[str]] = None, + name: Optional[str] = None, + description: Optional[str] = None, + directives: Optional[Sequence[object]] = (), + all_fields: bool = False, ) -> Callable[..., Type]: def wrap(cls: Type) -> Type: model_fields = model.__fields__ diff --git a/strawberry/experimental/pydantic2/fields.py b/strawberry/experimental/pydantic2/fields.py index fbe519f3c4..51180cac87 100644 --- a/strawberry/experimental/pydantic2/fields.py +++ b/strawberry/experimental/pydantic2/fields.py @@ -5,8 +5,13 @@ import pydantic from pydantic import BaseModel -from strawberry.experimental.pydantic2.v2_compat import lenient_issubclass, get_args, get_origin, is_new_type, \ - new_type_supertype +from strawberry.experimental.pydantic2.v2_compat import ( + lenient_issubclass, + get_args, + get_origin, + is_new_type, + new_type_supertype, +) from strawberry.experimental.pydantic2.exceptions import ( UnregisteredTypeException, UnsupportedTypeError, @@ -25,48 +30,8 @@ else: raise -ATTR_TO_TYPE_MAP = { - "NoneBytes": Optional[bytes], - "StrBytes": None, - "NoneStrBytes": None, - "StrictStr": str, - "ConstrainedBytes": bytes, - "conbytes": bytes, - "ConstrainedStr": str, - "constr": str, - "EmailStr": str, - "PyObject": None, - "ConstrainedInt": int, - "conint": int, - "PositiveInt": int, - "NegativeInt": int, - "ConstrainedFloat": float, - "confloat": float, - "PositiveFloat": float, - "NegativeFloat": float, - "ConstrainedDecimal": Decimal, - "condecimal": Decimal, - "UUID1": UUID, - "UUID3": UUID, - "UUID4": UUID, - "UUID5": UUID, - "FilePath": None, - "DirectoryPath": None, - "Json": None, - "JsonWrapper": None, - "SecretStr": str, - "SecretBytes": bytes, - "StrictBool": bool, - "StrictInt": int, - "StrictFloat": float, - "PaymentCardNumber": None, - "ByteSize": None, - "AnyUrl": str, - "AnyHttpUrl": str, - "HttpUrl": str, - "PostgresDsn": str, - "RedisDsn": str, -} +# NOTE: To investigate the annotated types +ATTR_TO_TYPE_MAP = {} FIELDS_MAP = { getattr(pydantic, field_name): type diff --git a/strawberry/experimental/pydantic2/object_type.py b/strawberry/experimental/pydantic2/object_type.py index 39c4668143..9a3e67c348 100644 --- a/strawberry/experimental/pydantic2/object_type.py +++ b/strawberry/experimental/pydantic2/object_type.py @@ -37,10 +37,10 @@ if TYPE_CHECKING: from graphql import GraphQLResolveInfo - from pydantic.fields import ModelField + from pydantic.fields import FieldInfo, FieldInfo -def get_type_for_field(field: ModelField, is_input: bool): # noqa: ANN201 +def get_type_for_field(field: FieldInfo, is_input: bool): # noqa: ANN201 outer_type = field.outer_type_ replaced_type = replace_types_recursively(outer_type, is_input) @@ -55,7 +55,8 @@ def get_type_for_field(field: ModelField, is_input: bool): # noqa: ANN201 def _build_dataclass_creation_fields( - field: ModelField, + field_name: str, + field: FieldInfo, is_input: bool, existing_fields: Dict[str, StrawberryField], auto_fields_set: Set[str], @@ -63,33 +64,33 @@ def _build_dataclass_creation_fields( ) -> DataclassCreationFields: field_type = ( get_type_for_field(field, is_input) - if field.name in auto_fields_set - else existing_fields[field.name].type + if field_name in auto_fields_set + else existing_fields[field_name].type ) if ( - field.name in existing_fields - and existing_fields[field.name].base_resolver is not None + field_name in existing_fields + and existing_fields[field_name].base_resolver is not None ): # if the user has defined a resolver for this field, always use it - strawberry_field = existing_fields[field.name] + strawberry_field = existing_fields[field_name] else: # otherwise we build an appropriate strawberry field that resolves it - existing_field = existing_fields.get(field.name) + existing_field = existing_fields.get(field_name) graphql_name = None if existing_field and existing_field.graphql_name: graphql_name = existing_field.graphql_name - elif field.has_alias and use_pydantic_alias: + elif field.alias and use_pydantic_alias: graphql_name = field.alias strawberry_field = StrawberryField( - python_name=field.name, + python_name=field_name, graphql_name=graphql_name, # always unset because we use default_factory instead default=dataclasses.MISSING, default_factory=get_default_factory_for_field(field), type_annotation=StrawberryAnnotation.from_annotation(field_type), - description=field.field_info.description, + description=field.description, deprecation_reason=( existing_field.deprecation_reason if existing_field else None ), @@ -101,7 +102,7 @@ def _build_dataclass_creation_fields( ) return DataclassCreationFields( - name=field.name, + name=field_name, field_type=field_type, field=strawberry_field, ) @@ -117,7 +118,6 @@ def _build_dataclass_creation_fields( def type( model: Type[PydanticModel], *, - fields: Optional[List[str]] = None, name: Optional[str] = None, is_input: bool = False, is_interface: bool = False, @@ -127,32 +127,22 @@ def type( use_pydantic_alias: bool = True, ) -> Callable[..., Type[StrawberryTypeFromPydantic[PydanticModel]]]: def wrap(cls: Any) -> Type[StrawberryTypeFromPydantic[PydanticModel]]: - model_fields = model.__fields__ - original_fields_set = set(fields) if fields else set() - - if fields: - warnings.warn( - "`fields` is deprecated, use `auto` type annotations instead", - DeprecationWarning, - stacklevel=2, - ) + model_fields: Dict[str, FieldInfo] = model.model_fields existing_fields = getattr(cls, "__annotations__", {}) # these are the fields that matched a field name in the pydantic model # and should copy their alias from the pydantic model - fields_set = original_fields_set.union( - {name for name, _ in existing_fields.items() if name in model_fields} - ) + fields_set = { + name for name, _ in existing_fields.items() if name in model_fields + } # these are the fields that were marked with strawberry.auto and # should copy their type from the pydantic model - auto_fields_set = original_fields_set.union( - { - name - for name, type_ in existing_fields.items() - if isinstance(type_, StrawberryAuto) - } - ) + auto_fields_set = { + name + for name, type_ in existing_fields.items() + if isinstance(type_, StrawberryAuto) + } if all_fields: if fields_set: @@ -176,11 +166,16 @@ def wrap(cls: Any) -> Type[StrawberryTypeFromPydantic[PydanticModel]]: extra_fields = cast(List[dataclasses.Field], extra_strawberry_fields) private_fields = get_private_fields(wrapped) - extra_fields_dict = {field.name: field for field in extra_strawberry_fields} + extra_fields_dict = {field_name: field for field in extra_strawberry_fields} all_model_fields: List[DataclassCreationFields] = [ _build_dataclass_creation_fields( - field, is_input, extra_fields_dict, auto_fields_set, use_pydantic_alias + field_name=field_name, + field=field, + is_input=is_input, + existing_fields=extra_fields_dict, + auto_fields_set=auto_fields_set, + use_pydantic_alias=use_pydantic_alias, ) for field_name, field in model_fields.items() if field_name in fields_set @@ -188,12 +183,12 @@ def wrap(cls: Any) -> Type[StrawberryTypeFromPydantic[PydanticModel]]: all_model_fields = [ DataclassCreationFields( - name=field.name, + name=field_name, field_type=field.type, field=field, ) for field in extra_fields + private_fields - if field.name not in fields_set + if field_name not in fields_set ] + all_model_fields # Implicitly define `is_type_of` to support interfaces/unions that use diff --git a/strawberry/experimental/pydantic2/utils.py b/strawberry/experimental/pydantic2/utils.py index 63020fdbc0..0e6942b3fb 100644 --- a/strawberry/experimental/pydantic2/utils.py +++ b/strawberry/experimental/pydantic2/utils.py @@ -75,7 +75,7 @@ def to_tuple(self) -> Tuple[str, Type, dataclasses.Field]: def get_default_factory_for_field( - field: ModelField, + field: ModelField, ) -> Union[NoArgAnyCallable, dataclasses._MISSING_TYPE]: """ Gets the default factory for a pydantic field. @@ -127,10 +127,10 @@ def get_default_factory_for_field( def ensure_all_auto_fields_in_pydantic( - model: Type[BaseModel], auto_fields: Set[str], cls_name: str + model: Type[BaseModel], auto_fields: Set[str], cls_name: str ) -> Union[NoReturn, None]: # Raise error if user defined a strawberry.auto field not present in the model - non_existing_fields = list(auto_fields - model.__fields__.keys()) + non_existing_fields = list(auto_fields - model.model_fields.keys()) if non_existing_fields: raise AutoFieldsNotInBaseModelError( diff --git a/strawberry/experimental/pydantic2/v2_compat.py b/strawberry/experimental/pydantic2/v2_compat.py index 7d27fa593d..f965136d1c 100644 --- a/strawberry/experimental/pydantic2/v2_compat.py +++ b/strawberry/experimental/pydantic2/v2_compat.py @@ -1,12 +1,14 @@ import pydantic -if pydantic.VERSION[0] == '2': +if pydantic.VERSION[0] == "2": from pydantic._internal._utils import smart_deepcopy from pydantic._internal._utils import lenient_issubclass from typing_extensions import get_args, get_origin from pydantic._internal._typing_extra import is_new_type + def new_type_supertype(type_): return type_.__supertype__ + else: from pydantic.utils import smart_deepcopy from pydantic.utils import lenient_issubclass diff --git a/tests/experimental/pydantic2/test_basic.py b/tests/experimental/pydantic2/test_basic.py index 9dfd4c9ef9..c829e924cf 100644 --- a/tests/experimental/pydantic2/test_basic.py +++ b/tests/experimental/pydantic2/test_basic.py @@ -14,31 +14,6 @@ from strawberry.union import StrawberryUnion -def test_basic_type_field_list(): - class User(pydantic.BaseModel): - age: int - password: Optional[str] - - with pytest.deprecated_call(): - @strawberry.experimental.pydantic2.type(User, fields=["age", "password"]) - class UserType: - pass - - definition: TypeDefinition = UserType._type_definition - assert definition.name == "UserType" - - [field1, field2] = definition.fields - - assert field1.python_name == "age" - assert field1.graphql_name is None - assert field1.type is int - - assert field2.python_name == "password" - assert field2.graphql_name is None - assert isinstance(field2.type, StrawberryOptional) - assert field2.type.of_type is str - - def test_basic_type_all_fields(): class User(pydantic.BaseModel): age: int @@ -70,9 +45,10 @@ class User(pydantic.BaseModel): password: Optional[str] with pytest.raises( - UserWarning, - match=("Using all_fields overrides any explicitly defined fields"), + UserWarning, + match=("Using all_fields overrides any explicitly defined fields"), ): + @strawberry.experimental.pydantic2.type(User, all_fields=True) class UserType: age: strawberry.auto @@ -148,9 +124,10 @@ class User(pydantic.BaseModel): group: Group with pytest.raises( - strawberry.experimental.pydantic2.UnregisteredTypeException, - match=("Cannot find a Strawberry Type for (.*) did you forget to register it?"), + strawberry.experimental.pydantic2.UnregisteredTypeException, + match=("Cannot find a Strawberry Type for (.*) did you forget to register it?"), ): + @strawberry.experimental.pydantic2.type(User) class UserType: age: strawberry.auto @@ -172,9 +149,10 @@ class GroupType: name: strawberry.auto with pytest.raises( - strawberry.experimental.pydantic2.UnregisteredTypeException, - match=("Cannot find a Strawberry Type for (.*) did you forget to register it?"), + strawberry.experimental.pydantic2.UnregisteredTypeException, + match=("Cannot find a Strawberry Type for (.*) did you forget to register it?"), ): + @strawberry.experimental.pydantic2.input(User) class UserInputType: age: strawberry.auto @@ -262,6 +240,7 @@ class User(pydantic.BaseModel): password: Optional[str] with pytest.raises(MissingFieldsListError): + @strawberry.experimental.pydantic2.type(User) class UserType: pass @@ -368,8 +347,8 @@ class UserType: @pytest.mark.xfail( reason=( - "passing default values when extending types from pydantic is not" - "supported. https://github.com/strawberry-graphql/strawberry/issues/829" + "passing default values when extending types from pydantic is not" + "supported. https://github.com/strawberry-graphql/strawberry/issues/829" ) ) def test_type_with_fields_coming_from_strawberry_and_pydantic_with_default(): @@ -733,7 +712,7 @@ class IsAuthenticated(strawberry.BasePermission): message = "User is not authenticated" def has_permission( - self, source: Any, info: strawberry.types.Info, **kwargs + self, source: Any, info: strawberry.types.Info, **kwargs ) -> bool: return False From ddf7aded3500dd49f4770597d314dce13273afdd Mon Sep 17 00:00:00 2001 From: jameschua Date: Sat, 20 May 2023 18:21:52 +0100 Subject: [PATCH 04/26] yay test pass --- strawberry/experimental/pydantic2/fields.py | 16 +-- .../experimental/pydantic2/object_type.py | 100 ++++++++++-------- strawberry/experimental/pydantic2/utils.py | 12 ++- tests/experimental/pydantic2/test_basic.py | 22 ++-- 4 files changed, 79 insertions(+), 71 deletions(-) diff --git a/strawberry/experimental/pydantic2/fields.py b/strawberry/experimental/pydantic2/fields.py index 51180cac87..d35f5f6dc3 100644 --- a/strawberry/experimental/pydantic2/fields.py +++ b/strawberry/experimental/pydantic2/fields.py @@ -41,14 +41,14 @@ def get_basic_type(type_: Any) -> Type[Any]: - if lenient_issubclass(type_, pydantic.ConstrainedInt): - return int - if lenient_issubclass(type_, pydantic.ConstrainedFloat): - return float - if lenient_issubclass(type_, pydantic.ConstrainedStr): - return str - if lenient_issubclass(type_, pydantic.ConstrainedList): - return List[get_basic_type(type_.item_type)] # type: ignore + # if lenient_issubclass(type_, pydantic.ConstrainedInt): + # return int + # if lenient_issubclass(type_, pydantic.ConstrainedFloat): + # return float + # if lenient_issubclass(type_, pydantic.ConstrainedStr): + # return str + # if lenient_issubclass(type_, pydantic.ConstrainedList): + # return List[get_basic_type(type_.item_type)] # type: ignore if type_ in FIELDS_MAP: type_ = FIELDS_MAP.get(type_) diff --git a/strawberry/experimental/pydantic2/object_type.py b/strawberry/experimental/pydantic2/object_type.py index 9a3e67c348..28b7ccfe59 100644 --- a/strawberry/experimental/pydantic2/object_type.py +++ b/strawberry/experimental/pydantic2/object_type.py @@ -16,6 +16,8 @@ cast, ) +from pydantic._internal._fields import Undefined + from strawberry.annotation import StrawberryAnnotation from strawberry.auto import StrawberryAuto from strawberry.experimental.pydantic2.conversion import ( @@ -40,14 +42,18 @@ from pydantic.fields import FieldInfo, FieldInfo +def is_required(field: FieldInfo) -> bool: + return field.default is Undefined and field.default_factory is None + + def get_type_for_field(field: FieldInfo, is_input: bool): # noqa: ANN201 - outer_type = field.outer_type_ + outer_type = field.annotation replaced_type = replace_types_recursively(outer_type, is_input) default_defined: bool = ( - field.default_factory is not None or field.default is not None + field.default_factory is not None or field.default is not None ) - should_add_optional: bool = not (field.required or default_defined) + should_add_optional: bool = not (is_required(field) or default_defined) if should_add_optional: return Optional[replaced_type] else: @@ -55,12 +61,12 @@ def get_type_for_field(field: FieldInfo, is_input: bool): # noqa: ANN201 def _build_dataclass_creation_fields( - field_name: str, - field: FieldInfo, - is_input: bool, - existing_fields: Dict[str, StrawberryField], - auto_fields_set: Set[str], - use_pydantic_alias: bool, + field_name: str, + field: FieldInfo, + is_input: bool, + existing_fields: Dict[str, StrawberryField], + auto_fields_set: Set[str], + use_pydantic_alias: bool, ) -> DataclassCreationFields: field_type = ( get_type_for_field(field, is_input) @@ -69,8 +75,8 @@ def _build_dataclass_creation_fields( ) if ( - field_name in existing_fields - and existing_fields[field_name].base_resolver is not None + field_name in existing_fields + and existing_fields[field_name].base_resolver is not None ): # if the user has defined a resolver for this field, always use it strawberry_field = existing_fields[field_name] @@ -116,15 +122,15 @@ def _build_dataclass_creation_fields( def type( - model: Type[PydanticModel], - *, - name: Optional[str] = None, - is_input: bool = False, - is_interface: bool = False, - description: Optional[str] = None, - directives: Optional[Sequence[object]] = (), - all_fields: bool = False, - use_pydantic_alias: bool = True, + model: Type[PydanticModel], + *, + name: Optional[str] = None, + is_input: bool = False, + is_interface: bool = False, + description: Optional[str] = None, + directives: Optional[Sequence[object]] = (), + all_fields: bool = False, + use_pydantic_alias: bool = True, ) -> Callable[..., Type[StrawberryTypeFromPydantic[PydanticModel]]]: def wrap(cls: Any) -> Type[StrawberryTypeFromPydantic[PydanticModel]]: model_fields: Dict[str, FieldInfo] = model.model_fields @@ -182,14 +188,14 @@ def wrap(cls: Any) -> Type[StrawberryTypeFromPydantic[PydanticModel]]: ] all_model_fields = [ - DataclassCreationFields( - name=field_name, - field_type=field.type, - field=field, - ) - for field in extra_fields + private_fields - if field_name not in fields_set - ] + all_model_fields + DataclassCreationFields( + name=field_name, + field_type=field.type, + field=field, + ) + for field in extra_fields + private_fields + if field_name not in fields_set + ] + all_model_fields # Implicitly define `is_type_of` to support interfaces/unions that use # pydantic objects (not the corresponding strawberry type) @@ -257,7 +263,7 @@ def is_type_of(cls: Type, obj: Any, _info: GraphQLResolveInfo) -> bool: cls._pydantic_type = model def from_pydantic_default( - instance: PydanticModel, extra: Optional[Dict[str, Any]] = None + instance: PydanticModel, extra: Optional[Dict[str, Any]] = None ) -> StrawberryTypeFromPydantic[PydanticModel]: ret = convert_pydantic_model_to_strawberry_class( cls=cls, model_instance=instance, extra=extra @@ -286,15 +292,15 @@ def to_pydantic_default(self: Any, **kwargs) -> PydanticModel: def input( - model: Type[PydanticModel], - *, - fields: Optional[List[str]] = None, - name: Optional[str] = None, - is_interface: bool = False, - description: Optional[str] = None, - directives: Optional[Sequence[object]] = (), - all_fields: bool = False, - use_pydantic_alias: bool = True, + model: Type[PydanticModel], + *, + fields: Optional[List[str]] = None, + name: Optional[str] = None, + is_interface: bool = False, + description: Optional[str] = None, + directives: Optional[Sequence[object]] = (), + all_fields: bool = False, + use_pydantic_alias: bool = True, ) -> Callable[..., Type[StrawberryTypeFromPydantic[PydanticModel]]]: """Convenience decorator for creating an input type from a Pydantic model. Equal to partial(type, is_input=True) @@ -314,15 +320,15 @@ def input( def interface( - model: Type[PydanticModel], - *, - fields: Optional[List[str]] = None, - name: Optional[str] = None, - is_input: bool = False, - description: Optional[str] = None, - directives: Optional[Sequence[object]] = (), - all_fields: bool = False, - use_pydantic_alias: bool = True, + model: Type[PydanticModel], + *, + fields: Optional[List[str]] = None, + name: Optional[str] = None, + is_input: bool = False, + description: Optional[str] = None, + directives: Optional[Sequence[object]] = (), + all_fields: bool = False, + use_pydantic_alias: bool = True, ) -> Callable[..., Type[StrawberryTypeFromPydantic[PydanticModel]]]: """Convenience decorator for creating an interface type from a Pydantic model. Equal to partial(type, is_interface=True) diff --git a/strawberry/experimental/pydantic2/utils.py b/strawberry/experimental/pydantic2/utils.py index 0e6942b3fb..cb125c9644 100644 --- a/strawberry/experimental/pydantic2/utils.py +++ b/strawberry/experimental/pydantic2/utils.py @@ -14,6 +14,8 @@ cast, ) +from pydantic._internal._fields import Undefined +from pydantic._internal._utils import smart_deepcopy from strawberry.experimental.pydantic2.exceptions import ( AutoFieldsNotInBaseModelError, @@ -74,8 +76,12 @@ def to_tuple(self) -> Tuple[str, Type, dataclasses.Field]: return self.name, self.field_type, self.field +def is_required(field: ModelField) -> bool: + return field.default is Undefined and field.default_factory is None + + def get_default_factory_for_field( - field: ModelField, + field: ModelField, ) -> Union[NoArgAnyCallable, dataclasses._MISSING_TYPE]: """ Gets the default factory for a pydantic field. @@ -120,14 +126,14 @@ def get_default_factory_for_field( # if we don't have default or default_factory, but the field is not required, # we should return a factory that returns None - if not field.required: + if not is_required(field): return lambda: None return dataclasses.MISSING def ensure_all_auto_fields_in_pydantic( - model: Type[BaseModel], auto_fields: Set[str], cls_name: str + model: Type[BaseModel], auto_fields: Set[str], cls_name: str ) -> Union[NoReturn, None]: # Raise error if user defined a strawberry.auto field not present in the model non_existing_fields = list(auto_fields - model.model_fields.keys()) diff --git a/tests/experimental/pydantic2/test_basic.py b/tests/experimental/pydantic2/test_basic.py index c829e924cf..dd177148e9 100644 --- a/tests/experimental/pydantic2/test_basic.py +++ b/tests/experimental/pydantic2/test_basic.py @@ -45,10 +45,9 @@ class User(pydantic.BaseModel): password: Optional[str] with pytest.raises( - UserWarning, - match=("Using all_fields overrides any explicitly defined fields"), + UserWarning, + match=("Using all_fields overrides any explicitly defined fields"), ): - @strawberry.experimental.pydantic2.type(User, all_fields=True) class UserType: age: strawberry.auto @@ -124,10 +123,9 @@ class User(pydantic.BaseModel): group: Group with pytest.raises( - strawberry.experimental.pydantic2.UnregisteredTypeException, - match=("Cannot find a Strawberry Type for (.*) did you forget to register it?"), + strawberry.experimental.pydantic2.UnregisteredTypeException, + match=("Cannot find a Strawberry Type for (.*) did you forget to register it?"), ): - @strawberry.experimental.pydantic2.type(User) class UserType: age: strawberry.auto @@ -149,10 +147,9 @@ class GroupType: name: strawberry.auto with pytest.raises( - strawberry.experimental.pydantic2.UnregisteredTypeException, - match=("Cannot find a Strawberry Type for (.*) did you forget to register it?"), + strawberry.experimental.pydantic2.UnregisteredTypeException, + match=("Cannot find a Strawberry Type for (.*) did you forget to register it?"), ): - @strawberry.experimental.pydantic2.input(User) class UserInputType: age: strawberry.auto @@ -240,7 +237,6 @@ class User(pydantic.BaseModel): password: Optional[str] with pytest.raises(MissingFieldsListError): - @strawberry.experimental.pydantic2.type(User) class UserType: pass @@ -347,8 +343,8 @@ class UserType: @pytest.mark.xfail( reason=( - "passing default values when extending types from pydantic is not" - "supported. https://github.com/strawberry-graphql/strawberry/issues/829" + "passing default values when extending types from pydantic is not" + "supported. https://github.com/strawberry-graphql/strawberry/issues/829" ) ) def test_type_with_fields_coming_from_strawberry_and_pydantic_with_default(): @@ -712,7 +708,7 @@ class IsAuthenticated(strawberry.BasePermission): message = "User is not authenticated" def has_permission( - self, source: Any, info: strawberry.types.Info, **kwargs + self, source: Any, info: strawberry.types.Info, **kwargs ) -> bool: return False From 9734494cf17af07e90e25dede70de3384e5ee779 Mon Sep 17 00:00:00 2001 From: jameschua Date: Sat, 20 May 2023 18:25:21 +0100 Subject: [PATCH 05/26] fix fields --- .../experimental/pydantic2/object_type.py | 92 +++++++++---------- strawberry/experimental/pydantic2/utils.py | 4 +- tests/experimental/pydantic2/test_basic.py | 22 +++-- 3 files changed, 61 insertions(+), 57 deletions(-) diff --git a/strawberry/experimental/pydantic2/object_type.py b/strawberry/experimental/pydantic2/object_type.py index 28b7ccfe59..8d832b83ef 100644 --- a/strawberry/experimental/pydantic2/object_type.py +++ b/strawberry/experimental/pydantic2/object_type.py @@ -51,7 +51,7 @@ def get_type_for_field(field: FieldInfo, is_input: bool): # noqa: ANN201 replaced_type = replace_types_recursively(outer_type, is_input) default_defined: bool = ( - field.default_factory is not None or field.default is not None + field.default_factory is not None or field.default is not None ) should_add_optional: bool = not (is_required(field) or default_defined) if should_add_optional: @@ -61,12 +61,12 @@ def get_type_for_field(field: FieldInfo, is_input: bool): # noqa: ANN201 def _build_dataclass_creation_fields( - field_name: str, - field: FieldInfo, - is_input: bool, - existing_fields: Dict[str, StrawberryField], - auto_fields_set: Set[str], - use_pydantic_alias: bool, + field_name: str, + field: FieldInfo, + is_input: bool, + existing_fields: Dict[str, StrawberryField], + auto_fields_set: Set[str], + use_pydantic_alias: bool, ) -> DataclassCreationFields: field_type = ( get_type_for_field(field, is_input) @@ -75,8 +75,8 @@ def _build_dataclass_creation_fields( ) if ( - field_name in existing_fields - and existing_fields[field_name].base_resolver is not None + field_name in existing_fields + and existing_fields[field_name].base_resolver is not None ): # if the user has defined a resolver for this field, always use it strawberry_field = existing_fields[field_name] @@ -122,15 +122,15 @@ def _build_dataclass_creation_fields( def type( - model: Type[PydanticModel], - *, - name: Optional[str] = None, - is_input: bool = False, - is_interface: bool = False, - description: Optional[str] = None, - directives: Optional[Sequence[object]] = (), - all_fields: bool = False, - use_pydantic_alias: bool = True, + model: Type[PydanticModel], + *, + name: Optional[str] = None, + is_input: bool = False, + is_interface: bool = False, + description: Optional[str] = None, + directives: Optional[Sequence[object]] = (), + all_fields: bool = False, + use_pydantic_alias: bool = True, ) -> Callable[..., Type[StrawberryTypeFromPydantic[PydanticModel]]]: def wrap(cls: Any) -> Type[StrawberryTypeFromPydantic[PydanticModel]]: model_fields: Dict[str, FieldInfo] = model.model_fields @@ -172,7 +172,7 @@ def wrap(cls: Any) -> Type[StrawberryTypeFromPydantic[PydanticModel]]: extra_fields = cast(List[dataclasses.Field], extra_strawberry_fields) private_fields = get_private_fields(wrapped) - extra_fields_dict = {field_name: field for field in extra_strawberry_fields} + extra_fields_dict = {field.name: field for field in extra_strawberry_fields} all_model_fields: List[DataclassCreationFields] = [ _build_dataclass_creation_fields( @@ -188,14 +188,14 @@ def wrap(cls: Any) -> Type[StrawberryTypeFromPydantic[PydanticModel]]: ] all_model_fields = [ - DataclassCreationFields( - name=field_name, - field_type=field.type, - field=field, - ) - for field in extra_fields + private_fields - if field_name not in fields_set - ] + all_model_fields + DataclassCreationFields( + name=field.name, + field_type=field.type, + field=field, + ) + for field in extra_fields + private_fields + if field.name not in fields_set + ] + all_model_fields # Implicitly define `is_type_of` to support interfaces/unions that use # pydantic objects (not the corresponding strawberry type) @@ -263,7 +263,7 @@ def is_type_of(cls: Type, obj: Any, _info: GraphQLResolveInfo) -> bool: cls._pydantic_type = model def from_pydantic_default( - instance: PydanticModel, extra: Optional[Dict[str, Any]] = None + instance: PydanticModel, extra: Optional[Dict[str, Any]] = None ) -> StrawberryTypeFromPydantic[PydanticModel]: ret = convert_pydantic_model_to_strawberry_class( cls=cls, model_instance=instance, extra=extra @@ -292,15 +292,15 @@ def to_pydantic_default(self: Any, **kwargs) -> PydanticModel: def input( - model: Type[PydanticModel], - *, - fields: Optional[List[str]] = None, - name: Optional[str] = None, - is_interface: bool = False, - description: Optional[str] = None, - directives: Optional[Sequence[object]] = (), - all_fields: bool = False, - use_pydantic_alias: bool = True, + model: Type[PydanticModel], + *, + fields: Optional[List[str]] = None, + name: Optional[str] = None, + is_interface: bool = False, + description: Optional[str] = None, + directives: Optional[Sequence[object]] = (), + all_fields: bool = False, + use_pydantic_alias: bool = True, ) -> Callable[..., Type[StrawberryTypeFromPydantic[PydanticModel]]]: """Convenience decorator for creating an input type from a Pydantic model. Equal to partial(type, is_input=True) @@ -320,15 +320,15 @@ def input( def interface( - model: Type[PydanticModel], - *, - fields: Optional[List[str]] = None, - name: Optional[str] = None, - is_input: bool = False, - description: Optional[str] = None, - directives: Optional[Sequence[object]] = (), - all_fields: bool = False, - use_pydantic_alias: bool = True, + model: Type[PydanticModel], + *, + fields: Optional[List[str]] = None, + name: Optional[str] = None, + is_input: bool = False, + description: Optional[str] = None, + directives: Optional[Sequence[object]] = (), + all_fields: bool = False, + use_pydantic_alias: bool = True, ) -> Callable[..., Type[StrawberryTypeFromPydantic[PydanticModel]]]: """Convenience decorator for creating an interface type from a Pydantic model. Equal to partial(type, is_interface=True) diff --git a/strawberry/experimental/pydantic2/utils.py b/strawberry/experimental/pydantic2/utils.py index cb125c9644..894aedf5ad 100644 --- a/strawberry/experimental/pydantic2/utils.py +++ b/strawberry/experimental/pydantic2/utils.py @@ -81,7 +81,7 @@ def is_required(field: ModelField) -> bool: def get_default_factory_for_field( - field: ModelField, + field: ModelField, ) -> Union[NoArgAnyCallable, dataclasses._MISSING_TYPE]: """ Gets the default factory for a pydantic field. @@ -133,7 +133,7 @@ def get_default_factory_for_field( def ensure_all_auto_fields_in_pydantic( - model: Type[BaseModel], auto_fields: Set[str], cls_name: str + model: Type[BaseModel], auto_fields: Set[str], cls_name: str ) -> Union[NoReturn, None]: # Raise error if user defined a strawberry.auto field not present in the model non_existing_fields = list(auto_fields - model.model_fields.keys()) diff --git a/tests/experimental/pydantic2/test_basic.py b/tests/experimental/pydantic2/test_basic.py index dd177148e9..c829e924cf 100644 --- a/tests/experimental/pydantic2/test_basic.py +++ b/tests/experimental/pydantic2/test_basic.py @@ -45,9 +45,10 @@ class User(pydantic.BaseModel): password: Optional[str] with pytest.raises( - UserWarning, - match=("Using all_fields overrides any explicitly defined fields"), + UserWarning, + match=("Using all_fields overrides any explicitly defined fields"), ): + @strawberry.experimental.pydantic2.type(User, all_fields=True) class UserType: age: strawberry.auto @@ -123,9 +124,10 @@ class User(pydantic.BaseModel): group: Group with pytest.raises( - strawberry.experimental.pydantic2.UnregisteredTypeException, - match=("Cannot find a Strawberry Type for (.*) did you forget to register it?"), + strawberry.experimental.pydantic2.UnregisteredTypeException, + match=("Cannot find a Strawberry Type for (.*) did you forget to register it?"), ): + @strawberry.experimental.pydantic2.type(User) class UserType: age: strawberry.auto @@ -147,9 +149,10 @@ class GroupType: name: strawberry.auto with pytest.raises( - strawberry.experimental.pydantic2.UnregisteredTypeException, - match=("Cannot find a Strawberry Type for (.*) did you forget to register it?"), + strawberry.experimental.pydantic2.UnregisteredTypeException, + match=("Cannot find a Strawberry Type for (.*) did you forget to register it?"), ): + @strawberry.experimental.pydantic2.input(User) class UserInputType: age: strawberry.auto @@ -237,6 +240,7 @@ class User(pydantic.BaseModel): password: Optional[str] with pytest.raises(MissingFieldsListError): + @strawberry.experimental.pydantic2.type(User) class UserType: pass @@ -343,8 +347,8 @@ class UserType: @pytest.mark.xfail( reason=( - "passing default values when extending types from pydantic is not" - "supported. https://github.com/strawberry-graphql/strawberry/issues/829" + "passing default values when extending types from pydantic is not" + "supported. https://github.com/strawberry-graphql/strawberry/issues/829" ) ) def test_type_with_fields_coming_from_strawberry_and_pydantic_with_default(): @@ -708,7 +712,7 @@ class IsAuthenticated(strawberry.BasePermission): message = "User is not authenticated" def has_permission( - self, source: Any, info: strawberry.types.Info, **kwargs + self, source: Any, info: strawberry.types.Info, **kwargs ) -> bool: return False From 2b2c1a8cfe310b677ad5d1cc879f750da2e77e5b Mon Sep 17 00:00:00 2001 From: jameschua Date: Sat, 20 May 2023 18:27:39 +0100 Subject: [PATCH 06/26] fix passing fields --- strawberry/experimental/pydantic2/object_type.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/strawberry/experimental/pydantic2/object_type.py b/strawberry/experimental/pydantic2/object_type.py index 8d832b83ef..dd4c928478 100644 --- a/strawberry/experimental/pydantic2/object_type.py +++ b/strawberry/experimental/pydantic2/object_type.py @@ -294,7 +294,6 @@ def to_pydantic_default(self: Any, **kwargs) -> PydanticModel: def input( model: Type[PydanticModel], *, - fields: Optional[List[str]] = None, name: Optional[str] = None, is_interface: bool = False, description: Optional[str] = None, @@ -308,7 +307,6 @@ def input( """ return type( model=model, - fields=fields, name=name, is_input=True, is_interface=is_interface, @@ -322,7 +320,6 @@ def input( def interface( model: Type[PydanticModel], *, - fields: Optional[List[str]] = None, name: Optional[str] = None, is_input: bool = False, description: Optional[str] = None, @@ -336,7 +333,6 @@ def interface( """ return type( model=model, - fields=fields, name=name, is_input=is_input, is_interface=True, From df87fb708030b41fea2db1db98695df15a161ccb Mon Sep 17 00:00:00 2001 From: jameschua Date: Sat, 20 May 2023 18:44:25 +0100 Subject: [PATCH 07/26] it worksgit stage . --- strawberry/experimental/pydantic2/utils.py | 12 +- tests/experimental/pydantic2/__init__.py | 0 .../experimental/pydantic2/schema/__init__.py | 0 .../pydantic2/schema/test_basic.py | 429 ++++++++++++++++++ tests/experimental/pydantic2/test_basic.py | 2 +- 5 files changed, 435 insertions(+), 8 deletions(-) create mode 100644 tests/experimental/pydantic2/__init__.py create mode 100644 tests/experimental/pydantic2/schema/__init__.py create mode 100644 tests/experimental/pydantic2/schema/test_basic.py diff --git a/strawberry/experimental/pydantic2/utils.py b/strawberry/experimental/pydantic2/utils.py index 894aedf5ad..c2c17327c6 100644 --- a/strawberry/experimental/pydantic2/utils.py +++ b/strawberry/experimental/pydantic2/utils.py @@ -14,7 +14,7 @@ cast, ) -from pydantic._internal._fields import Undefined +from pydantic._internal._fields import Undefined, _UndefinedType from pydantic._internal._utils import smart_deepcopy from strawberry.experimental.pydantic2.exceptions import ( @@ -93,14 +93,12 @@ def get_default_factory_for_field( """ # replace dataclasses.MISSING with our own UNSET to make comparisons easier default_factory = ( - field.default_factory - if field.default_factory is not dataclasses.MISSING - else UNSET + field.default_factory if field.default_factory is not None else UNSET ) - default = field.default if field.default is not dataclasses.MISSING else UNSET + default = field.default if not isinstance(field.default, _UndefinedType) else UNSET - has_factory = default_factory is not None and default_factory is not UNSET - has_default = default is not None and default is not UNSET + has_factory = default_factory is not UNSET + has_default = default is not UNSET # defining both default and default_factory is not supported diff --git a/tests/experimental/pydantic2/__init__.py b/tests/experimental/pydantic2/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/experimental/pydantic2/schema/__init__.py b/tests/experimental/pydantic2/schema/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/experimental/pydantic2/schema/test_basic.py b/tests/experimental/pydantic2/schema/test_basic.py new file mode 100644 index 0000000000..e72acbe6a4 --- /dev/null +++ b/tests/experimental/pydantic2/schema/test_basic.py @@ -0,0 +1,429 @@ +import textwrap +from enum import Enum +from typing import List, Optional, Union + +import pydantic + +import strawberry + + +def test_all_fields(): + class UserModel(pydantic.BaseModel): + age: int + password: Optional[str] + + @strawberry.experimental.pydantic2.type(UserModel, all_fields=True) + class User: + pass + + @strawberry.type + class Query: + @strawberry.field + def user(self) -> User: + return User(age=1, password="ABC") + + schema = strawberry.Schema(query=Query) + + expected_schema = """ + type Query { + user: User! + } + + type User { + age: Int! + password: String + } + """ + + assert str(schema) == textwrap.dedent(expected_schema).strip() + + query = "{ user { age } }" + + result = schema.execute_sync(query) + + assert not result.errors + assert result.data["user"]["age"] == 1 + + +def test_auto_fields(): + class UserModel(pydantic.BaseModel): + age: int + password: Optional[str] + other: float + + @strawberry.experimental.pydantic2.type(UserModel) + class User: + age: strawberry.auto + password: strawberry.auto + + @strawberry.type + class Query: + @strawberry.field + def user(self) -> User: + return User(age=1, password="ABC") + + schema = strawberry.Schema(query=Query) + + expected_schema = """ + type Query { + user: User! + } + + type User { + age: Int! + password: String + } + """ + + assert str(schema) == textwrap.dedent(expected_schema).strip() + + query = "{ user { age } }" + + result = schema.execute_sync(query) + + assert not result.errors + assert result.data["user"]["age"] == 1 + + +def test_basic_alias_type(): + class UserModel(pydantic.BaseModel): + age_: int = pydantic.Field(..., alias="age") + password: Optional[str] + + @strawberry.experimental.pydantic2.type(UserModel, all_fields=True) + class User: + pass + + @strawberry.type + class Query: + @strawberry.field + def user(self) -> User: + return User(age=1, password="ABC") + + schema = strawberry.Schema(query=Query) + + expected_schema = """ + type Query { + user: User! + } + + type User { + age: Int! + password: String + } + """ + + assert str(schema) == textwrap.dedent(expected_schema).strip() + + +def test_basic_type_with_list(): + class UserModel(pydantic.BaseModel): + age: int + friend_names: List[str] + + @strawberry.experimental.pydantic2.type(UserModel, all_fields=True) + class User: + pass + + @strawberry.type + class Query: + @strawberry.field + def user(self) -> User: + return User(age=1, friend_names=["A", "B"]) + + schema = strawberry.Schema(query=Query) + + query = "{ user { friendNames } }" + + result = schema.execute_sync(query) + + assert not result.errors + assert result.data["user"]["friendNames"] == ["A", "B"] + + +def test_basic_type_with_nested_model(): + class Hobby(pydantic.BaseModel): + name: str + + @strawberry.experimental.pydantic2.type(Hobby, all_fields=True) + class HobbyType: + pass + + class User(pydantic.BaseModel): + hobby: Hobby + + @strawberry.experimental.pydantic2.type(User, all_fields=True) + class UserType: + pass + + @strawberry.type + class Query: + @strawberry.field + def user(self) -> UserType: + return UserType(hobby=HobbyType(name="Skii")) + + schema = strawberry.Schema(query=Query) + + query = "{ user { hobby { name } } }" + + result = schema.execute_sync(query) + + assert not result.errors + assert result.data["user"]["hobby"]["name"] == "Skii" + + +def test_basic_type_with_list_of_nested_model(): + class Hobby(pydantic.BaseModel): + name: str + + @strawberry.experimental.pydantic2.type(Hobby, all_fields=True) + class HobbyType: + pass + + class User(pydantic.BaseModel): + hobbies: List[Hobby] + + @strawberry.experimental.pydantic2.type(User, all_fields=True) + class UserType: + pass + + @strawberry.type + class Query: + @strawberry.field + def user(self) -> UserType: + return UserType( + hobbies=[ + HobbyType(name="Skii"), + HobbyType(name="Cooking"), + ] + ) + + schema = strawberry.Schema(query=Query) + + query = "{ user { hobbies { name } } }" + + result = schema.execute_sync(query) + + assert not result.errors + assert result.data["user"]["hobbies"] == [ + {"name": "Skii"}, + {"name": "Cooking"}, + ] + + +def test_basic_type_with_extended_fields(): + class UserModel(pydantic.BaseModel): + age: int + + @strawberry.experimental.pydantic2.type(UserModel, all_fields=True) + class User: + name: str + + @strawberry.type + class Query: + @strawberry.field + def user(self) -> User: + return User(name="Marco", age=100) + + schema = strawberry.Schema(query=Query) + + expected_schema = """ + type Query { + user: User! + } + + type User { + name: String! + age: Int! + } + """ + assert str(schema) == textwrap.dedent(expected_schema).strip() + + query = "{ user { name age } }" + + result = schema.execute_sync(query) + + assert not result.errors + assert result.data["user"]["name"] == "Marco" + assert result.data["user"]["age"] == 100 + + +def test_type_with_custom_resolver(): + class UserModel(pydantic.BaseModel): + age: int + + def get_age_in_months(root): + return root.age * 12 + + @strawberry.experimental.pydantic2.type(UserModel, all_fields=True) + class User: + age_in_months: int = strawberry.field(resolver=get_age_in_months) + + @strawberry.type + class Query: + @strawberry.field + def user(self) -> User: + return User(age=20) + + schema = strawberry.Schema(query=Query) + + query = "{ user { age ageInMonths } }" + + result = schema.execute_sync(query) + + assert not result.errors + assert result.data["user"]["age"] == 20 + assert result.data["user"]["ageInMonths"] == 240 + + +def test_basic_type_with_union(): + class BranchA(pydantic.BaseModel): + field_a: str + + class BranchB(pydantic.BaseModel): + field_b: int + + class User(pydantic.BaseModel): + union_field: Union[BranchA, BranchB] + + @strawberry.experimental.pydantic2.type(BranchA, all_fields=True) + class BranchAType: + pass + + @strawberry.experimental.pydantic2.type(BranchB, all_fields=True) + class BranchBType: + pass + + @strawberry.experimental.pydantic2.type(User, all_fields=True) + class UserType: + pass + + @strawberry.type + class Query: + @strawberry.field + def user(self) -> UserType: + return UserType(union_field=BranchBType(field_b=10)) + + schema = strawberry.Schema(query=Query) + + query = "{ user { unionField { ... on BranchBType { fieldB } } } }" + + result = schema.execute_sync(query) + + assert not result.errors + assert result.data["user"]["unionField"]["fieldB"] == 10 + + +def test_basic_type_with_union_pydantic_types(): + class BranchA(pydantic.BaseModel): + field_a: str + + class BranchB(pydantic.BaseModel): + field_b: int + + class User(pydantic.BaseModel): + union_field: Union[BranchA, BranchB] + + @strawberry.experimental.pydantic2.type(BranchA, all_fields=True) + class BranchAType: + pass + + @strawberry.experimental.pydantic2.type(BranchB, all_fields=True) + class BranchBType: + pass + + @strawberry.experimental.pydantic2.type(User, all_fields=True) + class UserType: + pass + + @strawberry.type + class Query: + @strawberry.field + def user(self) -> UserType: + # note that BranchB is a pydantic type, not a strawberry type + return UserType(union_field=BranchB(field_b=10)) + + schema = strawberry.Schema(query=Query) + + query = "{ user { unionField { ... on BranchBType { fieldB } } } }" + + result = schema.execute_sync(query) + + assert not result.errors + assert result.data["user"]["unionField"]["fieldB"] == 10 + + +def test_basic_type_with_enum(): + @strawberry.enum + class UserKind(Enum): + user = 0 + admin = 1 + + class User(pydantic.BaseModel): + age: int + kind: UserKind + + @strawberry.experimental.pydantic2.type(User, all_fields=True) + class UserType: + pass + + @strawberry.type + class Query: + @strawberry.field + def user(self) -> UserType: + return UserType(age=10, kind=UserKind.admin) + + schema = strawberry.Schema(query=Query) + + query = "{ user { kind } }" + + result = schema.execute_sync(query) + + assert not result.errors + assert result.data["user"]["kind"] == "admin" + + +def test_basic_type_with_interface(): + class Base(pydantic.BaseModel): + base_field: str + + class BranchA(Base): + field_a: str + + class BranchB(Base): + field_b: int + + class User(pydantic.BaseModel): + interface_field: Base + + @strawberry.experimental.pydantic2.interface(Base, all_fields=True) + class BaseType: + pass + + @strawberry.experimental.pydantic2.type(BranchA, all_fields=True) + class BranchAType(BaseType): + pass + + @strawberry.experimental.pydantic2.type(BranchB, all_fields=True) + class BranchBType(BaseType): + pass + + @strawberry.experimental.pydantic2.type(User, all_fields=True) + class UserType: + pass + + @strawberry.type + class Query: + @strawberry.field + def user(self) -> UserType: + return UserType(interface_field=BranchBType(base_field="abc", field_b=10)) + + schema = strawberry.Schema(query=Query, types=[BranchAType, BranchBType]) + + query = "{ user { interfaceField { baseField, ... on BranchBType { fieldB } } } }" + + result = schema.execute_sync(query) + + assert not result.errors + assert result.data["user"]["interfaceField"]["baseField"] == "abc" + assert result.data["user"]["interfaceField"]["fieldB"] == 10 diff --git a/tests/experimental/pydantic2/test_basic.py b/tests/experimental/pydantic2/test_basic.py index c829e924cf..ac42495296 100644 --- a/tests/experimental/pydantic2/test_basic.py +++ b/tests/experimental/pydantic2/test_basic.py @@ -613,7 +613,7 @@ def updateGroup(group: GroupInput) -> GroupOutput: input UserInput { name: String! - work: WorkInput = null + work: WorkInput } type UserOutput { From 017e1689a5674ce81684bc34676f0bd0ebb01964 Mon Sep 17 00:00:00 2001 From: jameschua Date: Sat, 20 May 2023 18:46:45 +0100 Subject: [PATCH 08/26] revert --- strawberry/experimental/pydantic/conversion.py | 1 + strawberry/experimental/pydantic/error_type.py | 3 +-- strawberry/experimental/pydantic/fields.py | 11 ++++------- strawberry/experimental/pydantic/utils.py | 1 + 4 files changed, 7 insertions(+), 9 deletions(-) diff --git a/strawberry/experimental/pydantic/conversion.py b/strawberry/experimental/pydantic/conversion.py index e138e1e807..bc0f787948 100644 --- a/strawberry/experimental/pydantic/conversion.py +++ b/strawberry/experimental/pydantic/conversion.py @@ -9,6 +9,7 @@ from strawberry.union import StrawberryUnion if TYPE_CHECKING: + from strawberry.field import StrawberryField from strawberry.type import StrawberryType diff --git a/strawberry/experimental/pydantic/error_type.py b/strawberry/experimental/pydantic/error_type.py index f6a2d380ec..adcdd5cdf4 100644 --- a/strawberry/experimental/pydantic/error_type.py +++ b/strawberry/experimental/pydantic/error_type.py @@ -16,6 +16,7 @@ ) from pydantic import BaseModel +from pydantic.utils import lenient_issubclass from strawberry.auto import StrawberryAuto from strawberry.experimental.pydantic.utils import ( @@ -23,12 +24,10 @@ get_strawberry_type_from_model, normalize_type, ) -from strawberry.experimental.pydantic.v2_compat import lenient_issubclass from strawberry.object_type import _process_type, _wrap_dataclass from strawberry.types.type_resolver import _get_fields from strawberry.utils.typing import get_list_annotation, is_list - from .exceptions import MissingFieldsListError if TYPE_CHECKING: diff --git a/strawberry/experimental/pydantic/fields.py b/strawberry/experimental/pydantic/fields.py index ebcfe6dd34..cfa3a6be2c 100644 --- a/strawberry/experimental/pydantic/fields.py +++ b/strawberry/experimental/pydantic/fields.py @@ -5,13 +5,9 @@ import pydantic from pydantic import BaseModel -from strawberry.experimental.pydantic.v2_compat import ( - lenient_issubclass, - get_args, - get_origin, - is_new_type, - new_type_supertype, -) +from pydantic.typing import get_args, get_origin, is_new_type, new_type_supertype +from pydantic.utils import lenient_issubclass + from strawberry.experimental.pydantic.exceptions import ( UnregisteredTypeException, UnsupportedTypeError, @@ -74,6 +70,7 @@ "RedisDsn": str, } + FIELDS_MAP = { getattr(pydantic, field_name): type for field_name, type in ATTR_TO_TYPE_MAP.items() diff --git a/strawberry/experimental/pydantic/utils.py b/strawberry/experimental/pydantic/utils.py index 00a966f2ed..e863a47e16 100644 --- a/strawberry/experimental/pydantic/utils.py +++ b/strawberry/experimental/pydantic/utils.py @@ -14,6 +14,7 @@ cast, ) +from pydantic.utils import smart_deepcopy from strawberry.experimental.pydantic.exceptions import ( AutoFieldsNotInBaseModelError, From 2103eabf49b02eb4e990dbd952871eb3449dbc9d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 25 Jun 2023 15:08:51 +0000 Subject: [PATCH 09/26] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- strawberry/experimental/pydantic/v2_compat.py | 10 +++------- .../experimental/pydantic2/error_type.py | 1 - strawberry/experimental/pydantic2/fields.py | 19 +++++++------------ .../experimental/pydantic2/object_type.py | 2 +- .../experimental/pydantic2/v2_compat.py | 10 +++------- 5 files changed, 14 insertions(+), 28 deletions(-) diff --git a/strawberry/experimental/pydantic/v2_compat.py b/strawberry/experimental/pydantic/v2_compat.py index f965136d1c..f1b70b210c 100644 --- a/strawberry/experimental/pydantic/v2_compat.py +++ b/strawberry/experimental/pydantic/v2_compat.py @@ -1,17 +1,13 @@ import pydantic if pydantic.VERSION[0] == "2": - from pydantic._internal._utils import smart_deepcopy - from pydantic._internal._utils import lenient_issubclass - from typing_extensions import get_args, get_origin - from pydantic._internal._typing_extra import is_new_type + + from pydantic._internal._utils import lenient_issubclass, smart_deepcopy def new_type_supertype(type_): return type_.__supertype__ else: - from pydantic.utils import smart_deepcopy - from pydantic.utils import lenient_issubclass - from pydantic.typing import get_args, get_origin, is_new_type, new_type_supertype + from pydantic.utils import lenient_issubclass, smart_deepcopy __all__ = ["smart_deepcopy", "lenient_issubclass"] diff --git a/strawberry/experimental/pydantic2/error_type.py b/strawberry/experimental/pydantic2/error_type.py index 6933df5e9c..495d47df74 100644 --- a/strawberry/experimental/pydantic2/error_type.py +++ b/strawberry/experimental/pydantic2/error_type.py @@ -28,7 +28,6 @@ from strawberry.types.type_resolver import _get_fields from strawberry.utils.typing import get_list_annotation, is_list - from .exceptions import MissingFieldsListError if TYPE_CHECKING: diff --git a/strawberry/experimental/pydantic2/fields.py b/strawberry/experimental/pydantic2/fields.py index d35f5f6dc3..fb86983d90 100644 --- a/strawberry/experimental/pydantic2/fields.py +++ b/strawberry/experimental/pydantic2/fields.py @@ -1,21 +1,20 @@ import builtins -from decimal import Decimal -from typing import Any, List, Optional, Type -from uuid import UUID +from typing import Any, Type import pydantic from pydantic import BaseModel + +from strawberry.experimental.pydantic2.exceptions import ( + UnregisteredTypeException, + UnsupportedTypeError, +) from strawberry.experimental.pydantic2.v2_compat import ( - lenient_issubclass, get_args, get_origin, is_new_type, + lenient_issubclass, new_type_supertype, ) -from strawberry.experimental.pydantic2.exceptions import ( - UnregisteredTypeException, - UnsupportedTypeError, -) from strawberry.types.types import TypeDefinition try: @@ -42,13 +41,9 @@ def get_basic_type(type_: Any) -> Type[Any]: # if lenient_issubclass(type_, pydantic.ConstrainedInt): - # return int # if lenient_issubclass(type_, pydantic.ConstrainedFloat): - # return float # if lenient_issubclass(type_, pydantic.ConstrainedStr): - # return str # if lenient_issubclass(type_, pydantic.ConstrainedList): - # return List[get_basic_type(type_.item_type)] # type: ignore if type_ in FIELDS_MAP: type_ = FIELDS_MAP.get(type_) diff --git a/strawberry/experimental/pydantic2/object_type.py b/strawberry/experimental/pydantic2/object_type.py index dd4c928478..4fe2c132ca 100644 --- a/strawberry/experimental/pydantic2/object_type.py +++ b/strawberry/experimental/pydantic2/object_type.py @@ -39,7 +39,7 @@ if TYPE_CHECKING: from graphql import GraphQLResolveInfo - from pydantic.fields import FieldInfo, FieldInfo + from pydantic.fields import FieldInfo def is_required(field: FieldInfo) -> bool: diff --git a/strawberry/experimental/pydantic2/v2_compat.py b/strawberry/experimental/pydantic2/v2_compat.py index f965136d1c..f1b70b210c 100644 --- a/strawberry/experimental/pydantic2/v2_compat.py +++ b/strawberry/experimental/pydantic2/v2_compat.py @@ -1,17 +1,13 @@ import pydantic if pydantic.VERSION[0] == "2": - from pydantic._internal._utils import smart_deepcopy - from pydantic._internal._utils import lenient_issubclass - from typing_extensions import get_args, get_origin - from pydantic._internal._typing_extra import is_new_type + + from pydantic._internal._utils import lenient_issubclass, smart_deepcopy def new_type_supertype(type_): return type_.__supertype__ else: - from pydantic.utils import smart_deepcopy - from pydantic.utils import lenient_issubclass - from pydantic.typing import get_args, get_origin, is_new_type, new_type_supertype + from pydantic.utils import lenient_issubclass, smart_deepcopy __all__ = ["smart_deepcopy", "lenient_issubclass"] From 064db83a6dcc9904452dd83e41fb9e6b314e53f7 Mon Sep 17 00:00:00 2001 From: jameschua Date: Sun, 25 Jun 2023 23:15:22 +0800 Subject: [PATCH 10/26] remove compat file --- strawberry/experimental/pydantic/v2_compat.py | 13 ------------- strawberry/experimental/pydantic2/error_type.py | 2 +- strawberry/experimental/pydantic2/fields.py | 14 +++++++------- strawberry/experimental/pydantic2/v2_compat.py | 1 - 4 files changed, 8 insertions(+), 22 deletions(-) delete mode 100644 strawberry/experimental/pydantic/v2_compat.py diff --git a/strawberry/experimental/pydantic/v2_compat.py b/strawberry/experimental/pydantic/v2_compat.py deleted file mode 100644 index f1b70b210c..0000000000 --- a/strawberry/experimental/pydantic/v2_compat.py +++ /dev/null @@ -1,13 +0,0 @@ -import pydantic - -if pydantic.VERSION[0] == "2": - - from pydantic._internal._utils import lenient_issubclass, smart_deepcopy - - def new_type_supertype(type_): - return type_.__supertype__ - -else: - from pydantic.utils import lenient_issubclass, smart_deepcopy - -__all__ = ["smart_deepcopy", "lenient_issubclass"] diff --git a/strawberry/experimental/pydantic2/error_type.py b/strawberry/experimental/pydantic2/error_type.py index 495d47df74..8ec9209efc 100644 --- a/strawberry/experimental/pydantic2/error_type.py +++ b/strawberry/experimental/pydantic2/error_type.py @@ -16,6 +16,7 @@ ) from pydantic import BaseModel +from pydantic._internal._utils import lenient_issubclass from strawberry.auto import StrawberryAuto from strawberry.experimental.pydantic2.utils import ( @@ -23,7 +24,6 @@ get_strawberry_type_from_model, normalize_type, ) -from strawberry.experimental.pydantic2.v2_compat import lenient_issubclass from strawberry.object_type import _process_type, _wrap_dataclass from strawberry.types.type_resolver import _get_fields from strawberry.utils.typing import get_list_annotation, is_list diff --git a/strawberry/experimental/pydantic2/fields.py b/strawberry/experimental/pydantic2/fields.py index fb86983d90..ba5a9f0d56 100644 --- a/strawberry/experimental/pydantic2/fields.py +++ b/strawberry/experimental/pydantic2/fields.py @@ -3,18 +3,14 @@ import pydantic from pydantic import BaseModel +from pydantic._internal._typing_extra import is_new_type +from pydantic._internal._utils import lenient_issubclass +from typing_extensions import get_args, get_origin from strawberry.experimental.pydantic2.exceptions import ( UnregisteredTypeException, UnsupportedTypeError, ) -from strawberry.experimental.pydantic2.v2_compat import ( - get_args, - get_origin, - is_new_type, - lenient_issubclass, - new_type_supertype, -) from strawberry.types.types import TypeDefinition try: @@ -39,6 +35,10 @@ } +def new_type_supertype(type_): + return type_.__supertype__ + + def get_basic_type(type_: Any) -> Type[Any]: # if lenient_issubclass(type_, pydantic.ConstrainedInt): # if lenient_issubclass(type_, pydantic.ConstrainedFloat): diff --git a/strawberry/experimental/pydantic2/v2_compat.py b/strawberry/experimental/pydantic2/v2_compat.py index f1b70b210c..4109331c9e 100644 --- a/strawberry/experimental/pydantic2/v2_compat.py +++ b/strawberry/experimental/pydantic2/v2_compat.py @@ -1,7 +1,6 @@ import pydantic if pydantic.VERSION[0] == "2": - from pydantic._internal._utils import lenient_issubclass, smart_deepcopy def new_type_supertype(type_): From 274f2b6c13a201e7acfb4377c0afc805414e2962 Mon Sep 17 00:00:00 2001 From: jameschua Date: Fri, 7 Jul 2023 17:10:32 +0800 Subject: [PATCH 11/26] fix pydantic v2 update issues --- strawberry/experimental/__init__.py | 6 +++--- strawberry/experimental/pydantic2/object_type.py | 4 ++-- strawberry/experimental/pydantic2/utils.py | 10 +++++----- tests/experimental/pydantic2/test_basic.py | 6 +++--- 4 files changed, 13 insertions(+), 13 deletions(-) diff --git a/strawberry/experimental/__init__.py b/strawberry/experimental/__init__.py index 0674ca0c11..985e43d40b 100644 --- a/strawberry/experimental/__init__.py +++ b/strawberry/experimental/__init__.py @@ -1,12 +1,12 @@ try: from . import pydantic - __all__ = ["pydantic"] except ImportError: pass try: from . import pydantic2 - + # Support for pydantic2 is highly experimental and the interface will change + # We don't recommend using it yet __all__ = ["pydantic2"] except ImportError as e: - print(e) + pass diff --git a/strawberry/experimental/pydantic2/object_type.py b/strawberry/experimental/pydantic2/object_type.py index 4fe2c132ca..82a02ac819 100644 --- a/strawberry/experimental/pydantic2/object_type.py +++ b/strawberry/experimental/pydantic2/object_type.py @@ -16,7 +16,7 @@ cast, ) -from pydantic._internal._fields import Undefined +from pydantic._internal._fields import PydanticUndefined from strawberry.annotation import StrawberryAnnotation from strawberry.auto import StrawberryAuto @@ -43,7 +43,7 @@ def is_required(field: FieldInfo) -> bool: - return field.default is Undefined and field.default_factory is None + return field.default is PydanticUndefined and field.default_factory is None def get_type_for_field(field: FieldInfo, is_input: bool): # noqa: ANN201 diff --git a/strawberry/experimental/pydantic2/utils.py b/strawberry/experimental/pydantic2/utils.py index c2c17327c6..efa1d9c6d0 100644 --- a/strawberry/experimental/pydantic2/utils.py +++ b/strawberry/experimental/pydantic2/utils.py @@ -14,7 +14,7 @@ cast, ) -from pydantic._internal._fields import Undefined, _UndefinedType +from pydantic_core import PydanticUndefinedType from pydantic._internal._utils import smart_deepcopy from strawberry.experimental.pydantic2.exceptions import ( @@ -77,11 +77,11 @@ def to_tuple(self) -> Tuple[str, Type, dataclasses.Field]: def is_required(field: ModelField) -> bool: - return field.default is Undefined and field.default_factory is None + return field.default is PydanticUndefinedType and field.default_factory is None def get_default_factory_for_field( - field: ModelField, + field: ModelField, ) -> Union[NoArgAnyCallable, dataclasses._MISSING_TYPE]: """ Gets the default factory for a pydantic field. @@ -95,7 +95,7 @@ def get_default_factory_for_field( default_factory = ( field.default_factory if field.default_factory is not None else UNSET ) - default = field.default if not isinstance(field.default, _UndefinedType) else UNSET + default = field.default if not isinstance(field.default, PydanticUndefinedType) else UNSET has_factory = default_factory is not UNSET has_default = default is not UNSET @@ -131,7 +131,7 @@ def get_default_factory_for_field( def ensure_all_auto_fields_in_pydantic( - model: Type[BaseModel], auto_fields: Set[str], cls_name: str + model: Type[BaseModel], auto_fields: Set[str], cls_name: str ) -> Union[NoReturn, None]: # Raise error if user defined a strawberry.auto field not present in the model non_existing_fields = list(auto_fields - model.model_fields.keys()) diff --git a/tests/experimental/pydantic2/test_basic.py b/tests/experimental/pydantic2/test_basic.py index ac42495296..520a984f62 100644 --- a/tests/experimental/pydantic2/test_basic.py +++ b/tests/experimental/pydantic2/test_basic.py @@ -335,7 +335,6 @@ class UserType: [groups_field, friends_field] = definition.fields assert groups_field.default is dataclasses.MISSING - assert groups_field.default_factory is dataclasses.MISSING assert friends_field.default is dataclasses.MISSING # check that we really made a copy @@ -627,8 +626,9 @@ def updateGroup(group: GroupInput) -> GroupOutput: type WorkOutput { time: Float! -}""" - assert schema.as_str().strip() == expected_schema.strip() +}""".strip() + result_schema = schema.as_str().strip() + assert result_schema == expected_schema assert Group._strawberry_type == GroupOutput assert Group._strawberry_input_type == GroupInput From f582d41a11dd7364337731016506cb442c417c34 Mon Sep 17 00:00:00 2001 From: jameschua Date: Fri, 7 Jul 2023 17:28:28 +0800 Subject: [PATCH 12/26] add test for #2782 --- strawberry/experimental/__init__.py | 2 + .../experimental/pydantic2/object_type.py | 12 ++--- strawberry/experimental/pydantic2/utils.py | 16 +++--- tests/experimental/pydantic2/test_basic.py | 49 +++++++++++++++++++ 4 files changed, 62 insertions(+), 17 deletions(-) diff --git a/strawberry/experimental/__init__.py b/strawberry/experimental/__init__.py index 985e43d40b..4c160f8284 100644 --- a/strawberry/experimental/__init__.py +++ b/strawberry/experimental/__init__.py @@ -1,10 +1,12 @@ try: from . import pydantic + __all__ = ["pydantic"] except ImportError: pass try: from . import pydantic2 + # Support for pydantic2 is highly experimental and the interface will change # We don't recommend using it yet __all__ = ["pydantic2"] diff --git a/strawberry/experimental/pydantic2/object_type.py b/strawberry/experimental/pydantic2/object_type.py index 82a02ac819..7e4a55077e 100644 --- a/strawberry/experimental/pydantic2/object_type.py +++ b/strawberry/experimental/pydantic2/object_type.py @@ -49,15 +49,9 @@ def is_required(field: FieldInfo) -> bool: def get_type_for_field(field: FieldInfo, is_input: bool): # noqa: ANN201 outer_type = field.annotation replaced_type = replace_types_recursively(outer_type, is_input) - - default_defined: bool = ( - field.default_factory is not None or field.default is not None - ) - should_add_optional: bool = not (is_required(field) or default_defined) - if should_add_optional: - return Optional[replaced_type] - else: - return replaced_type + # Note that unlike pydantic v1, pydantic v2 does not add a default of None when + # the field is Optional[something] + return replaced_type def _build_dataclass_creation_fields( diff --git a/strawberry/experimental/pydantic2/utils.py b/strawberry/experimental/pydantic2/utils.py index efa1d9c6d0..f6fa03b557 100644 --- a/strawberry/experimental/pydantic2/utils.py +++ b/strawberry/experimental/pydantic2/utils.py @@ -81,7 +81,7 @@ def is_required(field: ModelField) -> bool: def get_default_factory_for_field( - field: ModelField, + field: ModelField, ) -> Union[NoArgAnyCallable, dataclasses._MISSING_TYPE]: """ Gets the default factory for a pydantic field. @@ -95,7 +95,9 @@ def get_default_factory_for_field( default_factory = ( field.default_factory if field.default_factory is not None else UNSET ) - default = field.default if not isinstance(field.default, PydanticUndefinedType) else UNSET + default = ( + field.default if not isinstance(field.default, PydanticUndefinedType) else UNSET + ) has_factory = default_factory is not UNSET has_default = default is not UNSET @@ -121,17 +123,15 @@ def get_default_factory_for_field( if has_default: return lambda: smart_deepcopy(default) - # if we don't have default or default_factory, but the field is not required, - # we should return a factory that returns None - - if not is_required(field): - return lambda: None + # Note that unlike pydantic v1, pydantic v2 does not add a default of None when + # the field is Optional[something] + # so there is no need to handle that case here return dataclasses.MISSING def ensure_all_auto_fields_in_pydantic( - model: Type[BaseModel], auto_fields: Set[str], cls_name: str + model: Type[BaseModel], auto_fields: Set[str], cls_name: str ) -> Union[NoReturn, None]: # Raise error if user defined a strawberry.auto field not present in the model non_existing_fields = list(auto_fields - model.model_fields.keys()) diff --git a/tests/experimental/pydantic2/test_basic.py b/tests/experimental/pydantic2/test_basic.py index 520a984f62..c7057e6d1d 100644 --- a/tests/experimental/pydantic2/test_basic.py +++ b/tests/experimental/pydantic2/test_basic.py @@ -317,6 +317,55 @@ class UserType4: assert UserType4().to_pydantic().friend is None +def test_optional_and_default(): + class UserModel(pydantic.BaseModel): + age: int + name: str = pydantic.Field("Michael", description="The user name") + password: Optional[str] = pydantic.Field(default="ABC") + passwordtwo: Optional[str] = None + some_list: Optional[List[str]] = pydantic.Field(default_factory=list) + check: Optional[bool] = False + + @strawberry.experimental.pydantic2.type(UserModel, all_fields=True) + class User: + pass + + definition: TypeDefinition = User._type_definition + assert definition.name == "User" + + [ + age_field, + name_field, + password_field, + passwordtwo_field, + some_list_field, + check_field, + ] = definition.fields + + assert age_field.python_name == "age" + assert age_field.type is int + + assert name_field.python_name == "name" + assert name_field.type is str + + assert password_field.python_name == "password" + assert isinstance(password_field.type, StrawberryOptional) + assert password_field.type.of_type is str + + assert passwordtwo_field.python_name == "passwordtwo" + assert isinstance(passwordtwo_field.type, StrawberryOptional) + assert passwordtwo_field.type.of_type is str + + assert some_list_field.python_name == "some_list" + assert isinstance(some_list_field.type, StrawberryOptional) + assert isinstance(some_list_field.type.of_type, StrawberryList) + assert some_list_field.type.of_type.of_type is str + + assert check_field.python_name == "check" + assert isinstance(check_field.type, StrawberryOptional) + assert check_field.type.of_type is bool + + def test_type_with_fields_mutable_default(): empty_list = [] From 045856a0a1947b567cf5c2d7e28c314f2c19902a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 7 Jul 2023 09:29:37 +0000 Subject: [PATCH 13/26] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- strawberry/experimental/pydantic2/fields.py | 2 +- strawberry/experimental/pydantic2/utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/strawberry/experimental/pydantic2/fields.py b/strawberry/experimental/pydantic2/fields.py index ba5a9f0d56..6b91470e10 100644 --- a/strawberry/experimental/pydantic2/fields.py +++ b/strawberry/experimental/pydantic2/fields.py @@ -1,11 +1,11 @@ import builtins from typing import Any, Type +from typing_extensions import get_args, get_origin import pydantic from pydantic import BaseModel from pydantic._internal._typing_extra import is_new_type from pydantic._internal._utils import lenient_issubclass -from typing_extensions import get_args, get_origin from strawberry.experimental.pydantic2.exceptions import ( UnregisteredTypeException, diff --git a/strawberry/experimental/pydantic2/utils.py b/strawberry/experimental/pydantic2/utils.py index f6fa03b557..3061e55d50 100644 --- a/strawberry/experimental/pydantic2/utils.py +++ b/strawberry/experimental/pydantic2/utils.py @@ -14,8 +14,8 @@ cast, ) -from pydantic_core import PydanticUndefinedType from pydantic._internal._utils import smart_deepcopy +from pydantic_core import PydanticUndefinedType from strawberry.experimental.pydantic2.exceptions import ( AutoFieldsNotInBaseModelError, From 5139f7a967ff7ebb32b49cf5017d7862e78c868e Mon Sep 17 00:00:00 2001 From: jameschua Date: Tue, 11 Jul 2023 10:20:29 -0700 Subject: [PATCH 14/26] mark pydantic v2 explicitly --- noxfile.py | 25 ++++++++++++++++--- .../pydantic/schema/test_basic.py | 3 +++ .../pydantic/schema/test_defaults.py | 3 +++ .../pydantic/schema/test_federation.py | 3 +++ .../pydantic/schema/test_forward_reference.py | 2 ++ .../pydantic/schema/test_mutation.py | 2 ++ tests/experimental/pydantic/test_basic.py | 1 + .../experimental/pydantic/test_conversion.py | 1 + .../experimental/pydantic/test_error_type.py | 1 + tests/experimental/pydantic/test_fields.py | 2 +- .../pydantic2/schema/test_basic.py | 2 ++ tests/experimental/pydantic2/test_basic.py | 1 + tests/schema/test_pydantic.py | 2 +- 13 files changed, 43 insertions(+), 5 deletions(-) diff --git a/noxfile.py b/noxfile.py index 6630690efe..e42ac15a83 100644 --- a/noxfile.py +++ b/noxfile.py @@ -90,10 +90,29 @@ def tests_litestar(session: Session) -> None: ) -@session(python=["3.11"], name="Pydantic tests", tags=["tests"]) -# TODO: add pydantic 2.0 here :) +@session(python=["3.11"], name="Pydantic v1 tests", tags=["tests"]) @nox.parametrize("pydantic", ["1.10"]) -def test_pydantic(session: Session, pydantic: str) -> None: +def test_pydantic_v1(session: Session, pydantic: str) -> None: + session.run_always("poetry", "install", external=True) + + session._session.install(f"pydantic~={pydantic}") # type: ignore + + session.run( + "pytest", + "--cov=strawberry", + "--cov-append", + "--cov-report=xml", + "-n", + "auto", + "--showlocals", + "-vv", + "-m", + "pydantic", + ) + +@session(python=["3.11"], name="Pydantic v2 tests", tags=["tests"]) +@nox.parametrize("pydantic", ["2.0"]) +def test_pydantic_v2(session: Session, pydantic: str) -> None: session.run_always("poetry", "install", external=True) session._session.install(f"pydantic~={pydantic}") # type: ignore diff --git a/tests/experimental/pydantic/schema/test_basic.py b/tests/experimental/pydantic/schema/test_basic.py index 06a95a8fbd..c1e974ec2a 100644 --- a/tests/experimental/pydantic/schema/test_basic.py +++ b/tests/experimental/pydantic/schema/test_basic.py @@ -3,9 +3,12 @@ from typing import List, Optional, Union import pydantic +import pytest import strawberry +pytestmark = pytest.mark.pydantic_v1 + def test_basic_type_field_list(): class UserModel(pydantic.BaseModel): diff --git a/tests/experimental/pydantic/schema/test_defaults.py b/tests/experimental/pydantic/schema/test_defaults.py index 0917e85ec9..d491617a12 100644 --- a/tests/experimental/pydantic/schema/test_defaults.py +++ b/tests/experimental/pydantic/schema/test_defaults.py @@ -2,10 +2,13 @@ from typing import Optional import pydantic +import pytest import strawberry from strawberry.printer import print_schema +pytestmark = pytest.mark.pydantic_v1 + def test_field_type_default(): class User(pydantic.BaseModel): diff --git a/tests/experimental/pydantic/schema/test_federation.py b/tests/experimental/pydantic/schema/test_federation.py index 47bd56c2f9..03c7b7050e 100644 --- a/tests/experimental/pydantic/schema/test_federation.py +++ b/tests/experimental/pydantic/schema/test_federation.py @@ -1,10 +1,13 @@ import typing +import pytest from pydantic import BaseModel import strawberry from strawberry.federation.schema_directives import Key +pytestmark = pytest.mark.pydantic_v1 + def test_fetch_entities_pydantic(): class ProductInDb(BaseModel): diff --git a/tests/experimental/pydantic/schema/test_forward_reference.py b/tests/experimental/pydantic/schema/test_forward_reference.py index ebc94d4b37..74a111e3dc 100644 --- a/tests/experimental/pydantic/schema/test_forward_reference.py +++ b/tests/experimental/pydantic/schema/test_forward_reference.py @@ -4,9 +4,11 @@ from typing import Optional import pydantic +import pytest import strawberry +pytestmark = pytest.mark.pydantic_v1 def test_auto_fields(): global User diff --git a/tests/experimental/pydantic/schema/test_mutation.py b/tests/experimental/pydantic/schema/test_mutation.py index 5ae48f65a3..8b75e4d54c 100644 --- a/tests/experimental/pydantic/schema/test_mutation.py +++ b/tests/experimental/pydantic/schema/test_mutation.py @@ -1,9 +1,11 @@ from typing import Dict, List, Union import pydantic +import pytest import strawberry +pytestmark = pytest.mark.pydantic_v1 def test_mutation(): class User(pydantic.BaseModel): diff --git a/tests/experimental/pydantic/test_basic.py b/tests/experimental/pydantic/test_basic.py index f5b7280dec..442a037901 100644 --- a/tests/experimental/pydantic/test_basic.py +++ b/tests/experimental/pydantic/test_basic.py @@ -13,6 +13,7 @@ from strawberry.types.types import StrawberryObjectDefinition from strawberry.union import StrawberryUnion +pytestmark = pytest.mark.pydantic_v1 def test_basic_type_field_list(): class User(pydantic.BaseModel): diff --git a/tests/experimental/pydantic/test_conversion.py b/tests/experimental/pydantic/test_conversion.py index 300520d4ef..d77eb54b58 100644 --- a/tests/experimental/pydantic/test_conversion.py +++ b/tests/experimental/pydantic/test_conversion.py @@ -19,6 +19,7 @@ from strawberry.type import StrawberryList, StrawberryOptional from strawberry.types.types import StrawberryObjectDefinition +pytestmark = pytest.mark.pydantic_v1 def test_can_use_type_standalone(): class User(BaseModel): diff --git a/tests/experimental/pydantic/test_error_type.py b/tests/experimental/pydantic/test_error_type.py index 969c61b87a..3d460c7f40 100644 --- a/tests/experimental/pydantic/test_error_type.py +++ b/tests/experimental/pydantic/test_error_type.py @@ -8,6 +8,7 @@ from strawberry.type import StrawberryList, StrawberryOptional from strawberry.types.types import StrawberryObjectDefinition +pytestmark = pytest.mark.pydantic_v1 def test_basic_error_type_fields(): class UserModel(pydantic.BaseModel): diff --git a/tests/experimental/pydantic/test_fields.py b/tests/experimental/pydantic/test_fields.py index d553318766..9162307032 100644 --- a/tests/experimental/pydantic/test_fields.py +++ b/tests/experimental/pydantic/test_fields.py @@ -10,7 +10,7 @@ from strawberry.type import StrawberryOptional from strawberry.types.types import StrawberryObjectDefinition - +pytestmark = pytest.mark.pydantic_v1 @pytest.mark.parametrize( ("pydantic_type", "field_type"), [ diff --git a/tests/experimental/pydantic2/schema/test_basic.py b/tests/experimental/pydantic2/schema/test_basic.py index e72acbe6a4..e02e6d2ced 100644 --- a/tests/experimental/pydantic2/schema/test_basic.py +++ b/tests/experimental/pydantic2/schema/test_basic.py @@ -3,9 +3,11 @@ from typing import List, Optional, Union import pydantic +import pytest import strawberry +pytestmark = pytest.mark.pydantic_v2 def test_all_fields(): class UserModel(pydantic.BaseModel): diff --git a/tests/experimental/pydantic2/test_basic.py b/tests/experimental/pydantic2/test_basic.py index c7057e6d1d..0f7c9d7f29 100644 --- a/tests/experimental/pydantic2/test_basic.py +++ b/tests/experimental/pydantic2/test_basic.py @@ -13,6 +13,7 @@ from strawberry.types.types import TypeDefinition from strawberry.union import StrawberryUnion +pytestmark = pytest.mark.pydantic_v2 def test_basic_type_all_fields(): class User(pydantic.BaseModel): diff --git a/tests/schema/test_pydantic.py b/tests/schema/test_pydantic.py index 793862429b..f101bee694 100644 --- a/tests/schema/test_pydantic.py +++ b/tests/schema/test_pydantic.py @@ -3,7 +3,7 @@ import strawberry -pytestmark = pytest.mark.pydantic +pytestmark = pytest.mark.pydantic_v1 def test_use_alias_as_gql_name(): From 0e0d1a5224effd83bb54fa03b9ccbcd82c5b6903 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 11 Jul 2023 17:21:39 +0000 Subject: [PATCH 15/26] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- noxfile.py | 1 + tests/experimental/pydantic/schema/test_forward_reference.py | 1 + tests/experimental/pydantic/schema/test_mutation.py | 1 + tests/experimental/pydantic/test_basic.py | 1 + tests/experimental/pydantic/test_conversion.py | 1 + tests/experimental/pydantic/test_error_type.py | 1 + tests/experimental/pydantic/test_fields.py | 2 ++ tests/experimental/pydantic2/schema/test_basic.py | 1 + tests/experimental/pydantic2/test_basic.py | 1 + 9 files changed, 10 insertions(+) diff --git a/noxfile.py b/noxfile.py index e42ac15a83..bb00360142 100644 --- a/noxfile.py +++ b/noxfile.py @@ -110,6 +110,7 @@ def test_pydantic_v1(session: Session, pydantic: str) -> None: "pydantic", ) + @session(python=["3.11"], name="Pydantic v2 tests", tags=["tests"]) @nox.parametrize("pydantic", ["2.0"]) def test_pydantic_v2(session: Session, pydantic: str) -> None: diff --git a/tests/experimental/pydantic/schema/test_forward_reference.py b/tests/experimental/pydantic/schema/test_forward_reference.py index 74a111e3dc..60c6c76eb6 100644 --- a/tests/experimental/pydantic/schema/test_forward_reference.py +++ b/tests/experimental/pydantic/schema/test_forward_reference.py @@ -10,6 +10,7 @@ pytestmark = pytest.mark.pydantic_v1 + def test_auto_fields(): global User diff --git a/tests/experimental/pydantic/schema/test_mutation.py b/tests/experimental/pydantic/schema/test_mutation.py index 8b75e4d54c..f0bf6b235f 100644 --- a/tests/experimental/pydantic/schema/test_mutation.py +++ b/tests/experimental/pydantic/schema/test_mutation.py @@ -7,6 +7,7 @@ pytestmark = pytest.mark.pydantic_v1 + def test_mutation(): class User(pydantic.BaseModel): name: pydantic.constr(min_length=2) diff --git a/tests/experimental/pydantic/test_basic.py b/tests/experimental/pydantic/test_basic.py index 442a037901..881c7a9352 100644 --- a/tests/experimental/pydantic/test_basic.py +++ b/tests/experimental/pydantic/test_basic.py @@ -15,6 +15,7 @@ pytestmark = pytest.mark.pydantic_v1 + def test_basic_type_field_list(): class User(pydantic.BaseModel): age: int diff --git a/tests/experimental/pydantic/test_conversion.py b/tests/experimental/pydantic/test_conversion.py index d77eb54b58..2a195600f6 100644 --- a/tests/experimental/pydantic/test_conversion.py +++ b/tests/experimental/pydantic/test_conversion.py @@ -21,6 +21,7 @@ pytestmark = pytest.mark.pydantic_v1 + def test_can_use_type_standalone(): class User(BaseModel): age: int diff --git a/tests/experimental/pydantic/test_error_type.py b/tests/experimental/pydantic/test_error_type.py index 3d460c7f40..d36bf1617e 100644 --- a/tests/experimental/pydantic/test_error_type.py +++ b/tests/experimental/pydantic/test_error_type.py @@ -10,6 +10,7 @@ pytestmark = pytest.mark.pydantic_v1 + def test_basic_error_type_fields(): class UserModel(pydantic.BaseModel): name: str diff --git a/tests/experimental/pydantic/test_fields.py b/tests/experimental/pydantic/test_fields.py index 9162307032..3ec8ba5a17 100644 --- a/tests/experimental/pydantic/test_fields.py +++ b/tests/experimental/pydantic/test_fields.py @@ -11,6 +11,8 @@ from strawberry.types.types import StrawberryObjectDefinition pytestmark = pytest.mark.pydantic_v1 + + @pytest.mark.parametrize( ("pydantic_type", "field_type"), [ diff --git a/tests/experimental/pydantic2/schema/test_basic.py b/tests/experimental/pydantic2/schema/test_basic.py index e02e6d2ced..1b4373108e 100644 --- a/tests/experimental/pydantic2/schema/test_basic.py +++ b/tests/experimental/pydantic2/schema/test_basic.py @@ -9,6 +9,7 @@ pytestmark = pytest.mark.pydantic_v2 + def test_all_fields(): class UserModel(pydantic.BaseModel): age: int diff --git a/tests/experimental/pydantic2/test_basic.py b/tests/experimental/pydantic2/test_basic.py index 0f7c9d7f29..7b311b94f0 100644 --- a/tests/experimental/pydantic2/test_basic.py +++ b/tests/experimental/pydantic2/test_basic.py @@ -15,6 +15,7 @@ pytestmark = pytest.mark.pydantic_v2 + def test_basic_type_all_fields(): class User(pydantic.BaseModel): age: int From 8fa12dc29eba6456ad3dc4d88ae8fb52fc357ce8 Mon Sep 17 00:00:00 2001 From: jameschua Date: Tue, 11 Jul 2023 10:35:35 -0700 Subject: [PATCH 16/26] add pytest markers for pydantic_v2 --- noxfile.py | 11 ++++++++--- pyproject.toml | 3 ++- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/noxfile.py b/noxfile.py index bb00360142..df1d770b95 100644 --- a/noxfile.py +++ b/noxfile.py @@ -24,7 +24,10 @@ def tests(session: Session) -> None: "-m", "not starlite", "-m", - "not pydantic", + "not pydantic_v1", + "-m", + "not pydantic_v2", + "-m", "--ignore=tests/mypy", "--ignore=tests/pyright", ) @@ -93,6 +96,7 @@ def tests_litestar(session: Session) -> None: @session(python=["3.11"], name="Pydantic v1 tests", tags=["tests"]) @nox.parametrize("pydantic", ["1.10"]) def test_pydantic_v1(session: Session, pydantic: str) -> None: + # pydantic_v1 has different tests files than pydantic_v2 session.run_always("poetry", "install", external=True) session._session.install(f"pydantic~={pydantic}") # type: ignore @@ -107,13 +111,14 @@ def test_pydantic_v1(session: Session, pydantic: str) -> None: "--showlocals", "-vv", "-m", - "pydantic", + "pydantic_v1", ) @session(python=["3.11"], name="Pydantic v2 tests", tags=["tests"]) @nox.parametrize("pydantic", ["2.0"]) def test_pydantic_v2(session: Session, pydantic: str) -> None: + # pydantic_v1 has different tests files than pydantic_v2 session.run_always("poetry", "install", external=True) session._session.install(f"pydantic~={pydantic}") # type: ignore @@ -128,7 +133,7 @@ def test_pydantic_v2(session: Session, pydantic: str) -> None: "--showlocals", "-vv", "-m", - "pydantic", + "pydantic_v2", ) diff --git a/pyproject.toml b/pyproject.toml index 31237a158d..3ee78775e8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -159,7 +159,8 @@ markers = [ "chalice", "flask", "starlite", - "pydantic", + "pydantic_v1", + "pydantic_v2", "flaky", "relay", ] From c7a9b8cb53f7bf6f7f8f230492036d577303100e Mon Sep 17 00:00:00 2001 From: jameschua Date: Tue, 11 Jul 2023 11:29:40 -0700 Subject: [PATCH 17/26] try again --- noxfile.py | 1 - 1 file changed, 1 deletion(-) diff --git a/noxfile.py b/noxfile.py index df1d770b95..cf8deb4055 100644 --- a/noxfile.py +++ b/noxfile.py @@ -27,7 +27,6 @@ def tests(session: Session) -> None: "not pydantic_v1", "-m", "not pydantic_v2", - "-m", "--ignore=tests/mypy", "--ignore=tests/pyright", ) From 6be92be367a8c80f9cf89be78cfca05adcb828be Mon Sep 17 00:00:00 2001 From: jameschua Date: Wed, 12 Jul 2023 11:57:48 -0700 Subject: [PATCH 18/26] add ignore pydantic2 --- noxfile.py | 2 ++ tests/experimental/pydantic2/schema/test_basic.py | 1 - tests/experimental/pydantic2/test_basic.py | 7 ++++++- 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/noxfile.py b/noxfile.py index cf8deb4055..0f10f470e8 100644 --- a/noxfile.py +++ b/noxfile.py @@ -133,6 +133,8 @@ def test_pydantic_v2(session: Session, pydantic: str) -> None: "-vv", "-m", "pydantic_v2", + "--ignore", + "tests/experimental/pydantic2", ) diff --git a/tests/experimental/pydantic2/schema/test_basic.py b/tests/experimental/pydantic2/schema/test_basic.py index 1b4373108e..e02e6d2ced 100644 --- a/tests/experimental/pydantic2/schema/test_basic.py +++ b/tests/experimental/pydantic2/schema/test_basic.py @@ -9,7 +9,6 @@ pytestmark = pytest.mark.pydantic_v2 - def test_all_fields(): class UserModel(pydantic.BaseModel): age: int diff --git a/tests/experimental/pydantic2/test_basic.py b/tests/experimental/pydantic2/test_basic.py index 7b311b94f0..685e56598a 100644 --- a/tests/experimental/pydantic2/test_basic.py +++ b/tests/experimental/pydantic2/test_basic.py @@ -7,7 +7,7 @@ import strawberry from strawberry.enum import EnumDefinition -from strawberry.experimental.pydantic2.exceptions import MissingFieldsListError + from strawberry.schema_directive import Location from strawberry.type import StrawberryList, StrawberryOptional from strawberry.types.types import TypeDefinition @@ -15,6 +15,11 @@ pytestmark = pytest.mark.pydantic_v2 +if pydantic.__version__ >= "2.0.0": + # pydantic v2 imports need to be here to avoid import errors when running + # noxfile tests with pydantic v1 + # otherwise you need to add explicit directory exclusions for this folder + from strawberry.experimental.pydantic2.exceptions import MissingFieldsListError def test_basic_type_all_fields(): class User(pydantic.BaseModel): From fecd02d8e2c18744f1062c13513a397c54361e38 Mon Sep 17 00:00:00 2001 From: jameschua Date: Wed, 12 Jul 2023 12:07:23 -0700 Subject: [PATCH 19/26] add explicit dir --- noxfile.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/noxfile.py b/noxfile.py index 0f10f470e8..53f1d7ce98 100644 --- a/noxfile.py +++ b/noxfile.py @@ -102,6 +102,7 @@ def test_pydantic_v1(session: Session, pydantic: str) -> None: session.run( "pytest", + "tests/experimental/pydantic", "--cov=strawberry", "--cov-append", "--cov-report=xml", @@ -124,6 +125,7 @@ def test_pydantic_v2(session: Session, pydantic: str) -> None: session.run( "pytest", + "tests/experimental/pydantic2", "--cov=strawberry", "--cov-append", "--cov-report=xml", @@ -133,8 +135,6 @@ def test_pydantic_v2(session: Session, pydantic: str) -> None: "-vv", "-m", "pydantic_v2", - "--ignore", - "tests/experimental/pydantic2", ) From 883b8a4f424003d848597d05e0254825a66cde64 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 12 Jul 2023 19:07:54 +0000 Subject: [PATCH 20/26] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/experimental/pydantic2/schema/test_basic.py | 1 + tests/experimental/pydantic2/test_basic.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/experimental/pydantic2/schema/test_basic.py b/tests/experimental/pydantic2/schema/test_basic.py index e02e6d2ced..1b4373108e 100644 --- a/tests/experimental/pydantic2/schema/test_basic.py +++ b/tests/experimental/pydantic2/schema/test_basic.py @@ -9,6 +9,7 @@ pytestmark = pytest.mark.pydantic_v2 + def test_all_fields(): class UserModel(pydantic.BaseModel): age: int diff --git a/tests/experimental/pydantic2/test_basic.py b/tests/experimental/pydantic2/test_basic.py index 685e56598a..72766335f7 100644 --- a/tests/experimental/pydantic2/test_basic.py +++ b/tests/experimental/pydantic2/test_basic.py @@ -7,7 +7,6 @@ import strawberry from strawberry.enum import EnumDefinition - from strawberry.schema_directive import Location from strawberry.type import StrawberryList, StrawberryOptional from strawberry.types.types import TypeDefinition @@ -21,6 +20,7 @@ # otherwise you need to add explicit directory exclusions for this folder from strawberry.experimental.pydantic2.exceptions import MissingFieldsListError + def test_basic_type_all_fields(): class User(pydantic.BaseModel): age: int From e6fd48b5a39a67fe8320c418e0d2c5779c671013 Mon Sep 17 00:00:00 2001 From: jameschua Date: Wed, 12 Jul 2023 12:29:36 -0700 Subject: [PATCH 21/26] remove markers --- noxfile.py | 9 --------- pyproject.toml | 2 -- tests/experimental/pydantic/schema/test_basic.py | 3 --- tests/experimental/pydantic/schema/test_defaults.py | 3 --- tests/experimental/pydantic/schema/test_federation.py | 3 --- .../pydantic/schema/test_forward_reference.py | 3 --- tests/experimental/pydantic/schema/test_mutation.py | 3 --- tests/experimental/pydantic/test_basic.py | 2 -- tests/experimental/pydantic/test_conversion.py | 2 -- tests/experimental/pydantic/test_error_type.py | 2 -- tests/experimental/pydantic/test_fields.py | 2 -- tests/experimental/pydantic2/schema/test_basic.py | 2 -- tests/experimental/pydantic2/test_basic.py | 1 - tests/schema/test_pydantic.py | 3 --- 14 files changed, 40 deletions(-) diff --git a/noxfile.py b/noxfile.py index 53f1d7ce98..fa03a7bdb2 100644 --- a/noxfile.py +++ b/noxfile.py @@ -24,9 +24,6 @@ def tests(session: Session) -> None: "-m", "not starlite", "-m", - "not pydantic_v1", - "-m", - "not pydantic_v2", "--ignore=tests/mypy", "--ignore=tests/pyright", ) @@ -95,7 +92,6 @@ def tests_litestar(session: Session) -> None: @session(python=["3.11"], name="Pydantic v1 tests", tags=["tests"]) @nox.parametrize("pydantic", ["1.10"]) def test_pydantic_v1(session: Session, pydantic: str) -> None: - # pydantic_v1 has different tests files than pydantic_v2 session.run_always("poetry", "install", external=True) session._session.install(f"pydantic~={pydantic}") # type: ignore @@ -110,15 +106,12 @@ def test_pydantic_v1(session: Session, pydantic: str) -> None: "auto", "--showlocals", "-vv", - "-m", - "pydantic_v1", ) @session(python=["3.11"], name="Pydantic v2 tests", tags=["tests"]) @nox.parametrize("pydantic", ["2.0"]) def test_pydantic_v2(session: Session, pydantic: str) -> None: - # pydantic_v1 has different tests files than pydantic_v2 session.run_always("poetry", "install", external=True) session._session.install(f"pydantic~={pydantic}") # type: ignore @@ -133,8 +126,6 @@ def test_pydantic_v2(session: Session, pydantic: str) -> None: "auto", "--showlocals", "-vv", - "-m", - "pydantic_v2", ) diff --git a/pyproject.toml b/pyproject.toml index 3ee78775e8..216e68aee3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -159,8 +159,6 @@ markers = [ "chalice", "flask", "starlite", - "pydantic_v1", - "pydantic_v2", "flaky", "relay", ] diff --git a/tests/experimental/pydantic/schema/test_basic.py b/tests/experimental/pydantic/schema/test_basic.py index c1e974ec2a..06a95a8fbd 100644 --- a/tests/experimental/pydantic/schema/test_basic.py +++ b/tests/experimental/pydantic/schema/test_basic.py @@ -3,12 +3,9 @@ from typing import List, Optional, Union import pydantic -import pytest import strawberry -pytestmark = pytest.mark.pydantic_v1 - def test_basic_type_field_list(): class UserModel(pydantic.BaseModel): diff --git a/tests/experimental/pydantic/schema/test_defaults.py b/tests/experimental/pydantic/schema/test_defaults.py index d491617a12..0917e85ec9 100644 --- a/tests/experimental/pydantic/schema/test_defaults.py +++ b/tests/experimental/pydantic/schema/test_defaults.py @@ -2,13 +2,10 @@ from typing import Optional import pydantic -import pytest import strawberry from strawberry.printer import print_schema -pytestmark = pytest.mark.pydantic_v1 - def test_field_type_default(): class User(pydantic.BaseModel): diff --git a/tests/experimental/pydantic/schema/test_federation.py b/tests/experimental/pydantic/schema/test_federation.py index 03c7b7050e..47bd56c2f9 100644 --- a/tests/experimental/pydantic/schema/test_federation.py +++ b/tests/experimental/pydantic/schema/test_federation.py @@ -1,13 +1,10 @@ import typing -import pytest from pydantic import BaseModel import strawberry from strawberry.federation.schema_directives import Key -pytestmark = pytest.mark.pydantic_v1 - def test_fetch_entities_pydantic(): class ProductInDb(BaseModel): diff --git a/tests/experimental/pydantic/schema/test_forward_reference.py b/tests/experimental/pydantic/schema/test_forward_reference.py index 60c6c76eb6..ebc94d4b37 100644 --- a/tests/experimental/pydantic/schema/test_forward_reference.py +++ b/tests/experimental/pydantic/schema/test_forward_reference.py @@ -4,12 +4,9 @@ from typing import Optional import pydantic -import pytest import strawberry -pytestmark = pytest.mark.pydantic_v1 - def test_auto_fields(): global User diff --git a/tests/experimental/pydantic/schema/test_mutation.py b/tests/experimental/pydantic/schema/test_mutation.py index f0bf6b235f..5ae48f65a3 100644 --- a/tests/experimental/pydantic/schema/test_mutation.py +++ b/tests/experimental/pydantic/schema/test_mutation.py @@ -1,12 +1,9 @@ from typing import Dict, List, Union import pydantic -import pytest import strawberry -pytestmark = pytest.mark.pydantic_v1 - def test_mutation(): class User(pydantic.BaseModel): diff --git a/tests/experimental/pydantic/test_basic.py b/tests/experimental/pydantic/test_basic.py index 881c7a9352..f5b7280dec 100644 --- a/tests/experimental/pydantic/test_basic.py +++ b/tests/experimental/pydantic/test_basic.py @@ -13,8 +13,6 @@ from strawberry.types.types import StrawberryObjectDefinition from strawberry.union import StrawberryUnion -pytestmark = pytest.mark.pydantic_v1 - def test_basic_type_field_list(): class User(pydantic.BaseModel): diff --git a/tests/experimental/pydantic/test_conversion.py b/tests/experimental/pydantic/test_conversion.py index 2a195600f6..300520d4ef 100644 --- a/tests/experimental/pydantic/test_conversion.py +++ b/tests/experimental/pydantic/test_conversion.py @@ -19,8 +19,6 @@ from strawberry.type import StrawberryList, StrawberryOptional from strawberry.types.types import StrawberryObjectDefinition -pytestmark = pytest.mark.pydantic_v1 - def test_can_use_type_standalone(): class User(BaseModel): diff --git a/tests/experimental/pydantic/test_error_type.py b/tests/experimental/pydantic/test_error_type.py index d36bf1617e..969c61b87a 100644 --- a/tests/experimental/pydantic/test_error_type.py +++ b/tests/experimental/pydantic/test_error_type.py @@ -8,8 +8,6 @@ from strawberry.type import StrawberryList, StrawberryOptional from strawberry.types.types import StrawberryObjectDefinition -pytestmark = pytest.mark.pydantic_v1 - def test_basic_error_type_fields(): class UserModel(pydantic.BaseModel): diff --git a/tests/experimental/pydantic/test_fields.py b/tests/experimental/pydantic/test_fields.py index 3ec8ba5a17..d553318766 100644 --- a/tests/experimental/pydantic/test_fields.py +++ b/tests/experimental/pydantic/test_fields.py @@ -10,8 +10,6 @@ from strawberry.type import StrawberryOptional from strawberry.types.types import StrawberryObjectDefinition -pytestmark = pytest.mark.pydantic_v1 - @pytest.mark.parametrize( ("pydantic_type", "field_type"), diff --git a/tests/experimental/pydantic2/schema/test_basic.py b/tests/experimental/pydantic2/schema/test_basic.py index 1b4373108e..1cbcb9bfb5 100644 --- a/tests/experimental/pydantic2/schema/test_basic.py +++ b/tests/experimental/pydantic2/schema/test_basic.py @@ -3,11 +3,9 @@ from typing import List, Optional, Union import pydantic -import pytest import strawberry -pytestmark = pytest.mark.pydantic_v2 def test_all_fields(): diff --git a/tests/experimental/pydantic2/test_basic.py b/tests/experimental/pydantic2/test_basic.py index 72766335f7..383ea4acd6 100644 --- a/tests/experimental/pydantic2/test_basic.py +++ b/tests/experimental/pydantic2/test_basic.py @@ -12,7 +12,6 @@ from strawberry.types.types import TypeDefinition from strawberry.union import StrawberryUnion -pytestmark = pytest.mark.pydantic_v2 if pydantic.__version__ >= "2.0.0": # pydantic v2 imports need to be here to avoid import errors when running diff --git a/tests/schema/test_pydantic.py b/tests/schema/test_pydantic.py index f101bee694..f33a6f3a2d 100644 --- a/tests/schema/test_pydantic.py +++ b/tests/schema/test_pydantic.py @@ -3,9 +3,6 @@ import strawberry -pytestmark = pytest.mark.pydantic_v1 - - def test_use_alias_as_gql_name(): class UserModel(BaseModel): age_: int = Field(..., alias="age_alias") From 858998c4d902216698d4dc99c9c92997911654ea Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 12 Jul 2023 19:30:02 +0000 Subject: [PATCH 22/26] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/experimental/pydantic2/schema/test_basic.py | 1 - tests/experimental/pydantic2/test_basic.py | 1 - tests/schema/test_pydantic.py | 2 +- 3 files changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/experimental/pydantic2/schema/test_basic.py b/tests/experimental/pydantic2/schema/test_basic.py index 1cbcb9bfb5..e72acbe6a4 100644 --- a/tests/experimental/pydantic2/schema/test_basic.py +++ b/tests/experimental/pydantic2/schema/test_basic.py @@ -7,7 +7,6 @@ import strawberry - def test_all_fields(): class UserModel(pydantic.BaseModel): age: int diff --git a/tests/experimental/pydantic2/test_basic.py b/tests/experimental/pydantic2/test_basic.py index 383ea4acd6..fbb6fd45e2 100644 --- a/tests/experimental/pydantic2/test_basic.py +++ b/tests/experimental/pydantic2/test_basic.py @@ -12,7 +12,6 @@ from strawberry.types.types import TypeDefinition from strawberry.union import StrawberryUnion - if pydantic.__version__ >= "2.0.0": # pydantic v2 imports need to be here to avoid import errors when running # noxfile tests with pydantic v1 diff --git a/tests/schema/test_pydantic.py b/tests/schema/test_pydantic.py index f33a6f3a2d..f4dc07551a 100644 --- a/tests/schema/test_pydantic.py +++ b/tests/schema/test_pydantic.py @@ -1,8 +1,8 @@ -import pytest from pydantic import BaseModel, Field import strawberry + def test_use_alias_as_gql_name(): class UserModel(BaseModel): age_: int = Field(..., alias="age_alias") From ae9eff1facb99db15451c70cb7452706e390af77 Mon Sep 17 00:00:00 2001 From: jameschua Date: Wed, 12 Jul 2023 13:09:26 -0700 Subject: [PATCH 23/26] add ignore for test --- noxfile.py | 3 ++- strawberry/experimental/pydantic2/v2_compat.py | 12 ------------ 2 files changed, 2 insertions(+), 13 deletions(-) delete mode 100644 strawberry/experimental/pydantic2/v2_compat.py diff --git a/noxfile.py b/noxfile.py index fa03a7bdb2..a467d4c30c 100644 --- a/noxfile.py +++ b/noxfile.py @@ -23,9 +23,10 @@ def tests(session: Session) -> None: "not django", "-m", "not starlite", - "-m", "--ignore=tests/mypy", "--ignore=tests/pyright", + "--ignore=tests/experimental/pydantic", + "--ignore=tests/experimental/pydantic2", ) diff --git a/strawberry/experimental/pydantic2/v2_compat.py b/strawberry/experimental/pydantic2/v2_compat.py deleted file mode 100644 index 4109331c9e..0000000000 --- a/strawberry/experimental/pydantic2/v2_compat.py +++ /dev/null @@ -1,12 +0,0 @@ -import pydantic - -if pydantic.VERSION[0] == "2": - from pydantic._internal._utils import lenient_issubclass, smart_deepcopy - - def new_type_supertype(type_): - return type_.__supertype__ - -else: - from pydantic.utils import lenient_issubclass, smart_deepcopy - -__all__ = ["smart_deepcopy", "lenient_issubclass"] From 42203c66430dc5b885404212ba611effa065d69f Mon Sep 17 00:00:00 2001 From: jameschua Date: Wed, 12 Jul 2023 18:27:26 -0700 Subject: [PATCH 24/26] add hints --- strawberry/experimental/pydantic2/conversion.py | 2 +- strawberry/experimental/pydantic2/conversion_types.py | 4 ++-- strawberry/experimental/pydantic2/fields.py | 4 ++-- strawberry/experimental/pydantic2/object_type.py | 2 +- tests/cli/snapshots/unions.py | 3 ++- tests/cli/snapshots/unions_py38.py | 3 ++- tests/cli/snapshots/unions_typing_extension.py | 3 ++- tests/experimental/pydantic/test_conversion.py | 4 ++-- 8 files changed, 14 insertions(+), 11 deletions(-) diff --git a/strawberry/experimental/pydantic2/conversion.py b/strawberry/experimental/pydantic2/conversion.py index e138e1e807..3127307b3e 100644 --- a/strawberry/experimental/pydantic2/conversion.py +++ b/strawberry/experimental/pydantic2/conversion.py @@ -9,6 +9,7 @@ from strawberry.union import StrawberryUnion if TYPE_CHECKING: + from strawberry.field import StrawberryField from strawberry.type import StrawberryType @@ -62,7 +63,6 @@ def _convert_from_pydantic_to_strawberry_type( return data - def convert_pydantic_model_to_strawberry_class( cls, *, model_instance=None, extra=None # noqa: ANN001 ) -> Any: diff --git a/strawberry/experimental/pydantic2/conversion_types.py b/strawberry/experimental/pydantic2/conversion_types.py index aca9cdccd9..69ca6ae054 100644 --- a/strawberry/experimental/pydantic2/conversion_types.py +++ b/strawberry/experimental/pydantic2/conversion_types.py @@ -16,7 +16,7 @@ class StrawberryTypeFromPydantic(Protocol[PydanticModel]): """This class does not exist in runtime. It only makes the methods below visible for IDEs""" - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any): ... @staticmethod @@ -25,7 +25,7 @@ def from_pydantic( ) -> StrawberryTypeFromPydantic[PydanticModel]: ... - def to_pydantic(self, **kwargs) -> PydanticModel: + def to_pydantic(self, **kwargs: Any) -> PydanticModel: ... @property diff --git a/strawberry/experimental/pydantic2/fields.py b/strawberry/experimental/pydantic2/fields.py index 6b91470e10..00bea43c61 100644 --- a/strawberry/experimental/pydantic2/fields.py +++ b/strawberry/experimental/pydantic2/fields.py @@ -1,5 +1,5 @@ import builtins -from typing import Any, Type +from typing import Any, NewType, Type from typing_extensions import get_args, get_origin import pydantic @@ -35,7 +35,7 @@ } -def new_type_supertype(type_): +def new_type_supertype(type_: NewType) -> Type[Any]: return type_.__supertype__ diff --git a/strawberry/experimental/pydantic2/object_type.py b/strawberry/experimental/pydantic2/object_type.py index 7e4a55077e..5b829d18b7 100644 --- a/strawberry/experimental/pydantic2/object_type.py +++ b/strawberry/experimental/pydantic2/object_type.py @@ -265,7 +265,7 @@ def from_pydantic_default( ret._original_model = instance return ret - def to_pydantic_default(self: Any, **kwargs) -> PydanticModel: + def to_pydantic_default(self: Any, **kwargs: Any) -> PydanticModel: instance_kwargs = { f.name: convert_strawberry_class_to_pydantic_model( getattr(self, f.name) diff --git a/tests/cli/snapshots/unions.py b/tests/cli/snapshots/unions.py index 2f03e0bb4b..d725469ce5 100644 --- a/tests/cli/snapshots/unions.py +++ b/tests/cli/snapshots/unions.py @@ -1,6 +1,7 @@ -import strawberry from typing import Annotated +import strawberry + # create a few types and then a union type diff --git a/tests/cli/snapshots/unions_py38.py b/tests/cli/snapshots/unions_py38.py index 4fd0143c69..b7ead3da65 100644 --- a/tests/cli/snapshots/unions_py38.py +++ b/tests/cli/snapshots/unions_py38.py @@ -1,6 +1,7 @@ -import strawberry from typing import Annotated, Union +import strawberry + # create a few types and then a union type diff --git a/tests/cli/snapshots/unions_typing_extension.py b/tests/cli/snapshots/unions_typing_extension.py index 11c5f01cfb..088fab1d24 100644 --- a/tests/cli/snapshots/unions_typing_extension.py +++ b/tests/cli/snapshots/unions_typing_extension.py @@ -1,6 +1,7 @@ -import strawberry from typing_extensions import Annotated +import strawberry + # create a few types and then a union type diff --git a/tests/experimental/pydantic/test_conversion.py b/tests/experimental/pydantic/test_conversion.py index 300520d4ef..a5e5efaaa7 100644 --- a/tests/experimental/pydantic/test_conversion.py +++ b/tests/experimental/pydantic/test_conversion.py @@ -984,7 +984,7 @@ class UserType: password: strawberry.auto @staticmethod - def from_pydantic(instance: User, extra: Dict[str, Any] = None) -> "UserType": + def from_pydantic(instance: User, extra: Optional[Dict[str, Any]] = None) -> "UserType": return UserType( age=str(instance.age), password=base64.b64encode(instance.password.encode()).decode() @@ -1024,7 +1024,7 @@ class UserType: password: strawberry.auto @staticmethod - def from_pydantic(instance: User, extra: Dict[str, Any] = None) -> "UserType": + def from_pydantic(instance: User, extra: Optional[Dict[str, Any]] = None) -> "UserType": return UserType( age=str(instance.age), password=base64.b64encode(instance.password.encode()).decode() From 3d97786ba5ae3a44032894c73712630e440b6883 Mon Sep 17 00:00:00 2001 From: jameschua Date: Wed, 12 Jul 2023 20:06:20 -0700 Subject: [PATCH 25/26] fix weird cli tests changes --- tests/cli/snapshots/unions.py | 3 +-- tests/cli/snapshots/unions_py38.py | 3 +-- tests/cli/snapshots/unions_typing_extension.py | 3 +-- 3 files changed, 3 insertions(+), 6 deletions(-) diff --git a/tests/cli/snapshots/unions.py b/tests/cli/snapshots/unions.py index d725469ce5..2f03e0bb4b 100644 --- a/tests/cli/snapshots/unions.py +++ b/tests/cli/snapshots/unions.py @@ -1,6 +1,5 @@ -from typing import Annotated - import strawberry +from typing import Annotated # create a few types and then a union type diff --git a/tests/cli/snapshots/unions_py38.py b/tests/cli/snapshots/unions_py38.py index b7ead3da65..4fd0143c69 100644 --- a/tests/cli/snapshots/unions_py38.py +++ b/tests/cli/snapshots/unions_py38.py @@ -1,6 +1,5 @@ -from typing import Annotated, Union - import strawberry +from typing import Annotated, Union # create a few types and then a union type diff --git a/tests/cli/snapshots/unions_typing_extension.py b/tests/cli/snapshots/unions_typing_extension.py index 088fab1d24..11c5f01cfb 100644 --- a/tests/cli/snapshots/unions_typing_extension.py +++ b/tests/cli/snapshots/unions_typing_extension.py @@ -1,6 +1,5 @@ -from typing_extensions import Annotated - import strawberry +from typing_extensions import Annotated # create a few types and then a union type From 227736d93f42e9b91a8db34364b7f31e43373c92 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 13 Jul 2023 03:07:01 +0000 Subject: [PATCH 26/26] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- strawberry/experimental/pydantic2/conversion.py | 1 + tests/experimental/pydantic/test_conversion.py | 8 ++++++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/strawberry/experimental/pydantic2/conversion.py b/strawberry/experimental/pydantic2/conversion.py index 3127307b3e..bc0f787948 100644 --- a/strawberry/experimental/pydantic2/conversion.py +++ b/strawberry/experimental/pydantic2/conversion.py @@ -63,6 +63,7 @@ def _convert_from_pydantic_to_strawberry_type( return data + def convert_pydantic_model_to_strawberry_class( cls, *, model_instance=None, extra=None # noqa: ANN001 ) -> Any: diff --git a/tests/experimental/pydantic/test_conversion.py b/tests/experimental/pydantic/test_conversion.py index a5e5efaaa7..5303be0562 100644 --- a/tests/experimental/pydantic/test_conversion.py +++ b/tests/experimental/pydantic/test_conversion.py @@ -984,7 +984,9 @@ class UserType: password: strawberry.auto @staticmethod - def from_pydantic(instance: User, extra: Optional[Dict[str, Any]] = None) -> "UserType": + def from_pydantic( + instance: User, extra: Optional[Dict[str, Any]] = None + ) -> "UserType": return UserType( age=str(instance.age), password=base64.b64encode(instance.password.encode()).decode() @@ -1024,7 +1026,9 @@ class UserType: password: strawberry.auto @staticmethod - def from_pydantic(instance: User, extra: Optional[Dict[str, Any]] = None) -> "UserType": + def from_pydantic( + instance: User, extra: Optional[Dict[str, Any]] = None + ) -> "UserType": return UserType( age=str(instance.age), password=base64.b64encode(instance.password.encode()).decode()