Skip to content

Commit

Permalink
fix: Return null on empty files/images
Browse files Browse the repository at this point in the history
Fix #453
  • Loading branch information
bellini666 committed Jan 27, 2024
1 parent e751a2d commit 9f46e8c
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 3 deletions.
8 changes: 8 additions & 0 deletions strawberry_django/fields/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from asgiref.sync import sync_to_async
from django.core.exceptions import ObjectDoesNotExist
from django.db import models
from django.db.models.fields.files import FileDescriptor
from django.db.models.fields.related import (
ForwardManyToOneDescriptor,
ReverseManyToOneDescriptor,
Expand Down Expand Up @@ -202,7 +203,14 @@ def get_result(
# Reversed OneToOne will raise ObjectDoesNotExist when
# trying to access it if the relation doesn't exist.
except_as_none=(ObjectDoesNotExist,) if self.is_optional else None,
empty_file_descriptor_as_null=True,
)
else:
# FileField/ImageField will always return a FileDescriptor, even when the
# field is "null". If it is falsy (i.e. doesn't have a file) we should
# return `None` instead.
if isinstance(attr, FileDescriptor) and not result:
result = None

if is_awaitable:

Expand Down
28 changes: 25 additions & 3 deletions strawberry_django/resolvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from asgiref.sync import sync_to_async
from django.db import models
from django.db.models.fields.files import FileDescriptor
from django.db.models.manager import BaseManager
from strawberry.utils.inspect import in_async_context
from typing_extensions import ParamSpec
Expand Down Expand Up @@ -127,6 +128,7 @@ def django_getattr(
*,
qs_hook: Callable[[models.QuerySet[_M]], Any] = default_qs_hook,
except_as_none: tuple[type[Exception], ...] | None = None,
empty_file_descriptor_as_null: bool = False,
) -> AwaitableOrValue[Any]: ...


Expand All @@ -138,6 +140,7 @@ def django_getattr(
*,
qs_hook: Callable[[models.QuerySet[_M]], Any] = default_qs_hook,
except_as_none: tuple[type[Exception], ...] | None = None,
empty_file_descriptor_as_null: bool = False,
) -> AwaitableOrValue[Any]: ...


Expand All @@ -148,10 +151,29 @@ def django_getattr(
*,
qs_hook: Callable[[models.QuerySet[_M]], Any] = default_qs_hook,
except_as_none: tuple[type[Exception], ...] | None = None,
empty_file_descriptor_as_null: bool = False,
):
args = (default,) if default is not _SENTINEL else ()
return django_resolver(getattr, qs_hook=qs_hook, except_as_none=except_as_none)(
return django_resolver(
_django_getattr,
qs_hook=qs_hook,
except_as_none=except_as_none,
)(
obj,
name,
*args,
default,
empty_file_descriptor_as_null=empty_file_descriptor_as_null,
)


def _django_getattr(
obj: Any,
name: str,
default: Any = _SENTINEL,
*,
empty_file_descriptor_as_null: bool = False,
):
args = (default,) if default is not _SENTINEL else ()
result = getattr(obj, name, *args)
if empty_file_descriptor_as_null and isinstance(result, FileDescriptor):
result = None
return result
73 changes: 73 additions & 0 deletions tests/test_queries.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import io
import textwrap
from typing import List, Optional, cast

import pytest
import strawberry
from asgiref.sync import sync_to_async
from django.core.files.uploadedfile import SimpleUploadedFile
from graphql import GraphQLError
from PIL import Image
from strawberry import auto

import strawberry_django
Expand Down Expand Up @@ -31,6 +35,13 @@ class Group:
users: List[User]


@strawberry_django.type(models.Fruit)
class Fruit:
id: auto
name: auto
picture: auto


@strawberry_django.type(models.Fruit)
class BerryFruit:
id: auto
Expand Down Expand Up @@ -62,6 +73,7 @@ class Query:
users: List[User] = strawberry_django.field()
group: Group = strawberry_django.field()
groups: List[Group] = strawberry_django.field()
fruit: Fruit = strawberry_django.field()
berries: List[BerryFruit] = strawberry_django.field()
bananas: List[BananaFruit] = strawberry_django.field()

Expand Down Expand Up @@ -155,6 +167,67 @@ async def test_model_properties(query, fruits):
]


async def test_query_file_field(query):
img_f = io.BytesIO()
img = Image.new(mode="RGB", size=(1, 1), color="red")
img.save(img_f, format="jpeg")
upload = SimpleUploadedFile("strawberry-picture.png", img_f.getvalue())
fruit = await sync_to_async(models.Fruit.objects.create)(
name="Strawberry",
picture=upload,
)

result = await query(
"""\
query Fruit ($pk: ID!) {
fruit (pk: $pk) {
id
name
picture {
name
}
}
}
""",
{"pk": fruit.pk},
)

assert not result.errors
assert result.data is not None
assert result.data["fruit"] == {
"id": str(fruit.pk),
"name": "Strawberry",
"picture": {"name": ".tmp_upload/strawberry-picture.png"},
}


async def test_query_file_field_when_null(query):
fruit = await sync_to_async(models.Fruit.objects.create)(name="Strawberry")

result = await query(
"""\
query Fruit ($pk: ID!) {
fruit (pk: $pk) {
id
name
picture {
name
}
}
}
""",
{"pk": fruit.pk},
)

assert not result.errors
assert result.data is not None
assert result.data["fruit"] == {
"id": str(fruit.pk),
"name": "Strawberry",
"picture": None,
}


def test_field_name():
"""Make sure that field_name overriding is not ignored."""

Expand Down

0 comments on commit 9f46e8c

Please sign in to comment.