diff --git a/noxfile.py b/noxfile.py index 6630690efe..a467d4c30c 100644 --- a/noxfile.py +++ b/noxfile.py @@ -23,10 +23,10 @@ def tests(session: Session) -> None: "not django", "-m", "not starlite", - "-m", - "not pydantic", "--ignore=tests/mypy", "--ignore=tests/pyright", + "--ignore=tests/experimental/pydantic", + "--ignore=tests/experimental/pydantic2", ) @@ -90,16 +90,36 @@ def tests_litestar(session: Session) -> None: ) -@session(python=["3.11"], name="Pydantic tests", tags=["tests"]) -# TODO: add pydantic 2.0 here :) +@session(python=["3.11"], name="Pydantic v1 tests", tags=["tests"]) @nox.parametrize("pydantic", ["1.10"]) -def test_pydantic(session: Session, pydantic: str) -> None: +def test_pydantic_v1(session: Session, pydantic: str) -> None: session.run_always("poetry", "install", external=True) session._session.install(f"pydantic~={pydantic}") # type: ignore session.run( "pytest", + "tests/experimental/pydantic", + "--cov=strawberry", + "--cov-append", + "--cov-report=xml", + "-n", + "auto", + "--showlocals", + "-vv", + ) + + +@session(python=["3.11"], name="Pydantic v2 tests", tags=["tests"]) +@nox.parametrize("pydantic", ["2.0"]) +def test_pydantic_v2(session: Session, pydantic: str) -> None: + session.run_always("poetry", "install", external=True) + + session._session.install(f"pydantic~={pydantic}") # type: ignore + + session.run( + "pytest", + "tests/experimental/pydantic2", "--cov=strawberry", "--cov-append", "--cov-report=xml", @@ -107,8 +127,6 @@ def test_pydantic(session: Session, pydantic: str) -> None: "auto", "--showlocals", "-vv", - "-m", - "pydantic", ) diff --git a/pyproject.toml b/pyproject.toml index 31237a158d..216e68aee3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -159,7 +159,6 @@ markers = [ "chalice", "flask", "starlite", - "pydantic", "flaky", "relay", ] diff --git a/strawberry/experimental/__init__.py b/strawberry/experimental/__init__.py index 6386ad81d7..4c160f8284 100644 --- a/strawberry/experimental/__init__.py +++ b/strawberry/experimental/__init__.py @@ -1,6 +1,14 @@ try: from . import pydantic + + __all__ = ["pydantic"] except ImportError: pass -else: - __all__ = ["pydantic"] +try: + from . import pydantic2 + + # Support for pydantic2 is highly experimental and the interface will change + # We don't recommend using it yet + __all__ = ["pydantic2"] +except ImportError as e: + pass diff --git a/strawberry/experimental/pydantic2/__init__.py b/strawberry/experimental/pydantic2/__init__.py new file mode 100644 index 0000000000..10f382650d --- /dev/null +++ b/strawberry/experimental/pydantic2/__init__.py @@ -0,0 +1,11 @@ +from .error_type import error_type +from .exceptions import UnregisteredTypeException +from .object_type import input, interface, type + +__all__ = [ + "error_type", + "UnregisteredTypeException", + "input", + "type", + "interface", +] diff --git a/strawberry/experimental/pydantic2/conversion.py b/strawberry/experimental/pydantic2/conversion.py new file mode 100644 index 0000000000..bc0f787948 --- /dev/null +++ b/strawberry/experimental/pydantic2/conversion.py @@ -0,0 +1,114 @@ +from __future__ import annotations + +import copy +import dataclasses +from typing import TYPE_CHECKING, Any, Type, Union, cast + +from strawberry.enum import EnumDefinition +from strawberry.type import StrawberryList, StrawberryOptional +from strawberry.union import StrawberryUnion + +if TYPE_CHECKING: + from strawberry.field import StrawberryField + from strawberry.type import StrawberryType + + +def _convert_from_pydantic_to_strawberry_type( + type_: Union[StrawberryType, type], data_from_model=None, extra=None # noqa: ANN001 +): + data = data_from_model if data_from_model is not None else extra + + if isinstance(type_, StrawberryOptional): + if data is None: + return data + return _convert_from_pydantic_to_strawberry_type( + type_.of_type, data_from_model=data, extra=extra + ) + if isinstance(type_, StrawberryUnion): + for option_type in type_.types: + if hasattr(option_type, "_pydantic_type"): + source_type = option_type._pydantic_type + else: + source_type = cast(type, option_type) + if isinstance(data, source_type): + return _convert_from_pydantic_to_strawberry_type( + option_type, data_from_model=data, extra=extra + ) + if isinstance(type_, EnumDefinition): + return data + if isinstance(type_, StrawberryList): + items = [] + for index, item in enumerate(data): + items.append( + _convert_from_pydantic_to_strawberry_type( + type_.of_type, + data_from_model=item, + extra=extra[index] if extra else None, + ) + ) + + return items + + if hasattr(type_, "_type_definition"): + # in the case of an interface, the concrete type may be more specific + # than the type in the field definition + # don't check _strawberry_input_type because inputs can't be interfaces + if hasattr(type(data), "_strawberry_type"): + type_ = type(data)._strawberry_type + if hasattr(type_, "from_pydantic"): + return type_.from_pydantic(data_from_model, extra) + return convert_pydantic_model_to_strawberry_class( + type_, model_instance=data_from_model, extra=extra + ) + + return data + + +def convert_pydantic_model_to_strawberry_class( + cls, *, model_instance=None, extra=None # noqa: ANN001 +) -> Any: + extra = extra or {} + kwargs = {} + + for field_ in cls._type_definition.fields: + field = cast("StrawberryField", field_) + python_name = field.python_name + + data_from_extra = extra.get(python_name, None) + data_from_model = ( + getattr(model_instance, python_name, None) if model_instance else None + ) + + # only convert and add fields to kwargs if they are present in the `__init__` + # method of the class + if field.init: + kwargs[python_name] = _convert_from_pydantic_to_strawberry_type( + field.type, data_from_model, extra=data_from_extra + ) + + return cls(**kwargs) + + +def convert_strawberry_class_to_pydantic_model(obj: Type) -> Any: + if hasattr(obj, "to_pydantic"): + return obj.to_pydantic() + elif dataclasses.is_dataclass(obj): + result = [] + for f in dataclasses.fields(obj): + value = convert_strawberry_class_to_pydantic_model(getattr(obj, f.name)) + result.append((f.name, value)) + return dict(result) + elif isinstance(obj, (list, tuple)): + # Assume we can create an object of this type by passing in a + # generator (which is not true for namedtuples, not supported). + return type(obj)(convert_strawberry_class_to_pydantic_model(v) for v in obj) + elif isinstance(obj, dict): + return type(obj)( + ( + convert_strawberry_class_to_pydantic_model(k), + convert_strawberry_class_to_pydantic_model(v), + ) + for k, v in obj.items() + ) + else: + return copy.deepcopy(obj) diff --git a/strawberry/experimental/pydantic2/conversion_types.py b/strawberry/experimental/pydantic2/conversion_types.py new file mode 100644 index 0000000000..69ca6ae054 --- /dev/null +++ b/strawberry/experimental/pydantic2/conversion_types.py @@ -0,0 +1,37 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Dict, Optional, Type, TypeVar +from typing_extensions import Protocol + +from pydantic import BaseModel + +if TYPE_CHECKING: + from strawberry.types.types import TypeDefinition + + +PydanticModel = TypeVar("PydanticModel", bound=BaseModel) + + +class StrawberryTypeFromPydantic(Protocol[PydanticModel]): + """This class does not exist in runtime. + It only makes the methods below visible for IDEs""" + + def __init__(self, **kwargs: Any): + ... + + @staticmethod + def from_pydantic( + instance: PydanticModel, extra: Optional[Dict[str, Any]] = None + ) -> StrawberryTypeFromPydantic[PydanticModel]: + ... + + def to_pydantic(self, **kwargs: Any) -> PydanticModel: + ... + + @property + def _type_definition(self) -> TypeDefinition: + ... + + @property + def _pydantic_type(self) -> Type[PydanticModel]: + ... diff --git a/strawberry/experimental/pydantic2/error_type.py b/strawberry/experimental/pydantic2/error_type.py new file mode 100644 index 0000000000..8ec9209efc --- /dev/null +++ b/strawberry/experimental/pydantic2/error_type.py @@ -0,0 +1,149 @@ +from __future__ import annotations + +import dataclasses +import warnings +from typing import ( + TYPE_CHECKING, + Any, + Callable, + List, + Optional, + Sequence, + Tuple, + Type, + Union, + cast, +) + +from pydantic import BaseModel +from pydantic._internal._utils import lenient_issubclass + +from strawberry.auto import StrawberryAuto +from strawberry.experimental.pydantic2.utils import ( + get_private_fields, + get_strawberry_type_from_model, + normalize_type, +) +from strawberry.object_type import _process_type, _wrap_dataclass +from strawberry.types.type_resolver import _get_fields +from strawberry.utils.typing import get_list_annotation, is_list + +from .exceptions import MissingFieldsListError + +if TYPE_CHECKING: + from pydantic.fields import ModelField + + +def get_type_for_field(field: ModelField) -> Union[Any, Type[None], Type[List]]: + type_ = field.outer_type_ + type_ = normalize_type(type_) + return field_type_to_type(type_) + + +def field_type_to_type(type_: Type) -> Union[Any, List[Any], None]: + error_class: Any = str + strawberry_type: Any = error_class + + if is_list(type_): + child_type = get_list_annotation(type_) + + if is_list(child_type): + strawberry_type = field_type_to_type(child_type) + elif lenient_issubclass(child_type, BaseModel): + strawberry_type = get_strawberry_type_from_model(child_type) + else: + strawberry_type = List[error_class] + + strawberry_type = Optional[strawberry_type] + elif lenient_issubclass(type_, BaseModel): + strawberry_type = get_strawberry_type_from_model(type_) + return Optional[strawberry_type] + + return Optional[List[strawberry_type]] + + +def error_type( + model: Type[BaseModel], + *, + fields: Optional[List[str]] = None, + name: Optional[str] = None, + description: Optional[str] = None, + directives: Optional[Sequence[object]] = (), + all_fields: bool = False, +) -> Callable[..., Type]: + def wrap(cls: Type) -> Type: + model_fields = model.__fields__ + fields_set = set(fields) if fields else set() + + if fields: + warnings.warn( + "`fields` is deprecated, use `auto` type annotations instead", + DeprecationWarning, + stacklevel=2, + ) + + existing_fields = getattr(cls, "__annotations__", {}) + fields_set = fields_set.union( + { + name + for name, type_ in existing_fields.items() + if isinstance(type_, StrawberryAuto) + } + ) + + if all_fields: + if fields_set: + warnings.warn( + "Using all_fields overrides any explicitly defined fields " + "in the model, using both is likely a bug", + stacklevel=2, + ) + fields_set = set(model_fields.keys()) + + if not fields_set: + raise MissingFieldsListError(cls) + + all_model_fields: List[Tuple[str, Any, dataclasses.Field]] = [ + ( + name, + get_type_for_field(field), + dataclasses.field(default=None), # type: ignore[arg-type] + ) + for name, field in model_fields.items() + if name in fields_set + ] + + wrapped = _wrap_dataclass(cls) + extra_fields = cast(List[dataclasses.Field], _get_fields(wrapped)) + private_fields = get_private_fields(wrapped) + + all_model_fields.extend( + ( + field.name, + field.type, + field, + ) + for field in extra_fields + private_fields + if not isinstance(field.type, StrawberryAuto) + ) + + cls = dataclasses.make_dataclass( + cls.__name__, + all_model_fields, + bases=cls.__bases__, + ) + + _process_type( + cls, + name=name, + is_input=False, + is_interface=False, + description=description, + directives=directives, + ) + + model._strawberry_type = cls # type: ignore[attr-defined] + cls._pydantic_type = model + return cls + + return wrap diff --git a/strawberry/experimental/pydantic2/exceptions.py b/strawberry/experimental/pydantic2/exceptions.py new file mode 100644 index 0000000000..cdb16baaad --- /dev/null +++ b/strawberry/experimental/pydantic2/exceptions.py @@ -0,0 +1,50 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, List, Type + +if TYPE_CHECKING: + from pydantic import BaseModel + from pydantic.typing import NoArgAnyCallable + + +class MissingFieldsListError(Exception): + def __init__(self, type: Type[BaseModel]): + message = ( + f"List of fields to copy from {type} is empty. Add fields with the " + f"`auto` type annotation" + ) + + super().__init__(message) + + +class UnsupportedTypeError(Exception): + pass + + +class UnregisteredTypeException(Exception): + def __init__(self, type: Type[BaseModel]): + message = ( + f"Cannot find a Strawberry Type for {type} did you forget to register it?" + ) + + super().__init__(message) + + +class BothDefaultAndDefaultFactoryDefinedError(Exception): + def __init__(self, default: Any, default_factory: NoArgAnyCallable): + message = ( + f"Not allowed to specify both default and default_factory. " + f"default:{default} default_factory:{default_factory}" + ) + + super().__init__(message) + + +class AutoFieldsNotInBaseModelError(Exception): + def __init__(self, fields: List[str], cls_name: str, model: Type[BaseModel]): + message = ( + f"{cls_name} defines {fields} with strawberry.auto. " + f"Field(s) not present in {model.__name__} BaseModel." + ) + + super().__init__(message) diff --git a/strawberry/experimental/pydantic2/fields.py b/strawberry/experimental/pydantic2/fields.py new file mode 100644 index 0000000000..00bea43c61 --- /dev/null +++ b/strawberry/experimental/pydantic2/fields.py @@ -0,0 +1,97 @@ +import builtins +from typing import Any, NewType, Type +from typing_extensions import get_args, get_origin + +import pydantic +from pydantic import BaseModel +from pydantic._internal._typing_extra import is_new_type +from pydantic._internal._utils import lenient_issubclass + +from strawberry.experimental.pydantic2.exceptions import ( + UnregisteredTypeException, + UnsupportedTypeError, +) +from strawberry.types.types import TypeDefinition + +try: + from typing import GenericAlias as TypingGenericAlias # type: ignore +except ImportError: + import sys + + # python < 3.9 does not have GenericAlias (list[int], tuple[str, ...] and so on) + # we do this under a conditional to avoid a mypy :) + if sys.version_info < (3, 9): + TypingGenericAlias = () + else: + raise + +# NOTE: To investigate the annotated types +ATTR_TO_TYPE_MAP = {} + +FIELDS_MAP = { + getattr(pydantic, field_name): type + for field_name, type in ATTR_TO_TYPE_MAP.items() + if hasattr(pydantic, field_name) +} + + +def new_type_supertype(type_: NewType) -> Type[Any]: + return type_.__supertype__ + + +def get_basic_type(type_: Any) -> Type[Any]: + # if lenient_issubclass(type_, pydantic.ConstrainedInt): + # if lenient_issubclass(type_, pydantic.ConstrainedFloat): + # if lenient_issubclass(type_, pydantic.ConstrainedStr): + # if lenient_issubclass(type_, pydantic.ConstrainedList): + + if type_ in FIELDS_MAP: + type_ = FIELDS_MAP.get(type_) + + if type_ is None: + raise UnsupportedTypeError() + + if is_new_type(type_): + return new_type_supertype(type_) + + return type_ + + +def replace_pydantic_types(type_: Any, is_input: bool) -> Any: + if lenient_issubclass(type_, BaseModel): + attr = "_strawberry_input_type" if is_input else "_strawberry_type" + if hasattr(type_, attr): + return getattr(type_, attr) + else: + raise UnregisteredTypeException(type_) + return type_ + + +def replace_types_recursively(type_: Any, is_input: bool) -> Any: + """Runs the conversions recursively into the arguments of generic types if any""" + basic_type = get_basic_type(type_) + replaced_type = replace_pydantic_types(basic_type, is_input) + + origin = get_origin(type_) + if not origin or not hasattr(type_, "__args__"): + return replaced_type + + converted = tuple( + replace_types_recursively(t, is_input=is_input) for t in get_args(replaced_type) + ) + + if isinstance(replaced_type, TypingGenericAlias): + return TypingGenericAlias(origin, converted) + + replaced_type = replaced_type.copy_with(converted) + + if isinstance(replaced_type, TypeDefinition): + # TODO: Not sure if this is necessary. No coverage in tests + # TODO: Unnecessary with StrawberryObject + replaced_type = builtins.type( + replaced_type.name, + (), + {"_type_definition": replaced_type}, + ) + + return replaced_type diff --git a/strawberry/experimental/pydantic2/object_type.py b/strawberry/experimental/pydantic2/object_type.py new file mode 100644 index 0000000000..5b829d18b7 --- /dev/null +++ b/strawberry/experimental/pydantic2/object_type.py @@ -0,0 +1,337 @@ +from __future__ import annotations + +import dataclasses +import sys +import warnings +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + Optional, + Sequence, + Set, + Type, + cast, +) + +from pydantic._internal._fields import PydanticUndefined + +from strawberry.annotation import StrawberryAnnotation +from strawberry.auto import StrawberryAuto +from strawberry.experimental.pydantic2.conversion import ( + convert_pydantic_model_to_strawberry_class, + convert_strawberry_class_to_pydantic_model, +) +from strawberry.experimental.pydantic2.exceptions import MissingFieldsListError +from strawberry.experimental.pydantic2.fields import replace_types_recursively +from strawberry.experimental.pydantic2.utils import ( + DataclassCreationFields, + ensure_all_auto_fields_in_pydantic, + get_default_factory_for_field, + get_private_fields, +) +from strawberry.field import StrawberryField +from strawberry.object_type import _process_type, _wrap_dataclass +from strawberry.types.type_resolver import _get_fields +from strawberry.utils.dataclasses import add_custom_init_fn + +if TYPE_CHECKING: + from graphql import GraphQLResolveInfo + from pydantic.fields import FieldInfo + + +def is_required(field: FieldInfo) -> bool: + return field.default is PydanticUndefined and field.default_factory is None + + +def get_type_for_field(field: FieldInfo, is_input: bool): # noqa: ANN201 + outer_type = field.annotation + replaced_type = replace_types_recursively(outer_type, is_input) + # Note that unlike pydantic v1, pydantic v2 does not add a default of None when + # the field is Optional[something] + return replaced_type + + +def _build_dataclass_creation_fields( + field_name: str, + field: FieldInfo, + is_input: bool, + existing_fields: Dict[str, StrawberryField], + auto_fields_set: Set[str], + use_pydantic_alias: bool, +) -> DataclassCreationFields: + field_type = ( + get_type_for_field(field, is_input) + if field_name in auto_fields_set + else existing_fields[field_name].type + ) + + if ( + field_name in existing_fields + and existing_fields[field_name].base_resolver is not None + ): + # if the user has defined a resolver for this field, always use it + strawberry_field = existing_fields[field_name] + else: + # otherwise we build an appropriate strawberry field that resolves it + existing_field = existing_fields.get(field_name) + graphql_name = None + if existing_field and existing_field.graphql_name: + graphql_name = existing_field.graphql_name + elif field.alias and use_pydantic_alias: + graphql_name = field.alias + + strawberry_field = StrawberryField( + python_name=field_name, + graphql_name=graphql_name, + # always unset because we use default_factory instead + default=dataclasses.MISSING, + default_factory=get_default_factory_for_field(field), + type_annotation=StrawberryAnnotation.from_annotation(field_type), + description=field.description, + deprecation_reason=( + existing_field.deprecation_reason if existing_field else None + ), + permission_classes=( + existing_field.permission_classes if existing_field else [] + ), + directives=existing_field.directives if existing_field else (), + metadata=existing_field.metadata if existing_field else {}, + ) + + return DataclassCreationFields( + name=field_name, + field_type=field_type, + field=strawberry_field, + ) + + +if TYPE_CHECKING: + from strawberry.experimental.pydantic2.conversion_types import ( + PydanticModel, + StrawberryTypeFromPydantic, + ) + + +def type( + model: Type[PydanticModel], + *, + name: Optional[str] = None, + is_input: bool = False, + is_interface: bool = False, + description: Optional[str] = None, + directives: Optional[Sequence[object]] = (), + all_fields: bool = False, + use_pydantic_alias: bool = True, +) -> Callable[..., Type[StrawberryTypeFromPydantic[PydanticModel]]]: + def wrap(cls: Any) -> Type[StrawberryTypeFromPydantic[PydanticModel]]: + model_fields: Dict[str, FieldInfo] = model.model_fields + + existing_fields = getattr(cls, "__annotations__", {}) + + # these are the fields that matched a field name in the pydantic model + # and should copy their alias from the pydantic model + fields_set = { + name for name, _ in existing_fields.items() if name in model_fields + } + # these are the fields that were marked with strawberry.auto and + # should copy their type from the pydantic model + auto_fields_set = { + name + for name, type_ in existing_fields.items() + if isinstance(type_, StrawberryAuto) + } + + if all_fields: + if fields_set: + warnings.warn( + "Using all_fields overrides any explicitly defined fields " + "in the model, using both is likely a bug", + stacklevel=2, + ) + fields_set = set(model_fields.keys()) + auto_fields_set = set(model_fields.keys()) + + if not fields_set: + raise MissingFieldsListError(cls) + + ensure_all_auto_fields_in_pydantic( + model=model, auto_fields=auto_fields_set, cls_name=cls.__name__ + ) + + wrapped = _wrap_dataclass(cls) + extra_strawberry_fields = _get_fields(wrapped) + extra_fields = cast(List[dataclasses.Field], extra_strawberry_fields) + private_fields = get_private_fields(wrapped) + + extra_fields_dict = {field.name: field for field in extra_strawberry_fields} + + all_model_fields: List[DataclassCreationFields] = [ + _build_dataclass_creation_fields( + field_name=field_name, + field=field, + is_input=is_input, + existing_fields=extra_fields_dict, + auto_fields_set=auto_fields_set, + use_pydantic_alias=use_pydantic_alias, + ) + for field_name, field in model_fields.items() + if field_name in fields_set + ] + + all_model_fields = [ + DataclassCreationFields( + name=field.name, + field_type=field.type, + field=field, + ) + for field in extra_fields + private_fields + if field.name not in fields_set + ] + all_model_fields + + # Implicitly define `is_type_of` to support interfaces/unions that use + # pydantic objects (not the corresponding strawberry type) + @classmethod # type: ignore + def is_type_of(cls: Type, obj: Any, _info: GraphQLResolveInfo) -> bool: + return isinstance(obj, (cls, model)) + + namespace = {"is_type_of": is_type_of} + # We need to tell the difference between a from_pydantic method that is + # inherited from a base class and one that is defined by the user in the + # decorated class. We want to override the method only if it is + # inherited. To tell the difference, we compare the class name to the + # fully qualified name of the method, which will end in .from_pydantic + has_custom_from_pydantic = hasattr( + cls, "from_pydantic" + ) and cls.from_pydantic.__qualname__.endswith(f"{cls.__name__}.from_pydantic") + has_custom_to_pydantic = hasattr( + cls, "to_pydantic" + ) and cls.to_pydantic.__qualname__.endswith(f"{cls.__name__}.to_pydantic") + + if has_custom_from_pydantic: + namespace["from_pydantic"] = cls.from_pydantic + if has_custom_to_pydantic: + namespace["to_pydantic"] = cls.to_pydantic + + if hasattr(cls, "resolve_reference"): + namespace["resolve_reference"] = cls.resolve_reference + + kwargs: Dict[str, object] = {} + + # Python 3.10.1 introduces the kw_only param to `make_dataclass`. + # If we're on an older version then generate our own custom init function + # Note: Python 3.10.0 added the `kw_only` param to dataclasses, it was + # just missed from the `make_dataclass` function: + # https://github.com/python/cpython/issues/89961 + if sys.version_info >= (3, 10, 1): + kwargs["kw_only"] = dataclasses.MISSING + else: + kwargs["init"] = False + + cls = dataclasses.make_dataclass( + cls.__name__, + [field.to_tuple() for field in all_model_fields], + bases=cls.__bases__, + namespace=namespace, + **kwargs, # type: ignore + ) + + if sys.version_info < (3, 10, 1): + add_custom_init_fn(cls) + + _process_type( + cls, + name=name, + is_input=is_input, + is_interface=is_interface, + description=description, + directives=directives, + ) + + if is_input: + model._strawberry_input_type = cls # type: ignore + else: + model._strawberry_type = cls # type: ignore + cls._pydantic_type = model + + def from_pydantic_default( + instance: PydanticModel, extra: Optional[Dict[str, Any]] = None + ) -> StrawberryTypeFromPydantic[PydanticModel]: + ret = convert_pydantic_model_to_strawberry_class( + cls=cls, model_instance=instance, extra=extra + ) + ret._original_model = instance + return ret + + def to_pydantic_default(self: Any, **kwargs: Any) -> PydanticModel: + instance_kwargs = { + f.name: convert_strawberry_class_to_pydantic_model( + getattr(self, f.name) + ) + for f in dataclasses.fields(self) + } + instance_kwargs.update(kwargs) + return model(**instance_kwargs) + + if not has_custom_from_pydantic: + cls.from_pydantic = staticmethod(from_pydantic_default) + if not has_custom_to_pydantic: + cls.to_pydantic = to_pydantic_default + + return cls + + return wrap + + +def input( + model: Type[PydanticModel], + *, + name: Optional[str] = None, + is_interface: bool = False, + description: Optional[str] = None, + directives: Optional[Sequence[object]] = (), + all_fields: bool = False, + use_pydantic_alias: bool = True, +) -> Callable[..., Type[StrawberryTypeFromPydantic[PydanticModel]]]: + """Convenience decorator for creating an input type from a Pydantic model. + Equal to partial(type, is_input=True) + See https://github.com/strawberry-graphql/strawberry/issues/1830 + """ + return type( + model=model, + name=name, + is_input=True, + is_interface=is_interface, + description=description, + directives=directives, + all_fields=all_fields, + use_pydantic_alias=use_pydantic_alias, + ) + + +def interface( + model: Type[PydanticModel], + *, + name: Optional[str] = None, + is_input: bool = False, + description: Optional[str] = None, + directives: Optional[Sequence[object]] = (), + all_fields: bool = False, + use_pydantic_alias: bool = True, +) -> Callable[..., Type[StrawberryTypeFromPydantic[PydanticModel]]]: + """Convenience decorator for creating an interface type from a Pydantic model. + Equal to partial(type, is_interface=True) + See https://github.com/strawberry-graphql/strawberry/issues/1830 + """ + return type( + model=model, + name=name, + is_input=is_input, + is_interface=True, + description=description, + directives=directives, + all_fields=all_fields, + use_pydantic_alias=use_pydantic_alias, + ) diff --git a/strawberry/experimental/pydantic2/utils.py b/strawberry/experimental/pydantic2/utils.py new file mode 100644 index 0000000000..3061e55d50 --- /dev/null +++ b/strawberry/experimental/pydantic2/utils.py @@ -0,0 +1,144 @@ +from __future__ import annotations + +import dataclasses +from typing import ( + TYPE_CHECKING, + Any, + List, + NamedTuple, + NoReturn, + Set, + Tuple, + Type, + Union, + cast, +) + +from pydantic._internal._utils import smart_deepcopy +from pydantic_core import PydanticUndefinedType + +from strawberry.experimental.pydantic2.exceptions import ( + AutoFieldsNotInBaseModelError, + BothDefaultAndDefaultFactoryDefinedError, + UnregisteredTypeException, +) +from strawberry.private import is_private +from strawberry.unset import UNSET +from strawberry.utils.typing import ( + get_list_annotation, + get_optional_annotation, + is_list, + is_optional, +) + +if TYPE_CHECKING: + from pydantic import BaseModel + from pydantic.fields import ModelField + from pydantic.typing import NoArgAnyCallable + + +def normalize_type(type_: Type) -> Any: + if is_list(type_): + return List[normalize_type(get_list_annotation(type_))] # type: ignore + + if is_optional(type_): + return get_optional_annotation(type_) + + return type_ + + +def get_strawberry_type_from_model(type_: Any) -> Any: + if hasattr(type_, "_strawberry_type"): + return type_._strawberry_type + else: + raise UnregisteredTypeException(type_) + + +def get_private_fields(cls: Type) -> List[dataclasses.Field]: + private_fields: List[dataclasses.Field] = [] + + for field in dataclasses.fields(cls): + if is_private(field.type): + private_fields.append(field) + + return private_fields + + +class DataclassCreationFields(NamedTuple): + """Fields required for the fields parameter of make_dataclass""" + + name: str + field_type: Type + field: dataclasses.Field + + def to_tuple(self) -> Tuple[str, Type, dataclasses.Field]: + # fields parameter wants (name, type, Field) + return self.name, self.field_type, self.field + + +def is_required(field: ModelField) -> bool: + return field.default is PydanticUndefinedType and field.default_factory is None + + +def get_default_factory_for_field( + field: ModelField, +) -> Union[NoArgAnyCallable, dataclasses._MISSING_TYPE]: + """ + Gets the default factory for a pydantic field. + + Handles mutable defaults when making the dataclass by + using pydantic's smart_deepcopy + + Returns optionally a NoArgAnyCallable representing a default_factory parameter + """ + # replace dataclasses.MISSING with our own UNSET to make comparisons easier + default_factory = ( + field.default_factory if field.default_factory is not None else UNSET + ) + default = ( + field.default if not isinstance(field.default, PydanticUndefinedType) else UNSET + ) + + has_factory = default_factory is not UNSET + has_default = default is not UNSET + + # defining both default and default_factory is not supported + + if has_factory and has_default: + default_factory = cast("NoArgAnyCallable", default_factory) + + raise BothDefaultAndDefaultFactoryDefinedError( + default=default, default_factory=default_factory + ) + + # if we have a default_factory, we should return it + + if has_factory: + default_factory = cast("NoArgAnyCallable", default_factory) + + return default_factory + + # if we have a default, we should return it + + if has_default: + return lambda: smart_deepcopy(default) + + # Note that unlike pydantic v1, pydantic v2 does not add a default of None when + # the field is Optional[something] + # so there is no need to handle that case here + + return dataclasses.MISSING + + +def ensure_all_auto_fields_in_pydantic( + model: Type[BaseModel], auto_fields: Set[str], cls_name: str +) -> Union[NoReturn, None]: + # Raise error if user defined a strawberry.auto field not present in the model + non_existing_fields = list(auto_fields - model.model_fields.keys()) + + if non_existing_fields: + raise AutoFieldsNotInBaseModelError( + fields=non_existing_fields, cls_name=cls_name, model=model + ) + else: + return None diff --git a/tests/experimental/pydantic/test_conversion.py b/tests/experimental/pydantic/test_conversion.py index 300520d4ef..5303be0562 100644 --- a/tests/experimental/pydantic/test_conversion.py +++ b/tests/experimental/pydantic/test_conversion.py @@ -984,7 +984,9 @@ class UserType: password: strawberry.auto @staticmethod - def from_pydantic(instance: User, extra: Dict[str, Any] = None) -> "UserType": + def from_pydantic( + instance: User, extra: Optional[Dict[str, Any]] = None + ) -> "UserType": return UserType( age=str(instance.age), password=base64.b64encode(instance.password.encode()).decode() @@ -1024,7 +1026,9 @@ class UserType: password: strawberry.auto @staticmethod - def from_pydantic(instance: User, extra: Dict[str, Any] = None) -> "UserType": + def from_pydantic( + instance: User, extra: Optional[Dict[str, Any]] = None + ) -> "UserType": return UserType( age=str(instance.age), password=base64.b64encode(instance.password.encode()).decode() diff --git a/tests/experimental/pydantic2/__init__.py b/tests/experimental/pydantic2/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/experimental/pydantic2/schema/__init__.py b/tests/experimental/pydantic2/schema/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/experimental/pydantic2/schema/test_basic.py b/tests/experimental/pydantic2/schema/test_basic.py new file mode 100644 index 0000000000..e72acbe6a4 --- /dev/null +++ b/tests/experimental/pydantic2/schema/test_basic.py @@ -0,0 +1,429 @@ +import textwrap +from enum import Enum +from typing import List, Optional, Union + +import pydantic + +import strawberry + + +def test_all_fields(): + class UserModel(pydantic.BaseModel): + age: int + password: Optional[str] + + @strawberry.experimental.pydantic2.type(UserModel, all_fields=True) + class User: + pass + + @strawberry.type + class Query: + @strawberry.field + def user(self) -> User: + return User(age=1, password="ABC") + + schema = strawberry.Schema(query=Query) + + expected_schema = """ + type Query { + user: User! + } + + type User { + age: Int! + password: String + } + """ + + assert str(schema) == textwrap.dedent(expected_schema).strip() + + query = "{ user { age } }" + + result = schema.execute_sync(query) + + assert not result.errors + assert result.data["user"]["age"] == 1 + + +def test_auto_fields(): + class UserModel(pydantic.BaseModel): + age: int + password: Optional[str] + other: float + + @strawberry.experimental.pydantic2.type(UserModel) + class User: + age: strawberry.auto + password: strawberry.auto + + @strawberry.type + class Query: + @strawberry.field + def user(self) -> User: + return User(age=1, password="ABC") + + schema = strawberry.Schema(query=Query) + + expected_schema = """ + type Query { + user: User! + } + + type User { + age: Int! + password: String + } + """ + + assert str(schema) == textwrap.dedent(expected_schema).strip() + + query = "{ user { age } }" + + result = schema.execute_sync(query) + + assert not result.errors + assert result.data["user"]["age"] == 1 + + +def test_basic_alias_type(): + class UserModel(pydantic.BaseModel): + age_: int = pydantic.Field(..., alias="age") + password: Optional[str] + + @strawberry.experimental.pydantic2.type(UserModel, all_fields=True) + class User: + pass + + @strawberry.type + class Query: + @strawberry.field + def user(self) -> User: + return User(age=1, password="ABC") + + schema = strawberry.Schema(query=Query) + + expected_schema = """ + type Query { + user: User! + } + + type User { + age: Int! + password: String + } + """ + + assert str(schema) == textwrap.dedent(expected_schema).strip() + + +def test_basic_type_with_list(): + class UserModel(pydantic.BaseModel): + age: int + friend_names: List[str] + + @strawberry.experimental.pydantic2.type(UserModel, all_fields=True) + class User: + pass + + @strawberry.type + class Query: + @strawberry.field + def user(self) -> User: + return User(age=1, friend_names=["A", "B"]) + + schema = strawberry.Schema(query=Query) + + query = "{ user { friendNames } }" + + result = schema.execute_sync(query) + + assert not result.errors + assert result.data["user"]["friendNames"] == ["A", "B"] + + +def test_basic_type_with_nested_model(): + class Hobby(pydantic.BaseModel): + name: str + + @strawberry.experimental.pydantic2.type(Hobby, all_fields=True) + class HobbyType: + pass + + class User(pydantic.BaseModel): + hobby: Hobby + + @strawberry.experimental.pydantic2.type(User, all_fields=True) + class UserType: + pass + + @strawberry.type + class Query: + @strawberry.field + def user(self) -> UserType: + return UserType(hobby=HobbyType(name="Skii")) + + schema = strawberry.Schema(query=Query) + + query = "{ user { hobby { name } } }" + + result = schema.execute_sync(query) + + assert not result.errors + assert result.data["user"]["hobby"]["name"] == "Skii" + + +def test_basic_type_with_list_of_nested_model(): + class Hobby(pydantic.BaseModel): + name: str + + @strawberry.experimental.pydantic2.type(Hobby, all_fields=True) + class HobbyType: + pass + + class User(pydantic.BaseModel): + hobbies: List[Hobby] + + @strawberry.experimental.pydantic2.type(User, all_fields=True) + class UserType: + pass + + @strawberry.type + class Query: + @strawberry.field + def user(self) -> UserType: + return UserType( + hobbies=[ + HobbyType(name="Skii"), + HobbyType(name="Cooking"), + ] + ) + + schema = strawberry.Schema(query=Query) + + query = "{ user { hobbies { name } } }" + + result = schema.execute_sync(query) + + assert not result.errors + assert result.data["user"]["hobbies"] == [ + {"name": "Skii"}, + {"name": "Cooking"}, + ] + + +def test_basic_type_with_extended_fields(): + class UserModel(pydantic.BaseModel): + age: int + + @strawberry.experimental.pydantic2.type(UserModel, all_fields=True) + class User: + name: str + + @strawberry.type + class Query: + @strawberry.field + def user(self) -> User: + return User(name="Marco", age=100) + + schema = strawberry.Schema(query=Query) + + expected_schema = """ + type Query { + user: User! + } + + type User { + name: String! + age: Int! + } + """ + assert str(schema) == textwrap.dedent(expected_schema).strip() + + query = "{ user { name age } }" + + result = schema.execute_sync(query) + + assert not result.errors + assert result.data["user"]["name"] == "Marco" + assert result.data["user"]["age"] == 100 + + +def test_type_with_custom_resolver(): + class UserModel(pydantic.BaseModel): + age: int + + def get_age_in_months(root): + return root.age * 12 + + @strawberry.experimental.pydantic2.type(UserModel, all_fields=True) + class User: + age_in_months: int = strawberry.field(resolver=get_age_in_months) + + @strawberry.type + class Query: + @strawberry.field + def user(self) -> User: + return User(age=20) + + schema = strawberry.Schema(query=Query) + + query = "{ user { age ageInMonths } }" + + result = schema.execute_sync(query) + + assert not result.errors + assert result.data["user"]["age"] == 20 + assert result.data["user"]["ageInMonths"] == 240 + + +def test_basic_type_with_union(): + class BranchA(pydantic.BaseModel): + field_a: str + + class BranchB(pydantic.BaseModel): + field_b: int + + class User(pydantic.BaseModel): + union_field: Union[BranchA, BranchB] + + @strawberry.experimental.pydantic2.type(BranchA, all_fields=True) + class BranchAType: + pass + + @strawberry.experimental.pydantic2.type(BranchB, all_fields=True) + class BranchBType: + pass + + @strawberry.experimental.pydantic2.type(User, all_fields=True) + class UserType: + pass + + @strawberry.type + class Query: + @strawberry.field + def user(self) -> UserType: + return UserType(union_field=BranchBType(field_b=10)) + + schema = strawberry.Schema(query=Query) + + query = "{ user { unionField { ... on BranchBType { fieldB } } } }" + + result = schema.execute_sync(query) + + assert not result.errors + assert result.data["user"]["unionField"]["fieldB"] == 10 + + +def test_basic_type_with_union_pydantic_types(): + class BranchA(pydantic.BaseModel): + field_a: str + + class BranchB(pydantic.BaseModel): + field_b: int + + class User(pydantic.BaseModel): + union_field: Union[BranchA, BranchB] + + @strawberry.experimental.pydantic2.type(BranchA, all_fields=True) + class BranchAType: + pass + + @strawberry.experimental.pydantic2.type(BranchB, all_fields=True) + class BranchBType: + pass + + @strawberry.experimental.pydantic2.type(User, all_fields=True) + class UserType: + pass + + @strawberry.type + class Query: + @strawberry.field + def user(self) -> UserType: + # note that BranchB is a pydantic type, not a strawberry type + return UserType(union_field=BranchB(field_b=10)) + + schema = strawberry.Schema(query=Query) + + query = "{ user { unionField { ... on BranchBType { fieldB } } } }" + + result = schema.execute_sync(query) + + assert not result.errors + assert result.data["user"]["unionField"]["fieldB"] == 10 + + +def test_basic_type_with_enum(): + @strawberry.enum + class UserKind(Enum): + user = 0 + admin = 1 + + class User(pydantic.BaseModel): + age: int + kind: UserKind + + @strawberry.experimental.pydantic2.type(User, all_fields=True) + class UserType: + pass + + @strawberry.type + class Query: + @strawberry.field + def user(self) -> UserType: + return UserType(age=10, kind=UserKind.admin) + + schema = strawberry.Schema(query=Query) + + query = "{ user { kind } }" + + result = schema.execute_sync(query) + + assert not result.errors + assert result.data["user"]["kind"] == "admin" + + +def test_basic_type_with_interface(): + class Base(pydantic.BaseModel): + base_field: str + + class BranchA(Base): + field_a: str + + class BranchB(Base): + field_b: int + + class User(pydantic.BaseModel): + interface_field: Base + + @strawberry.experimental.pydantic2.interface(Base, all_fields=True) + class BaseType: + pass + + @strawberry.experimental.pydantic2.type(BranchA, all_fields=True) + class BranchAType(BaseType): + pass + + @strawberry.experimental.pydantic2.type(BranchB, all_fields=True) + class BranchBType(BaseType): + pass + + @strawberry.experimental.pydantic2.type(User, all_fields=True) + class UserType: + pass + + @strawberry.type + class Query: + @strawberry.field + def user(self) -> UserType: + return UserType(interface_field=BranchBType(base_field="abc", field_b=10)) + + schema = strawberry.Schema(query=Query, types=[BranchAType, BranchBType]) + + query = "{ user { interfaceField { baseField, ... on BranchBType { fieldB } } } }" + + result = schema.execute_sync(query) + + assert not result.errors + assert result.data["user"]["interfaceField"]["baseField"] == "abc" + assert result.data["user"]["interfaceField"]["fieldB"] == 10 diff --git a/tests/experimental/pydantic2/test_basic.py b/tests/experimental/pydantic2/test_basic.py new file mode 100644 index 0000000000..fbb6fd45e2 --- /dev/null +++ b/tests/experimental/pydantic2/test_basic.py @@ -0,0 +1,894 @@ +import dataclasses +from enum import Enum +from typing import Any, List, Optional, Union + +import pydantic +import pytest + +import strawberry +from strawberry.enum import EnumDefinition +from strawberry.schema_directive import Location +from strawberry.type import StrawberryList, StrawberryOptional +from strawberry.types.types import TypeDefinition +from strawberry.union import StrawberryUnion + +if pydantic.__version__ >= "2.0.0": + # pydantic v2 imports need to be here to avoid import errors when running + # noxfile tests with pydantic v1 + # otherwise you need to add explicit directory exclusions for this folder + from strawberry.experimental.pydantic2.exceptions import MissingFieldsListError + + +def test_basic_type_all_fields(): + class User(pydantic.BaseModel): + age: int + password: Optional[str] + + @strawberry.experimental.pydantic2.type(User, all_fields=True) + class UserType: + pass + + definition: TypeDefinition = UserType._type_definition + assert definition.name == "UserType" + + [field1, field2] = definition.fields + + assert field1.python_name == "age" + assert field1.graphql_name is None + assert field1.type is int + + assert field2.python_name == "password" + assert field2.graphql_name is None + assert isinstance(field2.type, StrawberryOptional) + assert field2.type.of_type is str + + +@pytest.mark.filterwarnings("error") +def test_basic_type_all_fields_warn(): + class User(pydantic.BaseModel): + age: int + password: Optional[str] + + with pytest.raises( + UserWarning, + match=("Using all_fields overrides any explicitly defined fields"), + ): + + @strawberry.experimental.pydantic2.type(User, all_fields=True) + class UserType: + age: strawberry.auto + + +def test_basic_type_auto_fields(): + class User(pydantic.BaseModel): + age: int + password: Optional[str] + other: float + + @strawberry.experimental.pydantic2.type(User) + class UserType: + age: strawberry.auto + password: strawberry.auto + + definition: TypeDefinition = UserType._type_definition + assert definition.name == "UserType" + + [field1, field2] = definition.fields + + assert field1.python_name == "age" + assert field1.graphql_name is None + assert field1.type is int + + assert field2.python_name == "password" + assert field2.graphql_name is None + assert isinstance(field2.type, StrawberryOptional) + assert field2.type.of_type is str + + +def test_auto_fields_other_sentinel(): + class other_sentinel: + pass + + class User(pydantic.BaseModel): + age: int + password: Optional[str] + other: int + + @strawberry.experimental.pydantic2.type(User) + class UserType: + age: strawberry.auto + password: strawberry.auto + other: other_sentinel # this should be a private field, not an auto field + + definition: TypeDefinition = UserType._type_definition + assert definition.name == "UserType" + + [field1, field2, field3] = definition.fields + + assert field1.python_name == "age" + assert field1.graphql_name is None + assert field1.type is int + + assert field2.python_name == "password" + assert field2.graphql_name is None + assert isinstance(field2.type, StrawberryOptional) + assert field2.type.of_type is str + + assert field3.python_name == "other" + assert field3.graphql_name is None + assert field3.type is other_sentinel + + +def test_referencing_other_models_fails_when_not_registered(): + class Group(pydantic.BaseModel): + name: str + + class User(pydantic.BaseModel): + age: int + password: Optional[str] + group: Group + + with pytest.raises( + strawberry.experimental.pydantic2.UnregisteredTypeException, + match=("Cannot find a Strawberry Type for (.*) did you forget to register it?"), + ): + + @strawberry.experimental.pydantic2.type(User) + class UserType: + age: strawberry.auto + password: strawberry.auto + group: strawberry.auto + + +def test_referencing_other_input_models_fails_when_not_registered(): + class Group(pydantic.BaseModel): + name: str + + class User(pydantic.BaseModel): + age: int + password: Optional[str] + group: Group + + @strawberry.experimental.pydantic2.type(Group) + class GroupType: + name: strawberry.auto + + with pytest.raises( + strawberry.experimental.pydantic2.UnregisteredTypeException, + match=("Cannot find a Strawberry Type for (.*) did you forget to register it?"), + ): + + @strawberry.experimental.pydantic2.input(User) + class UserInputType: + age: strawberry.auto + password: strawberry.auto + group: strawberry.auto + + +def test_referencing_other_registered_models(): + class Group(pydantic.BaseModel): + name: str + + class User(pydantic.BaseModel): + age: int + group: Group + + @strawberry.experimental.pydantic2.type(Group) + class GroupType: + name: strawberry.auto + + @strawberry.experimental.pydantic2.type(User) + class UserType: + age: strawberry.auto + group: strawberry.auto + + definition: TypeDefinition = UserType._type_definition + assert definition.name == "UserType" + + [field1, field2] = definition.fields + + assert field1.python_name == "age" + assert field1.type is int + + assert field2.python_name == "group" + assert field2.type is GroupType + + +def test_list(): + class User(pydantic.BaseModel): + friend_names: List[str] + + @strawberry.experimental.pydantic2.type(User) + class UserType: + friend_names: strawberry.auto + + definition: TypeDefinition = UserType._type_definition + assert definition.name == "UserType" + + [field] = definition.fields + + assert field.python_name == "friend_names" + assert isinstance(field.type, StrawberryList) + assert field.type.of_type is str + + +def test_list_of_types(): + class Friend(pydantic.BaseModel): + name: str + + class User(pydantic.BaseModel): + friends: Optional[List[Optional[Friend]]] + + @strawberry.experimental.pydantic2.type(Friend) + class FriendType: + name: strawberry.auto + + @strawberry.experimental.pydantic2.type(User) + class UserType: + friends: strawberry.auto + + definition: TypeDefinition = UserType._type_definition + assert definition.name == "UserType" + + [field] = definition.fields + + assert field.python_name == "friends" + assert isinstance(field.type, StrawberryOptional) + assert isinstance(field.type.of_type, StrawberryList) + assert isinstance(field.type.of_type.of_type, StrawberryOptional) + assert field.type.of_type.of_type.of_type is FriendType + + +def test_basic_type_without_fields_throws_an_error(): + class User(pydantic.BaseModel): + age: int + password: Optional[str] + + with pytest.raises(MissingFieldsListError): + + @strawberry.experimental.pydantic2.type(User) + class UserType: + pass + + +def test_type_with_fields_coming_from_strawberry_and_pydantic(): + class User(pydantic.BaseModel): + age: int + password: Optional[str] + + @strawberry.experimental.pydantic2.type(User) + class UserType: + name: str + age: strawberry.auto + password: strawberry.auto + + definition: TypeDefinition = UserType._type_definition + assert definition.name == "UserType" + + [field1, field2, field3] = definition.fields + + assert field1.python_name == "name" + assert field1.type is str + + assert field2.python_name == "age" + assert field2.type is int + + assert field3.python_name == "password" + assert isinstance(field3.type, StrawberryOptional) + assert field3.type.of_type is str + + +def test_default_and_default_factory(): + class User1(pydantic.BaseModel): + friend: Optional[str] = "friend_value" + + @strawberry.experimental.pydantic2.type(User1) + class UserType1: + friend: strawberry.auto + + assert UserType1().friend == "friend_value" + assert UserType1().to_pydantic().friend == "friend_value" + + class User2(pydantic.BaseModel): + friend: Optional[str] = None + + @strawberry.experimental.pydantic2.type(User2) + class UserType2: + friend: strawberry.auto + + assert UserType2().friend is None + assert UserType2().to_pydantic().friend is None + + # Test instantiation using default_factory + + class User3(pydantic.BaseModel): + friend: Optional[str] = pydantic.Field(default_factory=lambda: "friend_value") + + @strawberry.experimental.pydantic2.type(User3) + class UserType3: + friend: strawberry.auto + + assert UserType3().friend == "friend_value" + assert UserType3().to_pydantic().friend == "friend_value" + + class User4(pydantic.BaseModel): + friend: Optional[str] = pydantic.Field(default_factory=lambda: None) + + @strawberry.experimental.pydantic2.type(User4) + class UserType4: + friend: strawberry.auto + + assert UserType4().friend is None + assert UserType4().to_pydantic().friend is None + + +def test_optional_and_default(): + class UserModel(pydantic.BaseModel): + age: int + name: str = pydantic.Field("Michael", description="The user name") + password: Optional[str] = pydantic.Field(default="ABC") + passwordtwo: Optional[str] = None + some_list: Optional[List[str]] = pydantic.Field(default_factory=list) + check: Optional[bool] = False + + @strawberry.experimental.pydantic2.type(UserModel, all_fields=True) + class User: + pass + + definition: TypeDefinition = User._type_definition + assert definition.name == "User" + + [ + age_field, + name_field, + password_field, + passwordtwo_field, + some_list_field, + check_field, + ] = definition.fields + + assert age_field.python_name == "age" + assert age_field.type is int + + assert name_field.python_name == "name" + assert name_field.type is str + + assert password_field.python_name == "password" + assert isinstance(password_field.type, StrawberryOptional) + assert password_field.type.of_type is str + + assert passwordtwo_field.python_name == "passwordtwo" + assert isinstance(passwordtwo_field.type, StrawberryOptional) + assert passwordtwo_field.type.of_type is str + + assert some_list_field.python_name == "some_list" + assert isinstance(some_list_field.type, StrawberryOptional) + assert isinstance(some_list_field.type.of_type, StrawberryList) + assert some_list_field.type.of_type.of_type is str + + assert check_field.python_name == "check" + assert isinstance(check_field.type, StrawberryOptional) + assert check_field.type.of_type is bool + + +def test_type_with_fields_mutable_default(): + empty_list = [] + + class User(pydantic.BaseModel): + groups: List[str] + friends: List[str] = empty_list + + @strawberry.experimental.pydantic2.type(User) + class UserType: + groups: strawberry.auto + friends: strawberry.auto + + definition: TypeDefinition = UserType._type_definition + assert definition.name == "UserType" + + [groups_field, friends_field] = definition.fields + + assert groups_field.default is dataclasses.MISSING + assert friends_field.default is dataclasses.MISSING + + # check that we really made a copy + assert friends_field.default_factory() is not empty_list + assert UserType(groups=["groups"]).friends is not empty_list + UserType(groups=["groups"]).friends.append("joe") + assert empty_list == [] + + +@pytest.mark.xfail( + reason=( + "passing default values when extending types from pydantic is not" + "supported. https://github.com/strawberry-graphql/strawberry/issues/829" + ) +) +def test_type_with_fields_coming_from_strawberry_and_pydantic_with_default(): + class User(pydantic.BaseModel): + age: int + password: Optional[str] + + @strawberry.experimental.pydantic2.type(User) + class UserType: + name: str = "Michael" + age: strawberry.auto + password: strawberry.auto + + definition: TypeDefinition = UserType._type_definition + assert definition.name == "UserType" + + [field1, field2, field3] = definition.fields + + assert field1.python_name == "age" + assert field1.type is int + + assert field2.python_name == "password" + assert isinstance(field2.type, StrawberryOptional) + assert field2.type.of_type is str + + assert field3.python_name == "name" + assert field3.type is str + assert field3.default == "Michael" + + +def test_type_with_nested_fields_coming_from_strawberry_and_pydantic(): + @strawberry.type + class Name: + first_name: str + last_name: str + + class User(pydantic.BaseModel): + age: int + password: Optional[str] + + @strawberry.experimental.pydantic2.type(User) + class UserType: + name: Name + age: strawberry.auto + password: strawberry.auto + + definition: TypeDefinition = UserType._type_definition + assert definition.name == "UserType" + + [field1, field2, field3] = definition.fields + + assert field1.python_name == "name" + assert field1.type is Name + + assert field2.python_name == "age" + assert field2.type is int + + assert field3.python_name == "password" + assert isinstance(field3.type, StrawberryOptional) + assert field3.type.of_type is str + + +def test_type_with_aliased_pydantic_field(): + class UserModel(pydantic.BaseModel): + age_: int = pydantic.Field(..., alias="age") + password: Optional[str] + + @strawberry.experimental.pydantic2.type(UserModel) + class User: + age_: strawberry.auto + password: strawberry.auto + + definition: TypeDefinition = User._type_definition + assert definition.name == "User" + + [field1, field2] = definition.fields + + assert field1.python_name == "age_" + assert field1.type is int + assert field1.graphql_name == "age" + + assert field2.python_name == "password" + assert isinstance(field2.type, StrawberryOptional) + assert field2.type.of_type is str + + +def test_union(): + class BranchA(pydantic.BaseModel): + field_a: str + + class BranchB(pydantic.BaseModel): + field_b: int + + class User(pydantic.BaseModel): + age: int + union_field: Union[BranchA, BranchB] + + @strawberry.experimental.pydantic2.type(BranchA) + class BranchAType: + field_a: strawberry.auto + + @strawberry.experimental.pydantic2.type(BranchB) + class BranchBType: + field_b: strawberry.auto + + @strawberry.experimental.pydantic2.type(User) + class UserType: + age: strawberry.auto + union_field: strawberry.auto + + definition: TypeDefinition = UserType._type_definition + assert definition.name == "UserType" + + [field1, field2] = definition.fields + + assert field1.python_name == "age" + assert field1.type is int + + assert field2.python_name == "union_field" + assert isinstance(field2.type, StrawberryUnion) + assert field2.type.types[0] is BranchAType + assert field2.type.types[1] is BranchBType + + +def test_enum(): + @strawberry.enum + class UserKind(Enum): + user = 0 + admin = 1 + + class User(pydantic.BaseModel): + age: int + kind: UserKind + + @strawberry.experimental.pydantic2.type(User) + class UserType: + age: strawberry.auto + kind: strawberry.auto + + definition: TypeDefinition = UserType._type_definition + assert definition.name == "UserType" + + [field1, field2] = definition.fields + + assert field1.python_name == "age" + assert field1.type is int + + assert field2.python_name == "kind" + assert isinstance(field2.type, EnumDefinition) + assert field2.type.wrapped_cls is UserKind + + +def test_interface(): + class Base(pydantic.BaseModel): + base_field: str + + class BranchA(Base): + field_a: str + + class BranchB(Base): + field_b: int + + class User(pydantic.BaseModel): + age: int + interface_field: Base + + @strawberry.experimental.pydantic2.interface(Base) + class BaseType: + base_field: strawberry.auto + + @strawberry.experimental.pydantic2.type(BranchA) + class BranchAType(BaseType): + field_a: strawberry.auto + + @strawberry.experimental.pydantic2.type(BranchB) + class BranchBType(BaseType): + field_b: strawberry.auto + + @strawberry.experimental.pydantic2.type(User) + class UserType: + age: strawberry.auto + interface_field: strawberry.auto + + definition: TypeDefinition = UserType._type_definition + assert definition.name == "UserType" + + [field1, field2] = definition.fields + + assert field1.python_name == "age" + assert field1.type is int + + assert field2.python_name == "interface_field" + assert field2.type is BaseType + + +def test_both_output_and_input_type(): + class Work(pydantic.BaseModel): + time: float + + class User(pydantic.BaseModel): + name: str + work: Optional[Work] + + class Group(pydantic.BaseModel): + users: List[User] + + # Test both definition orders + @strawberry.experimental.pydantic2.input(Work) + class WorkInput: + time: strawberry.auto + + @strawberry.experimental.pydantic2.type(Work) + class WorkOutput: + time: strawberry.auto + + @strawberry.experimental.pydantic2.type(User) + class UserOutput: + name: strawberry.auto + work: strawberry.auto + + @strawberry.experimental.pydantic2.input(User) + class UserInput: + name: strawberry.auto + work: strawberry.auto + + @strawberry.experimental.pydantic2.input(Group) + class GroupInput: + users: strawberry.auto + + @strawberry.experimental.pydantic2.type(Group) + class GroupOutput: + users: strawberry.auto + + @strawberry.type + class Query: + groups: List[GroupOutput] + + @strawberry.type + class Mutation: + @strawberry.mutation + def updateGroup(group: GroupInput) -> GroupOutput: + pass + + # This triggers the exception from #1504 + schema = strawberry.Schema(query=Query, mutation=Mutation) + expected_schema = """ +input GroupInput { + users: [UserInput!]! +} + +type GroupOutput { + users: [UserOutput!]! +} + +type Mutation { + updateGroup(group: GroupInput!): GroupOutput! +} + +type Query { + groups: [GroupOutput!]! +} + +input UserInput { + name: String! + work: WorkInput +} + +type UserOutput { + name: String! + work: WorkOutput +} + +input WorkInput { + time: Float! +} + +type WorkOutput { + time: Float! +}""".strip() + result_schema = schema.as_str().strip() + assert result_schema == expected_schema + + assert Group._strawberry_type == GroupOutput + assert Group._strawberry_input_type == GroupInput + assert User._strawberry_type == UserOutput + assert User._strawberry_input_type == UserInput + assert Work._strawberry_type == WorkOutput + assert Work._strawberry_input_type == WorkInput + + +def test_single_field_changed_type(): + class User(pydantic.BaseModel): + age: int + + @strawberry.experimental.pydantic2.type(User) + class UserType: + age: str + + definition: TypeDefinition = UserType._type_definition + assert definition.name == "UserType" + + [field1] = definition.fields + + assert field1.python_name == "age" + assert field1.graphql_name is None + assert field1.type is str + + +def test_type_with_aliased_pydantic_field_changed_type(): + class UserModel(pydantic.BaseModel): + age_: int = pydantic.Field(..., alias="age") + password: Optional[str] + + @strawberry.experimental.pydantic2.type(UserModel) + class User: + age_: str + password: strawberry.auto + + definition: TypeDefinition = User._type_definition + assert definition.name == "User" + + [field1, field2] = definition.fields + + assert field1.python_name == "age_" + assert field1.type is str + assert field1.graphql_name == "age" + + assert field2.python_name == "password" + assert isinstance(field2.type, StrawberryOptional) + assert field2.type.of_type is str + + +def test_deprecated_fields(): + class User(pydantic.BaseModel): + age: int + password: Optional[str] + other: float + + @strawberry.experimental.pydantic2.type(User) + class UserType: + age: strawberry.auto = strawberry.field(deprecation_reason="Because") + password: strawberry.auto + + definition: TypeDefinition = UserType._type_definition + assert definition.name == "UserType" + + [field1, field2] = definition.fields + + assert field1.python_name == "age" + assert field1.graphql_name is None + assert field1.type is int + assert field1.deprecation_reason == "Because" + + assert field2.python_name == "password" + assert field2.graphql_name is None + assert isinstance(field2.type, StrawberryOptional) + assert field2.type.of_type is str + + +def test_permission_classes(): + class IsAuthenticated(strawberry.BasePermission): + message = "User is not authenticated" + + def has_permission( + self, source: Any, info: strawberry.types.Info, **kwargs + ) -> bool: + return False + + class User(pydantic.BaseModel): + age: int + password: Optional[str] + other: float + + @strawberry.experimental.pydantic2.type(User) + class UserType: + age: strawberry.auto = strawberry.field(permission_classes=[IsAuthenticated]) + password: strawberry.auto + + definition: TypeDefinition = UserType._type_definition + assert definition.name == "UserType" + + [field1, field2] = definition.fields + + assert field1.python_name == "age" + assert field1.graphql_name is None + assert field1.type is int + assert field1.permission_classes == [IsAuthenticated] + + assert field2.python_name == "password" + assert field2.graphql_name is None + assert isinstance(field2.type, StrawberryOptional) + assert field2.type.of_type is str + + +def test_field_directives(): + @strawberry.schema_directive(locations=[Location.FIELD_DEFINITION]) + class Sensitive: + reason: str + + class User(pydantic.BaseModel): + age: int + password: Optional[str] + other: float + + @strawberry.experimental.pydantic2.type(User) + class UserType: + age: strawberry.auto = strawberry.field(directives=[Sensitive(reason="GDPR")]) + password: strawberry.auto + + definition: TypeDefinition = UserType._type_definition + assert definition.name == "UserType" + + [field1, field2] = definition.fields + + assert field1.python_name == "age" + assert field1.graphql_name is None + assert field1.type is int + assert field1.directives == [Sensitive(reason="GDPR")] + + assert field2.python_name == "password" + assert field2.graphql_name is None + assert isinstance(field2.type, StrawberryOptional) + assert field2.type.of_type is str + + +def test_alias_fields(): + class User(pydantic.BaseModel): + age: int + + @strawberry.experimental.pydantic2.type(User) + class UserType: + age: strawberry.auto = strawberry.field(name="ageAlias") + + definition: TypeDefinition = UserType._type_definition + assert definition.name == "UserType" + + field1 = definition.fields[0] + + assert field1.python_name == "age" + assert field1.graphql_name == "ageAlias" + assert field1.type is int + + +def test_alias_fields_with_use_pydantic_alias(): + class User(pydantic.BaseModel): + age: int + state: str = pydantic.Field(alias="statePydantic") + country: str = pydantic.Field(alias="countryPydantic") + + @strawberry.experimental.pydantic2.type(User, use_pydantic_alias=True) + class UserType: + age: strawberry.auto = strawberry.field(name="ageAlias") + state: strawberry.auto = strawberry.field(name="state") + country: strawberry.auto + + definition: TypeDefinition = UserType._type_definition + assert definition.name == "UserType" + + [field1, field2, field3] = definition.fields + + assert field1.python_name == "age" + assert field1.graphql_name == "ageAlias" + + assert field2.python_name == "state" + assert field2.graphql_name == "state" + + assert field3.python_name == "country" + assert field3.graphql_name == "countryPydantic" + + +def test_field_metadata(): + class User(pydantic.BaseModel): + private: bool + public: bool + + @strawberry.experimental.pydantic2.type(User) + class UserType: + private: strawberry.auto = strawberry.field(metadata={"admin_only": True}) + public: strawberry.auto + + definition: TypeDefinition = UserType._type_definition + assert definition.name == "UserType" + + [field1, field2] = definition.fields + + assert field1.python_name == "private" + assert field1.metadata["admin_only"] + + assert field2.python_name == "public" + assert not field2.metadata diff --git a/tests/schema/test_pydantic.py b/tests/schema/test_pydantic.py index 793862429b..f4dc07551a 100644 --- a/tests/schema/test_pydantic.py +++ b/tests/schema/test_pydantic.py @@ -1,10 +1,7 @@ -import pytest from pydantic import BaseModel, Field import strawberry -pytestmark = pytest.mark.pydantic - def test_use_alias_as_gql_name(): class UserModel(BaseModel):