Skip to content

Commit

Permalink
fix(optimizer): Avoid extra queries for prefetches with existing pref…
Browse files Browse the repository at this point in the history
…etch hints

Fix #559
  • Loading branch information
bellini666 committed Jul 13, 2024
1 parent cb9f63a commit 3ac223a
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 2 deletions.
7 changes: 7 additions & 0 deletions strawberry_django/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,6 +763,13 @@ def _get_hints_from_django_relation(
store.prefetch_related.append(model_fieldname)
return store

field_store = getattr(field, "store", None)
if field_store and field_store.prefetch_related:
# When we have a prefetch_related in the field, skip the optimization as
# it will probably be filtering/annotating the queryset differently, and doing
# it here would result in an extra unused query
return store

remote_field = model_field.remote_field
remote_model = remote_field.model
field_store = _get_model_hints(
Expand Down
60 changes: 60 additions & 0 deletions tests/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1365,3 +1365,63 @@ def test_nested_prefetch_with_get_queryset(
},
}
mock_get_queryset.assert_called_once()


@pytest.mark.django_db(transaction=True)
def test_ensure_prefetch_hint_with_field_as_same_name_as_the_model_doesnt_cause_extra_db_queries(
db,
):
@strawberry_django.type(Issue)
class IssueType:
pk: strawberry.ID

@strawberry_django.type(Milestone)
class MilestoneType:
pk: strawberry.ID

@strawberry_django.field(
prefetch_related=[
lambda info: Prefetch(
"issues",
queryset=Issue.objects.filter(name__startswith="Foo"),
to_attr="_my_issues",
),
],
)
def issues(self) -> List[IssueType]:
return self._my_issues # type: ignore

@strawberry.type
class Query:
milestone: MilestoneType = strawberry_django.field()

schema = strawberry.Schema(query=Query, extensions=[DjangoOptimizerExtension])

milestone1 = MilestoneFactory.create()
milestone2 = MilestoneFactory.create()

issue1 = IssueFactory.create(name="Foo", milestone=milestone1)
IssueFactory.create(name="Bar", milestone=milestone1)
IssueFactory.create(name="Foo", milestone=milestone2)

query = """\
query TestQuery ($pk: ID!) {
milestone(pk: $pk) {
pk
issues {
pk
}
}
}
"""

with assert_num_queries(2):
res = schema.execute_sync(query, {"pk": milestone1.pk})

assert res.errors is None
assert res.data == {
"milestone": {
"pk": str(milestone1.pk),
"issues": [{"pk": str(issue1.pk)}],
},
}
5 changes: 3 additions & 2 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,9 @@ def assert_num_queries(n: int, *, using=DEFAULT_DB_ALIAS):
# FIXME: Async will not have access to the correct number of queries without
# execing CaptureQueriesContext.(__enter__|__exit__) wrapped in sync_to_async
# How can we fix this?
if _client.get().is_async and executed == 0:
return
with contextlib.suppress(LookupError):
if _client.get().is_async and executed == 0:
return

assert (
executed == n
Expand Down

0 comments on commit 3ac223a

Please sign in to comment.