diff --git a/strawberry_django/fields/field.py b/strawberry_django/fields/field.py index cdb038e5..8d842a6f 100644 --- a/strawberry_django/fields/field.py +++ b/strawberry_django/fields/field.py @@ -18,6 +18,7 @@ from asgiref.sync import sync_to_async from django.core.exceptions import ObjectDoesNotExist from django.db import models +from django.db.models.fields.files import FileDescriptor from django.db.models.fields.related import ( ForwardManyToOneDescriptor, ReverseManyToOneDescriptor, @@ -202,7 +203,14 @@ def get_result( # Reversed OneToOne will raise ObjectDoesNotExist when # trying to access it if the relation doesn't exist. except_as_none=(ObjectDoesNotExist,) if self.is_optional else None, + empty_file_descriptor_as_null=True, ) + else: + # FileField/ImageField will always return a FileDescriptor, even when the + # field is "null". If it is falsy (i.e. doesn't have a file) we should + # return `None` instead. + if isinstance(attr, FileDescriptor) and not result: + result = None if is_awaitable: diff --git a/strawberry_django/resolvers.py b/strawberry_django/resolvers.py index af39a231..911b0b7b 100644 --- a/strawberry_django/resolvers.py +++ b/strawberry_django/resolvers.py @@ -7,6 +7,7 @@ from asgiref.sync import sync_to_async from django.db import models +from django.db.models.fields.files import FileDescriptor from django.db.models.manager import BaseManager from strawberry.utils.inspect import in_async_context from typing_extensions import ParamSpec @@ -127,6 +128,7 @@ def django_getattr( *, qs_hook: Callable[[models.QuerySet[_M]], Any] = default_qs_hook, except_as_none: tuple[type[Exception], ...] | None = None, + empty_file_descriptor_as_null: bool = False, ) -> AwaitableOrValue[Any]: ... @@ -138,6 +140,7 @@ def django_getattr( *, qs_hook: Callable[[models.QuerySet[_M]], Any] = default_qs_hook, except_as_none: tuple[type[Exception], ...] | None = None, + empty_file_descriptor_as_null: bool = False, ) -> AwaitableOrValue[Any]: ... @@ -148,10 +151,29 @@ def django_getattr( *, qs_hook: Callable[[models.QuerySet[_M]], Any] = default_qs_hook, except_as_none: tuple[type[Exception], ...] | None = None, + empty_file_descriptor_as_null: bool = False, ): - args = (default,) if default is not _SENTINEL else () - return django_resolver(getattr, qs_hook=qs_hook, except_as_none=except_as_none)( + return django_resolver( + _django_getattr, + qs_hook=qs_hook, + except_as_none=except_as_none, + )( obj, name, - *args, + default, + empty_file_descriptor_as_null=empty_file_descriptor_as_null, ) + + +def _django_getattr( + obj: Any, + name: str, + default: Any = _SENTINEL, + *, + empty_file_descriptor_as_null: bool = False, +): + args = (default,) if default is not _SENTINEL else () + result = getattr(obj, name, *args) + if empty_file_descriptor_as_null and isinstance(result, FileDescriptor): + result = None + return result diff --git a/tests/test_queries.py b/tests/test_queries.py index c9015dfa..f1486e83 100644 --- a/tests/test_queries.py +++ b/tests/test_queries.py @@ -1,9 +1,13 @@ +import io import textwrap from typing import List, Optional, cast import pytest import strawberry +from asgiref.sync import sync_to_async +from django.core.files.uploadedfile import SimpleUploadedFile from graphql import GraphQLError +from PIL import Image from strawberry import auto import strawberry_django @@ -31,6 +35,13 @@ class Group: users: List[User] +@strawberry_django.type(models.Fruit) +class Fruit: + id: auto + name: auto + picture: auto + + @strawberry_django.type(models.Fruit) class BerryFruit: id: auto @@ -62,6 +73,7 @@ class Query: users: List[User] = strawberry_django.field() group: Group = strawberry_django.field() groups: List[Group] = strawberry_django.field() + fruit: Fruit = strawberry_django.field() berries: List[BerryFruit] = strawberry_django.field() bananas: List[BananaFruit] = strawberry_django.field() @@ -155,6 +167,67 @@ async def test_model_properties(query, fruits): ] +async def test_query_file_field(query): + img_f = io.BytesIO() + img = Image.new(mode="RGB", size=(1, 1), color="red") + img.save(img_f, format="jpeg") + upload = SimpleUploadedFile("strawberry-picture.png", img_f.getvalue()) + fruit = await sync_to_async(models.Fruit.objects.create)( + name="Strawberry", + picture=upload, + ) + + result = await query( + """\ + query Fruit ($pk: ID!) { + fruit (pk: $pk) { + id + name + picture { + name + } + } + } + """, + {"pk": fruit.pk}, + ) + + assert not result.errors + assert result.data is not None + assert result.data["fruit"] == { + "id": str(fruit.pk), + "name": "Strawberry", + "picture": {"name": ".tmp_upload/strawberry-picture.png"}, + } + + +async def test_query_file_field_when_null(query): + fruit = await sync_to_async(models.Fruit.objects.create)(name="Strawberry") + + result = await query( + """\ + query Fruit ($pk: ID!) { + fruit (pk: $pk) { + id + name + picture { + name + } + } + } + """, + {"pk": fruit.pk}, + ) + + assert not result.errors + assert result.data is not None + assert result.data["fruit"] == { + "id": str(fruit.pk), + "name": "Strawberry", + "picture": None, + } + + def test_field_name(): """Make sure that field_name overriding is not ignored."""