Skip to content

Commit

Permalink
feat: New Paginated generic to be used as a wrapped for paginated res…
Browse files Browse the repository at this point in the history
…ults
  • Loading branch information
bellini666 committed Oct 13, 2024
1 parent 73fdf0a commit 7279deb
Show file tree
Hide file tree
Showing 7 changed files with 562 additions and 68 deletions.
16 changes: 15 additions & 1 deletion strawberry_django/fields/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,18 @@ def is_async(self) -> bool:

@functools.cached_property
def django_type(self) -> type[WithStrawberryDjangoObjectDefinition] | None:
from strawberry_django.pagination import Paginated

origin = self.type

if isinstance(origin, LazyType):
origin = origin.resolve_type()

object_definition = get_object_definition(origin)

if object_definition and issubclass(object_definition.origin, relay.Connection):
if object_definition and issubclass(
object_definition.origin, (relay.Connection, Paginated)
):
origin_specialized_type_var_map = (
get_specialized_type_var_map(cast(type, origin)) or {}
)
Expand Down Expand Up @@ -148,6 +152,16 @@ def is_list(self) -> bool:

return isinstance(type_, StrawberryList)

@functools.cached_property
def is_paginated(self) -> bool:
from strawberry_django.pagination import Paginated

type_ = self.type
if isinstance(type_, StrawberryOptional):
type_ = type_.of_type

return isinstance(type_, type) and issubclass(type_, Paginated)

@functools.cached_property
def is_connection(self) -> bool:
type_ = self.type
Expand Down
17 changes: 10 additions & 7 deletions strawberry_django/fields/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import dataclasses
import inspect
from functools import cached_property
from functools import cached_property, partial
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -229,9 +229,12 @@ async def async_resolver():
if "info" not in kwargs:
kwargs["info"] = info

resolved = await sync_to_async(self.get_queryset_hook(**kwargs))(
resolved,
)
@sync_to_async
def resolve():
inner_resolved = self.get_queryset_hook(**kwargs)(resolved)
return self.get_wrapped_result(inner_resolved, **kwargs)

resolved = await resolve()

return resolved

Expand All @@ -245,15 +248,15 @@ async def async_resolver():
kwargs["info"] = info

result = django_resolver(
lambda obj: obj,
partial(self.get_wrapped_result, **kwargs),
qs_hook=self.get_queryset_hook(**kwargs),
)(result)

return result

def get_queryset_hook(self, info: Info, **kwargs):
if self.is_connection:
# We don't want to fetch results yet, those will be done by the connection
if self.is_connection or self.is_paginated:
# We don't want to fetch results yet, those will be done by the connection/pagination
def qs_hook(qs: models.QuerySet): # type: ignore
return self.get_queryset(qs, info, **kwargs)

Expand Down
1 change: 1 addition & 0 deletions strawberry_django/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,7 @@ def arguments(self) -> List[StrawberryArgument]:
and is_root_query
and not self.is_list
and not self.is_connection
and not self.is_paginated
):
settings = strawberry_django_settings()
arguments.append(
Expand Down
139 changes: 115 additions & 24 deletions strawberry_django/pagination.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,64 @@
import sys
from typing import TYPE_CHECKING, List, Optional, TypeVar, Union
import warnings
from typing import Generic, List, Optional, TypeVar, Union, cast

import strawberry
from django.db import DEFAULT_DB_ALIAS
from django.db.models import Count, Window
from django.db.models import Count, QuerySet, Window
from django.db.models.functions import RowNumber
from strawberry.types import Info
from strawberry.types.arguments import StrawberryArgument
from strawberry.types.unset import UNSET, UnsetType
from typing_extensions import Self

from strawberry_django.fields.base import StrawberryDjangoFieldBase
from strawberry_django.resolvers import django_resolver

from .arguments import argument

if TYPE_CHECKING:
from django.db.models import QuerySet
NodeType = TypeVar("NodeType")
_T = TypeVar("_T")
_QS = TypeVar("_QS", bound=QuerySet)

_QS = TypeVar("_QS", bound="QuerySet")
DEFAULT_OFFSET: int = 0
DEFAULT_LIMIT: int = -1


@strawberry.input
class OffsetPaginationInput:
offset: int = 0
limit: int = -1
offset: int = DEFAULT_OFFSET
limit: int = DEFAULT_LIMIT


@strawberry.type
class Paginated(Generic[NodeType]):
queryset: strawberry.Private[QuerySet]
pagination: strawberry.Private[OffsetPaginationInput]

@strawberry.field
def limit(self) -> int:
return self.pagination.limit

@strawberry.field
def offset(self) -> int:
return self.pagination.limit

@strawberry.field(description="Total count of existing results.")
@django_resolver
def total_count(self) -> int:
return get_total_count(self.queryset)

@strawberry.field(description="List of paginated results.")
@django_resolver
def results(self) -> List[NodeType]:
from strawberry_django.optimizer import is_optimized_by_prefetching

if is_optimized_by_prefetching(self.queryset):
results = self.queryset._result_cache # type: ignore
else:
results = apply(self.pagination, self.queryset)

return cast(List[NodeType], results)


def apply(
Expand Down Expand Up @@ -59,8 +94,11 @@ def apply(
)
else:
start = pagination.offset
stop = start + pagination.limit
queryset = queryset[start:stop]
if pagination.limit >= 0:
stop = start + pagination.limit
queryset = queryset[start:stop]
else:
queryset = queryset[start:]

return queryset

Expand Down Expand Up @@ -116,6 +154,32 @@ def apply_window_pagination(
return queryset


def get_total_count(queryset: QuerySet) -> int:
"""Get the total count of a queryset.
Try to get the total count from the queryset cache, if it's optimized by
prefetching. Otherwise, fallback to the `QuerySet.count()` method.
"""
from strawberry_django.optimizer import is_optimized_by_prefetching

if is_optimized_by_prefetching(queryset):
results = queryset._result_cache # type: ignore

try:
return results[0]._strawberry_total_count if results else 0
except AttributeError:
warnings.warn(
(
"Pagination annotations not found, falling back to QuerySet resolution. "
"This might cause n+1 issues..."
),
RuntimeWarning,
stacklevel=2,
)

return queryset.count()


class StrawberryDjangoPagination(StrawberryDjangoFieldBase):
def __init__(self, pagination: Union[bool, UnsetType] = UNSET, **kwargs):
self.pagination = pagination
Expand All @@ -126,10 +190,25 @@ def __copy__(self) -> Self:
new_field.pagination = self.pagination
return new_field

def _has_pagination(self) -> bool:
if isinstance(self.pagination, bool):
return self.pagination

if self.is_paginated:
return True

django_type = self.django_type
if django_type is not None and not issubclass(
django_type, strawberry.relay.Node
):
return django_type.__strawberry_django_definition__.pagination

return False

@property
def arguments(self) -> List[StrawberryArgument]:
arguments = []
if self.base_resolver is None and self.is_list:
if self.base_resolver is None and (self.is_list or self.is_paginated):
pagination = self.get_pagination()
if pagination is not None:
arguments.append(
Expand All @@ -143,20 +222,7 @@ def arguments(self, value: List[StrawberryArgument]):
return args_prop.fset(self, value) # type: ignore

def get_pagination(self) -> Optional[type]:
has_pagination = self.pagination

if isinstance(has_pagination, UnsetType):
django_type = self.django_type
has_pagination = (
django_type.__strawberry_django_definition__.pagination
if (
django_type is not None
and not issubclass(django_type, strawberry.relay.Node)
)
else False
)

return OffsetPaginationInput if has_pagination else None
return OffsetPaginationInput if self._has_pagination() else None

def apply_pagination(
self,
Expand All @@ -182,3 +248,28 @@ def get_queryset(
pagination,
related_field_id=_strawberry_related_field_id,
)

def get_wrapped_result(
self,
result: _T,
info: Info,
*,
pagination: Optional[object] = None,
**kwargs,
) -> Union[_T, Paginated[_T]]:
if not self.is_paginated:
return result

if not isinstance(result, QuerySet):
raise TypeError(f"Result expected to be a queryset, got {result!r}")

if (
pagination not in (None, UNSET) # noqa: PLR6201
and not isinstance(pagination, OffsetPaginationInput)
):
raise TypeError(f"Don't know how to resolve pagination {pagination!r}")

return Paginated(
queryset=result,
pagination=pagination or OffsetPaginationInput(),
)
36 changes: 4 additions & 32 deletions strawberry_django/relay.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from strawberry.utils.await_maybe import AwaitableOrValue
from typing_extensions import Literal, Self

from strawberry_django.pagination import get_total_count
from strawberry_django.queryset import run_type_get_queryset
from strawberry_django.resolvers import django_getattr, django_resolver
from strawberry_django.utils.typing import (
Expand All @@ -51,41 +52,12 @@ class ListConnectionWithTotalCount(relay.ListConnection[relay.NodeType]):
@strawberry.field(description="Total quantity of existing nodes.")
@django_resolver
def total_count(self) -> Optional[int]:
from .optimizer import is_optimized_by_prefetching

assert self.nodes is not None

if isinstance(self.nodes, models.QuerySet) and is_optimized_by_prefetching(
self.nodes
):
result = cast(List[relay.NodeType], self.nodes._result_cache) # type: ignore
try:
return (
result[0]._strawberry_total_count # type: ignore
if result
else 0
)
except AttributeError:
warnings.warn(
(
"Pagination annotations not found, falling back to QuerySet resolution. "
"This might cause n+1 issues..."
),
RuntimeWarning,
stacklevel=2,
)
if isinstance(self.nodes, models.QuerySet):
return get_total_count(self.nodes)

total_count = None
try:
total_count = cast(
"models.QuerySet[models.Model]",
self.nodes,
).count()
except (AttributeError, ValueError, TypeError):
if isinstance(self.nodes, Sized):
total_count = len(self.nodes)

return total_count
return len(self.nodes) if isinstance(self.nodes, Sized) else None

@classmethod
def resolve_connection(
Expand Down
8 changes: 4 additions & 4 deletions strawberry_django/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,12 @@
from .fields.types import get_model_field, resolve_model_field_name
from .settings import strawberry_django_settings as django_settings

__all = [
"StrawberryDjangoType",
"type",
"interface",
__all__ = [
"StrawberryDjangoDefinition",
"input",
"interface",
"partial",
"type",
]

_T = TypeVar("_T", bound=type)
Expand Down
Loading

0 comments on commit 7279deb

Please sign in to comment.