Skip to content

Commit

Permalink
add test for #2782
Browse files Browse the repository at this point in the history
  • Loading branch information
thejaminator committed Jul 7, 2023
1 parent ae295f5 commit f582d41
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 17 deletions.
2 changes: 2 additions & 0 deletions strawberry/experimental/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
try:
from . import pydantic

__all__ = ["pydantic"]
except ImportError:
pass
try:
from . import pydantic2

# Support for pydantic2 is highly experimental and the interface will change
# We don't recommend using it yet
__all__ = ["pydantic2"]
Expand Down
12 changes: 3 additions & 9 deletions strawberry/experimental/pydantic2/object_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,9 @@ def is_required(field: FieldInfo) -> bool:
def get_type_for_field(field: FieldInfo, is_input: bool): # noqa: ANN201
outer_type = field.annotation
replaced_type = replace_types_recursively(outer_type, is_input)

default_defined: bool = (
field.default_factory is not None or field.default is not None
)
should_add_optional: bool = not (is_required(field) or default_defined)
if should_add_optional:
return Optional[replaced_type]
else:
return replaced_type
# Note that unlike pydantic v1, pydantic v2 does not add a default of None when
# the field is Optional[something]
return replaced_type


def _build_dataclass_creation_fields(
Expand Down
16 changes: 8 additions & 8 deletions strawberry/experimental/pydantic2/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def is_required(field: ModelField) -> bool:


def get_default_factory_for_field(
field: ModelField,
field: ModelField,
) -> Union[NoArgAnyCallable, dataclasses._MISSING_TYPE]:
"""
Gets the default factory for a pydantic field.
Expand All @@ -95,7 +95,9 @@ def get_default_factory_for_field(
default_factory = (
field.default_factory if field.default_factory is not None else UNSET
)
default = field.default if not isinstance(field.default, PydanticUndefinedType) else UNSET
default = (
field.default if not isinstance(field.default, PydanticUndefinedType) else UNSET
)

has_factory = default_factory is not UNSET
has_default = default is not UNSET
Expand All @@ -121,17 +123,15 @@ def get_default_factory_for_field(
if has_default:
return lambda: smart_deepcopy(default)

# if we don't have default or default_factory, but the field is not required,
# we should return a factory that returns None

if not is_required(field):
return lambda: None
# Note that unlike pydantic v1, pydantic v2 does not add a default of None when
# the field is Optional[something]
# so there is no need to handle that case here

return dataclasses.MISSING


def ensure_all_auto_fields_in_pydantic(
model: Type[BaseModel], auto_fields: Set[str], cls_name: str
model: Type[BaseModel], auto_fields: Set[str], cls_name: str
) -> Union[NoReturn, None]:
# Raise error if user defined a strawberry.auto field not present in the model
non_existing_fields = list(auto_fields - model.model_fields.keys())
Expand Down
49 changes: 49 additions & 0 deletions tests/experimental/pydantic2/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,55 @@ class UserType4:
assert UserType4().to_pydantic().friend is None


def test_optional_and_default():
class UserModel(pydantic.BaseModel):
age: int
name: str = pydantic.Field("Michael", description="The user name")
password: Optional[str] = pydantic.Field(default="ABC")
passwordtwo: Optional[str] = None
some_list: Optional[List[str]] = pydantic.Field(default_factory=list)
check: Optional[bool] = False

@strawberry.experimental.pydantic2.type(UserModel, all_fields=True)
class User:
pass

definition: TypeDefinition = User._type_definition
assert definition.name == "User"

[
age_field,
name_field,
password_field,
passwordtwo_field,
some_list_field,
check_field,
] = definition.fields

assert age_field.python_name == "age"
assert age_field.type is int

assert name_field.python_name == "name"
assert name_field.type is str

assert password_field.python_name == "password"
assert isinstance(password_field.type, StrawberryOptional)
assert password_field.type.of_type is str

assert passwordtwo_field.python_name == "passwordtwo"
assert isinstance(passwordtwo_field.type, StrawberryOptional)
assert passwordtwo_field.type.of_type is str

assert some_list_field.python_name == "some_list"
assert isinstance(some_list_field.type, StrawberryOptional)
assert isinstance(some_list_field.type.of_type, StrawberryList)
assert some_list_field.type.of_type.of_type is str

assert check_field.python_name == "check"
assert isinstance(check_field.type, StrawberryOptional)
assert check_field.type.of_type is bool


def test_type_with_fields_mutable_default():
empty_list = []

Expand Down

0 comments on commit f582d41

Please sign in to comment.