From 3ac223a354a993ec2083f3f81547b6a6af6905d5 Mon Sep 17 00:00:00 2001 From: Thiago Bellini Ribeiro Date: Sat, 13 Jul 2024 13:59:54 -0300 Subject: [PATCH] fix(optimizer): Avoid extra queries for prefetches with existing prefetch hints Fix #559 --- strawberry_django/optimizer.py | 7 ++++ tests/test_optimizer.py | 60 ++++++++++++++++++++++++++++++++++ tests/utils.py | 5 +-- 3 files changed, 70 insertions(+), 2 deletions(-) diff --git a/strawberry_django/optimizer.py b/strawberry_django/optimizer.py index 683e9d48d..1988c3c9f 100644 --- a/strawberry_django/optimizer.py +++ b/strawberry_django/optimizer.py @@ -763,6 +763,13 @@ def _get_hints_from_django_relation( store.prefetch_related.append(model_fieldname) return store + field_store = getattr(field, "store", None) + if field_store and field_store.prefetch_related: + # When we have a prefetch_related in the field, skip the optimization as + # it will probably be filtering/annotating the queryset differently, and doing + # it here would result in an extra unused query + return store + remote_field = model_field.remote_field remote_model = remote_field.model field_store = _get_model_hints( diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index 97b5842d7..90f1970b4 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -1365,3 +1365,63 @@ def test_nested_prefetch_with_get_queryset( }, } mock_get_queryset.assert_called_once() + + +@pytest.mark.django_db(transaction=True) +def test_ensure_prefetch_hint_with_field_as_same_name_as_the_model_doesnt_cause_extra_db_queries( + db, +): + @strawberry_django.type(Issue) + class IssueType: + pk: strawberry.ID + + @strawberry_django.type(Milestone) + class MilestoneType: + pk: strawberry.ID + + @strawberry_django.field( + prefetch_related=[ + lambda info: Prefetch( + "issues", + queryset=Issue.objects.filter(name__startswith="Foo"), + to_attr="_my_issues", + ), + ], + ) + def issues(self) -> List[IssueType]: + return self._my_issues # type: ignore + + @strawberry.type + class Query: + milestone: MilestoneType = strawberry_django.field() + + schema = strawberry.Schema(query=Query, extensions=[DjangoOptimizerExtension]) + + milestone1 = MilestoneFactory.create() + milestone2 = MilestoneFactory.create() + + issue1 = IssueFactory.create(name="Foo", milestone=milestone1) + IssueFactory.create(name="Bar", milestone=milestone1) + IssueFactory.create(name="Foo", milestone=milestone2) + + query = """\ + query TestQuery ($pk: ID!) { + milestone(pk: $pk) { + pk + issues { + pk + } + } + } + """ + + with assert_num_queries(2): + res = schema.execute_sync(query, {"pk": milestone1.pk}) + + assert res.errors is None + assert res.data == { + "milestone": { + "pk": str(milestone1.pk), + "issues": [{"pk": str(issue1.pk)}], + }, + } diff --git a/tests/utils.py b/tests/utils.py index b054eceb2..80416bbb9 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -97,8 +97,9 @@ def assert_num_queries(n: int, *, using=DEFAULT_DB_ALIAS): # FIXME: Async will not have access to the correct number of queries without # execing CaptureQueriesContext.(__enter__|__exit__) wrapped in sync_to_async # How can we fix this? - if _client.get().is_async and executed == 0: - return + with contextlib.suppress(LookupError): + if _client.get().is_async and executed == 0: + return assert ( executed == n