From 3f81116f56f622535cd79a58df108196f5661757 Mon Sep 17 00:00:00 2001 From: Take Weiland Date: Tue, 17 Dec 2024 16:30:50 +0100 Subject: [PATCH] Improve prefetching for single/optional fields by not prefetching the whole table --- strawberry_django/fields/field.py | 18 ++++++++---------- strawberry_django/pagination.py | 14 ++++++++++++++ 2 files changed, 22 insertions(+), 10 deletions(-) diff --git a/strawberry_django/fields/field.py b/strawberry_django/fields/field.py index 847b1611..bd70120e 100644 --- a/strawberry_django/fields/field.py +++ b/strawberry_django/fields/field.py @@ -31,6 +31,7 @@ ReverseOneToOneDescriptor, ) from django.db.models.manager import BaseManager +from django.db.models.query import MAX_GET_RESULTS from django.db.models.query_utils import DeferredAttribute from strawberry import UNSET, relay from strawberry.annotation import StrawberryAnnotation @@ -282,30 +283,27 @@ def qs_hook(qs: models.QuerySet): # type: ignore def qs_hook(qs: models.QuerySet): # type: ignore qs = self.get_queryset(qs, info, **kwargs) - # Don't use qs.first() if the queryset is optimized by prefetching. - # Calling first in that case would disregard the prefetched results, because first implicitly - # adds a limit to the query - if is_optimized_by_prefetching(qs): - return next(iter(qs), None) return qs.first() else: def qs_hook(qs: models.QuerySet): qs = self.get_queryset(qs, info, **kwargs) - # See comment above about qs.first(), the same applies for get() + # Don't use qs.get() if the queryset is optimized by prefetching. + # Calling first in that case would disregard the prefetched results, because first implicitly + # adds a limit to the query if is_optimized_by_prefetching(qs): # mimic behavior of get() - qs_len = len( - qs - ) # the queryset is already prefetched, no issue with just using len() + # the queryset is already prefetched, no issue with just using len() + qs_len = len(qs) if qs_len == 0: raise qs.model.DoesNotExist( f"{qs.model._meta.object_name} matching query does not exist." ) if qs_len != 1: raise qs.model.MultipleObjectsReturned( - f"get() returned more than one {qs.model._meta.object_name} -- it returned {qs_len}!" + f"get() returned more than one {qs.model._meta.object_name} -- it returned " + f"{qs_len if qs_len < MAX_GET_RESULTS else f'more than {qs_len - 1}'}!" ) return qs[0] return qs.get() diff --git a/strawberry_django/pagination.py b/strawberry_django/pagination.py index 6d4f7a56..31be9dd2 100644 --- a/strawberry_django/pagination.py +++ b/strawberry_django/pagination.py @@ -6,6 +6,7 @@ from django.db import DEFAULT_DB_ALIAS from django.db.models import Count, QuerySet, Window from django.db.models.functions import RowNumber +from django.db.models.query import MAX_GET_RESULTS from strawberry.types import Info from strawberry.types.arguments import StrawberryArgument from strawberry.types.unset import UNSET, UnsetType @@ -343,6 +344,19 @@ def get_queryset( if self.is_paginated: return queryset + # Add implicit pagination if this field is not a list + # that way when first() / get() is called on the QuerySet it does not cause extra queries + if not pagination and not ( + self.is_list or self.is_paginated or self.is_connection + ): + if self.is_optional: + # first() applies order by pk if not ordered + if not queryset.ordered: + queryset = queryset.order_by("pk") + pagination = OffsetPaginationInput(offset=0, limit=1) + else: + pagination = OffsetPaginationInput(offset=0, limit=MAX_GET_RESULTS) + return self.apply_pagination( queryset, pagination,