From c7c50e9e20af08572ba9924e02a51c2ca8711c0b Mon Sep 17 00:00:00 2001 From: Thiago Bellini Ribeiro Date: Sun, 14 Jul 2024 10:01:52 -0300 Subject: [PATCH] fix(optimizer): Convert select_related into Prefetch when the type defines a custom get_queryset (#583) --- strawberry_django/fields/field.py | 2 + strawberry_django/optimizer.py | 68 ++++++++++++++++--- tests/projects/schema.py | 14 ++-- tests/projects/snapshots/schema.gql | 2 +- .../snapshots/schema_with_inheritance.gql | 4 +- tests/test_optimizer.py | 54 ++++++++++++++- 6 files changed, 122 insertions(+), 22 deletions(-) diff --git a/strawberry_django/fields/field.py b/strawberry_django/fields/field.py index ad0eea2a..87ed5e84 100644 --- a/strawberry_django/fields/field.py +++ b/strawberry_django/fields/field.py @@ -359,6 +359,8 @@ def default_resolver( return cast(Iterable[Any], retval) + default_resolver.can_optimize = True # type: ignore + field.base_resolver = StrawberryResolver(default_resolver) return super().apply(field) diff --git a/strawberry_django/optimizer.py b/strawberry_django/optimizer.py index 683e9d48..268822ac 100644 --- a/strawberry_django/optimizer.py +++ b/strawberry_django/optimizer.py @@ -68,6 +68,7 @@ PrefetchType, TypeOrMapping, TypeOrSequence, + WithStrawberryDjangoObjectDefinition, get_django_definition, has_django_definition, ) @@ -422,6 +423,20 @@ def _apply_annotate( return qs.annotate(**to_annotate) +def _get_django_type( + field: StrawberryField, +) -> type[WithStrawberryDjangoObjectDefinition] | None: + f_type = field.type + if isinstance(f_type, LazyType): + f_type = f_type.resolve_type() + if isinstance(f_type, StrawberryContainer): + f_type = f_type.of_type + if isinstance(f_type, LazyType): + f_type = f_type.resolve_type() + + return f_type if has_django_definition(f_type) else None + + def _get_prefetch_queryset( remote_model: type[models.Model], schema: Schema, @@ -443,15 +458,7 @@ def _get_prefetch_queryset( else: qs = remote_model._base_manager.all() # type: ignore - f_type = field.type - if isinstance(f_type, LazyType): - f_type = f_type.resolve_type() - while isinstance(f_type, StrawberryContainer): - f_type = f_type.of_type - if isinstance(f_type, LazyType): - f_type = f_type.resolve_type() - - if has_django_definition(f_type): + if f_type := _get_django_type(field): qs = run_type_get_queryset( qs, f_type, @@ -690,16 +697,39 @@ def _get_hints_from_model_property( def _get_hints_from_django_foreign_key( field: StrawberryField, field_definition: GraphQLObjectType, + field_selection: FieldNode, model_field: models.ForeignKey | OneToOneRel, model_fieldname: str, schema: Schema, *, config: OptimizerConfig, + parent_type: GraphQLObjectType | GraphQLInterfaceType, field_info: GraphQLResolveInfo, path: str, cache: dict[type[models.Model], list[tuple[int, OptimizerStore]]], level: int = 0, ) -> OptimizerStore: + f_type = _get_django_type(field) + if f_type and hasattr(f_type, "get_queryset"): + # If the field has a get_queryset method, change strategy to Prefetch + # so it will be respected + store = _get_hints_from_django_relation( + field, + field_definition=field_definition, + field_selection=field_selection, + model_field=model_field, + model_fieldname=model_fieldname, + schema=schema, + config=config, + parent_type=parent_type, + field_info=field_info, + path=path, + cache=cache, + level=level, + ) + store.only.append(path) + return store + store = OptimizerStore.with_hints( only=[path], select_related=[path], @@ -737,7 +767,12 @@ def _get_hints_from_django_relation( field_definition: GraphQLObjectType, field_selection: FieldNode, model_field: ( - models.ManyToManyField | ManyToManyRel | ManyToOneRel | GenericRelation + models.ManyToManyField + | ManyToManyRel + | ManyToOneRel + | GenericRelation + | OneToOneRel + | models.ForeignKey ), model_fieldname: str, schema: Schema, @@ -861,6 +896,17 @@ def _get_hints_from_django_field( GenericRelation, ) + # If the field has a base resolver, don't try to optimize it. The user should + # be defining custom hints in this case, which should already be in the store + # GlobalID and special cases setting `can_optimize` are ok though, as those resolvers + # are auto generated by us + if ( + field.base_resolver is not None + and field.type != relay.GlobalID + and not getattr(field.base_resolver.wrapped_func, "can_optimize", False) + ): + return None + model_fieldname: str = getattr(field, "django_name", None) or field.python_name if (model_field := get_model_field(model, model_fieldname)) is None: return None @@ -871,10 +917,12 @@ def _get_hints_from_django_field( store = _get_hints_from_django_foreign_key( field, field_definition=field_definition, + field_selection=field_selection, model_field=model_field, model_fieldname=model_fieldname, schema=schema, config=config, + parent_type=parent_type, field_info=field_info, path=path, cache=cache, diff --git a/tests/projects/schema.py b/tests/projects/schema.py index 6acb4155..83097832 100644 --- a/tests/projects/schema.py +++ b/tests/projects/schema.py @@ -158,6 +158,12 @@ class MilestoneType(relay.Node, Named): order=IssueOrder, pagination=True, ) + issues_with_filters: ListConnectionWithTotalCount["IssueType"] = ( + strawberry_django.connection( + field_name="issues", + filters=IssueFilter, + ) + ) @strawberry_django.field( prefetch_related=[ @@ -178,14 +184,6 @@ class MilestoneType(relay.Node, Named): def my_issues(self) -> List["IssueType"]: return self._my_issues # type: ignore - @strawberry_django.connection( - ListConnectionWithTotalCount["IssueType"], - field_name="issues", - filters=IssueFilter, - ) - def issues_with_filters(self) -> List["IssueType"]: - return self.issues.all() # type: ignore - @strawberry_django.field( annotate={ "_my_bugs_count": lambda info: Count( diff --git a/tests/projects/snapshots/schema.gql b/tests/projects/snapshots/schema.gql index 6e8d5d6a..16533658 100644 --- a/tests/projects/snapshots/schema.gql +++ b/tests/projects/snapshots/schema.gql @@ -373,7 +373,6 @@ type MilestoneType implements Node & Named { dueDate: Date project: ProjectType! issues(filters: IssueFilter, order: IssueOrder, pagination: OffsetPaginationInput): [IssueType!]! - myIssues: [IssueType!]! issuesWithFilters( filters: IssueFilter @@ -389,6 +388,7 @@ type MilestoneType implements Node & Named { """Returns the items in the list that come after the specified cursor.""" last: Int = null ): IssueTypeConnection! + myIssues: [IssueType!]! myBugsCount: Int! asyncField(value: String!): String! } diff --git a/tests/projects/snapshots/schema_with_inheritance.gql b/tests/projects/snapshots/schema_with_inheritance.gql index f443e807..8c891a65 100644 --- a/tests/projects/snapshots/schema_with_inheritance.gql +++ b/tests/projects/snapshots/schema_with_inheritance.gql @@ -163,7 +163,6 @@ type MilestoneType implements Node & Named { dueDate: Date project: ProjectType! issues(filters: IssueFilter, order: IssueOrder, pagination: OffsetPaginationInput): [IssueType!]! - myIssues: [IssueType!]! issuesWithFilters( filters: IssueFilter @@ -179,6 +178,7 @@ type MilestoneType implements Node & Named { """Returns the items in the list that come after the specified cursor.""" last: Int = null ): IssueTypeConnection! + myIssues: [IssueType!]! myBugsCount: Int! asyncField(value: String!): String! } @@ -190,7 +190,6 @@ type MilestoneTypeSubclass implements Node & Named { dueDate: Date project: ProjectType! issues(filters: IssueFilter, order: IssueOrder, pagination: OffsetPaginationInput): [IssueType!]! - myIssues: [IssueType!]! issuesWithFilters( filters: IssueFilter @@ -206,6 +205,7 @@ type MilestoneTypeSubclass implements Node & Named { """Returns the items in the list that come after the specified cursor.""" last: Int = null ): IssueTypeConnection! + myIssues: [IssueType!]! myBugsCount: Int! asyncField(value: String!): String! } diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index 97b5842d..9298a1d9 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -1,5 +1,5 @@ import datetime -from typing import Any, List, cast +from typing import Any, List, Optional, cast import pytest import strawberry @@ -1365,3 +1365,55 @@ def test_nested_prefetch_with_get_queryset( }, } mock_get_queryset.assert_called_once() + + +@pytest.mark.django_db(transaction=True) +def test_select_related_fallsback_to_prefetch_when_type_defines_get_queryset( + db, +): + @strawberry_django.type(Milestone) + class MilestoneType: + pk: strawberry.ID + + @classmethod + def get_queryset(cls, queryset, info, **kwargs): + return queryset.filter(name__startswith="Foo") + + @strawberry_django.type(Issue) + class IssueType: + pk: strawberry.ID + milestone: Optional[MilestoneType] + + @strawberry.type + class Query: + issues: List[IssueType] = strawberry_django.field() + + schema = strawberry.Schema(query=Query, extensions=[DjangoOptimizerExtension]) + + milestone1 = MilestoneFactory.create(name="Foo") + milestone2 = MilestoneFactory.create(name="Bar") + + issue1 = IssueFactory.create(milestone=milestone1) + issue2 = IssueFactory.create(milestone=milestone2) + + query = """\ + query TestQuery { + issues { + pk + milestone { + pk + } + } + } + """ + + with assert_num_queries(2): + res = schema.execute_sync(query) + + assert res.errors is None + assert res.data == { + "issues": [ + {"pk": str(issue1.pk), "milestone": {"pk": str(milestone1.pk)}}, + {"pk": str(issue2.pk), "milestone": None}, + ], + }