Skip to content

Commit

Permalink
fix: make sure custom fields are kept during inheritance (#415)
Browse files Browse the repository at this point in the history
When a type is using a custom field, inheriting from it should not
change that field back to a `StrawberryDjangoField`

Fix #414
  • Loading branch information
bellini666 authored Nov 13, 2023
1 parent d0a32e4 commit fb19948
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 1 deletion.
8 changes: 7 additions & 1 deletion strawberry_django/fields/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
from strawberry.relay.types import NodeIterableType
from strawberry.types.info import Info
from strawberry.unset import UnsetType
from typing_extensions import Literal
from typing_extensions import Literal, Self

from strawberry_django.utils.typing import (
AnnotateType,
Expand Down Expand Up @@ -115,6 +115,12 @@ def __init__(
)
super().__init__(*args, **kwargs)

def __copy__(self) -> Self:
new_field = super().__copy__()
new_field.disable_optimization = self.disable_optimization
new_field.store = self.store.copy()
return new_field

@cached_property
def _need_remove_filters_argument(self):
if not self.base_resolver or not self.is_connection:
Expand Down
2 changes: 2 additions & 0 deletions strawberry_django/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,8 @@ def _process_type(
# seeing it, just update its annotations/description/etc
f.type_annotation = type_annotation
f.description = description
elif isinstance(f, StrawberryDjangoField):
f = copy.copy(f) # noqa: PLW2901
elif (
not isinstance(f, StrawberryDjangoField)
and getattr(f, "base_resolver", None) is not None
Expand Down
21 changes: 21 additions & 0 deletions tests/test_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@

import strawberry
from django.db import models
from strawberry.type import get_object_definition

import strawberry_django
from strawberry_django.fields.field import StrawberryDjangoField
from strawberry_django.utils.typing import get_django_definition


Expand Down Expand Up @@ -109,3 +111,22 @@ class SomeModelType:
assert store.select_related == ["other"]
assert store.prefetch_related == ["other"]
assert store.annotate == {"other_name": models.F("other__name")}


def test_custom_field_kept_on_inheritance():
class SomeModel(models.Model):
foo = models.CharField(max_length=255)

class CustomField(StrawberryDjangoField): ...

@strawberry_django.type(SomeModel)
class SomeModelType:
foo: strawberry.auto = CustomField()

@strawberry_django.type(SomeModel)
class SomeModelSubclassType(SomeModelType): ...

for type_ in [SomeModelType, SomeModelSubclassType]:
object_definition = get_object_definition(type_, strict=True)
field = object_definition.get_field("foo")
assert isinstance(field, CustomField)

0 comments on commit fb19948

Please sign in to comment.