From 9705ab00e8c992b4d087729a8530459973838315 Mon Sep 17 00:00:00 2001 From: Thiago Bellini Ribeiro Date: Sun, 20 Oct 2024 14:49:27 -0300 Subject: [PATCH] Refactor Paginated to be used similarly to how Connection is used --- docs/guide/pagination.md | 85 ++++- strawberry_django/__init__.py | 3 +- strawberry_django/fields/field.py | 343 +++++++++++++++-- strawberry_django/pagination.py | 34 +- strawberry_django/permissions.py | 8 +- tests/projects/schema.py | 8 +- tests/projects/snapshots/schema.gql | 2 +- .../snapshots/schema_with_inheritance.gql | 4 +- tests/test_paginated_type.py | 355 +++++++++++++++++- 9 files changed, 751 insertions(+), 91 deletions(-) diff --git a/docs/guide/pagination.md b/docs/guide/pagination.md index 0e8453ee..67beb920 100644 --- a/docs/guide/pagination.md +++ b/docs/guide/pagination.md @@ -28,7 +28,7 @@ type Fruit { input OffsetPaginationInput { offset: Int! = 0 - limit: Int! = -1 + limit: Int = null } type Query { @@ -107,7 +107,7 @@ class Fruit: @strawberry.type class Query: - fruits: OffsetPaginated[Fruit] = strawberry_django.field() + fruits: OffsetPaginated[Fruit] = strawberry_django.offset_paginated() ``` Would produce the following schema: @@ -118,7 +118,7 @@ type Fruit { } type PaginationInfo { - limit: Int! + limit: Int = null offset: Int! } @@ -130,7 +130,7 @@ type FruitOffsetPaginated { input OffsetPaginationInput { offset: Int! = 0 - limit: Int! = -1 + limit: Int = null } type Query { @@ -159,6 +159,77 @@ query { > OffsetPaginated follow the same rules for the default pagination limit, and can be configured > in the same way as explained above. +### Customizing queryset resolver + +It is possible to define a custom resolver for the queryset to either provide a custom +queryset for it, or even to receive extra arguments alongside the pagination arguments. + +Suppose we want to pre-filter a queryset of fruits for only available ones, +while also adding [ordering](./ordering.md) to it. This can be achieved like: + +```python title="types.py" + +@strawberry_django.type(models.Fruit) +class Fruit: + name: auto + price: auto + + +@strawberry_django.order(models.Fruit) +class FruitOrder: + name: auto + price: auto + + +@strawberry.type +class Query: + @straberry.offset_paginated(OffsetPaginated[Fruit], order=order) + def fruits(self, only_available: bool = True) -> QuerySet[Fruit]: + queryset = models.Fruit.objects.all() + if only_available: + queryset = queryset.filter(available=True) + + return queryset +``` + +This would produce the following schema: + +```graphql title="schema.graphql" +type Fruit { + name: String! + price: Decimal! +} + +type FruitOrder { + name: Ordering + price: Ordering +} + +type PaginationInfo { + limit: Int! + offset: Int! +} + +type FruitOffsetPaginated { + pageInfo: PaginationInfo! + totalCount: Int! + results: [Fruit]! +} + +input OffsetPaginationInput { + offset: Int! = 0 + limit: Int = null +} + +type Query { + fruits( + onlyAvailable: Boolean! = true + pagination: OffsetPaginationInput + order: FruitOrder + ): [FruitOffsetPaginated!]! +} +``` + ### Customizing the pagination Like other generics, `OffsetPaginated` can be customized to modify its behavior or to @@ -195,7 +266,7 @@ class FruitOffsetPaginated(OffsetPaginated[Fruit]): @strawberry.type class Query: - fruits: FruitOffsetPaginated = strawberry_django.field() + fruits: FruitOffsetPaginated = strawberry_django.offset_paginated() ``` Would produce the following schema: @@ -206,7 +277,7 @@ type Fruit { } type PaginationInfo { - limit: Int! + limit: Int = null offset: Int! } @@ -220,7 +291,7 @@ type FruitOffsetPaginated { input OffsetPaginationInput { offset: Int! = 0 - limit: Int! = -1 + limit: Int = null } type Query { diff --git a/strawberry_django/__init__.py b/strawberry_django/__init__.py index 30e2ac69..74e0a557 100644 --- a/strawberry_django/__init__.py +++ b/strawberry_django/__init__.py @@ -1,5 +1,5 @@ from . import auth, filters, mutations, ordering, pagination, relay -from .fields.field import connection, field, node +from .fields.field import connection, field, node, offset_paginated from .fields.filter_order import filter_field, order_field from .fields.filter_types import ( BaseFilterLookup, @@ -60,6 +60,7 @@ "mutation", "mutations", "node", + "offset_paginated", "order", "order_field", "ordering", diff --git a/strawberry_django/fields/field.py b/strawberry_django/fields/field.py index 864d36fa..7806e77a 100644 --- a/strawberry_django/fields/field.py +++ b/strawberry_django/fields/field.py @@ -2,13 +2,21 @@ import dataclasses import inspect -from collections.abc import Iterable, Mapping, Sequence -from functools import cached_property, partial +from collections.abc import ( + AsyncIterable, + AsyncIterator, + Iterable, + Iterator, + Mapping, + Sequence, +) +from functools import cached_property from typing import ( TYPE_CHECKING, Any, Callable, TypeVar, + Union, cast, overload, ) @@ -26,9 +34,12 @@ from django.db.models.query_utils import DeferredAttribute from strawberry import UNSET, relay from strawberry.annotation import StrawberryAnnotation +from strawberry.extensions.field_extension import FieldExtension +from strawberry.types.field import _RESOLVER_TYPE # noqa: PLC2701 from strawberry.types.fields.resolver import StrawberryResolver from strawberry.types.info import Info # noqa: TCH002 from strawberry.utils.await_maybe import await_maybe +from typing_extensions import TypeAlias from strawberry_django import optimizer from strawberry_django.arguments import argument @@ -37,7 +48,12 @@ from strawberry_django.filters import FILTERS_ARG, StrawberryDjangoFieldFilters from strawberry_django.optimizer import OptimizerStore, is_optimized_by_prefetching from strawberry_django.ordering import ORDER_ARG, StrawberryDjangoFieldOrdering -from strawberry_django.pagination import StrawberryDjangoPagination +from strawberry_django.pagination import ( + PAGINATION_ARG, + OffsetPaginated, + OffsetPaginationInput, + StrawberryDjangoPagination, +) from strawberry_django.permissions import filter_with_perms from strawberry_django.queryset import run_type_get_queryset from strawberry_django.relay import resolve_model_nodes @@ -51,13 +67,11 @@ if TYPE_CHECKING: from graphql.pyutils import AwaitableOrValue from strawberry import BasePermission - from strawberry.extensions.field_extension import ( - FieldExtension, - SyncExtensionResolver, - ) + from strawberry.extensions.field_extension import SyncExtensionResolver from strawberry.relay.types import NodeIterableType from strawberry.types.arguments import StrawberryArgument - from strawberry.types.field import _RESOLVER_TYPE, StrawberryField + from strawberry.types.base import WithStrawberryObjectDefinition + from strawberry.types.field import StrawberryField from strawberry.types.unset import UnsetType from typing_extensions import Literal, Self @@ -228,12 +242,9 @@ async def async_resolver(): if "info" not in kwargs: kwargs["info"] = info - @sync_to_async - def resolve(resolved=resolved): - inner_resolved = self.get_queryset_hook(**kwargs)(resolved) - return self.get_wrapped_result(inner_resolved, **kwargs) - - resolved = await resolve() + resolved = await sync_to_async(self.get_queryset_hook(**kwargs))( + resolved + ) return resolved @@ -248,7 +259,7 @@ def resolve(resolved=resolved): result = django_resolver( self.get_queryset_hook(**kwargs), - qs_hook=partial(self.get_wrapped_result, **kwargs), + qs_hook=lambda qs: qs, )(result) return result @@ -300,6 +311,44 @@ def get_queryset(self, queryset, info, **kwargs): return queryset +def _get_field_arguments_for_extensions( + field: StrawberryDjangoField, + *, + add_filters: bool = True, + add_order: bool = True, + add_pagination: bool = True, +) -> list[StrawberryArgument]: + """Get a list of arguments to be set to fields using extensions. + + Because we have a base_resolver defined in those, our parents will not add + order/filters/pagination resolvers in here, so we need to add them by hand (unless they + are somewhat in there). We are not adding pagination because it doesn't make + sense together with a Connection + """ + args: dict[str, StrawberryArgument] = {a.python_name: a for a in field.arguments} + + if add_filters and FILTERS_ARG not in args: + filters = field.get_filters() + if filters not in (None, UNSET): # noqa: PLR6201 + args[FILTERS_ARG] = argument(FILTERS_ARG, filters, is_optional=True) + + if add_order and ORDER_ARG not in args: + order = field.get_order() + if order not in (None, UNSET): # noqa: PLR6201 + args[ORDER_ARG] = argument(ORDER_ARG, order, is_optional=True) + + if add_pagination and PAGINATION_ARG not in args: + pagination = field.get_pagination() + if pagination not in (None, UNSET): # noqa: PLR6201 + args[PAGINATION_ARG] = argument( + PAGINATION_ARG, + pagination, + is_optional=True, + ) + + return list(args.values()) + + class StrawberryDjangoConnectionExtension(relay.ConnectionExtension): def apply(self, field: StrawberryField) -> None: if not isinstance(field, StrawberryDjangoField): @@ -307,24 +356,10 @@ def apply(self, field: StrawberryField) -> None: "The extension can only be applied to StrawberryDjangoField" ) - # NOTE: Because we have a base_resolver defined, our parents will not add - # order/filters resolvers in here, so we need to add them by hand (unless they - # are somewhat in there). We are not adding pagination because it doesn't make - # sense together with a Connection - args: dict[str, StrawberryArgument] = { - a.python_name: a for a in field.arguments - } - - if FILTERS_ARG not in args: - filters = field.get_filters() - if filters not in (None, UNSET): # noqa: PLR6201 - args[FILTERS_ARG] = argument(FILTERS_ARG, filters, is_optional=True) - if ORDER_ARG not in args: - order = field.get_order() - if order not in (None, UNSET): # noqa: PLR6201 - args[ORDER_ARG] = argument(ORDER_ARG, order, is_optional=True) - - field.arguments = list(args.values()) + field.arguments = _get_field_arguments_for_extensions( + field, + add_pagination=False, + ) if field.base_resolver is None: @@ -412,6 +447,67 @@ async def async_resolver(): ) +class StrawberryOffsetPaginatedExtension(FieldExtension): + paginated_type: type[OffsetPaginated] + + def apply(self, field: StrawberryField) -> None: + if not isinstance(field, StrawberryDjangoField): + raise TypeError( + "The extension can only be applied to StrawberryDjangoField" + ) + + field.arguments = _get_field_arguments_for_extensions(field) + self.paginated_type = cast(type[OffsetPaginated], field.type) + + def resolve( + self, + next_: SyncExtensionResolver, + source: Any, + info: Info, + *, + pagination: OffsetPaginationInput | None = None, + order: WithStrawberryObjectDefinition | None = None, + filters: WithStrawberryObjectDefinition | None = None, + **kwargs: Any, + ) -> Any: + assert self.paginated_type is not None + queryset = cast(models.QuerySet, next_(source, info, **kwargs)) + + def get_queryset(queryset): + return cast(StrawberryDjangoField, info._field).get_queryset( + queryset, + info, + pagination=pagination, + order=order, + filters=filters, + ) + + # We have a single resolver for both sync and async, so we need to check if + # nodes is awaitable or not and resolve it accordingly + if inspect.isawaitable(queryset): + + async def async_resolver(queryset=queryset): + resolved = self.paginated_type.resolve_paginated( + get_queryset(await queryset), + info=info, + pagination=pagination, + **kwargs, + ) + if inspect.isawaitable(resolved): + resolved = await resolved + + return resolved + + return async_resolver() + + return self.paginated_type.resolve_paginated( + get_queryset(queryset), + info=info, + pagination=pagination, + **kwargs, + ) + + @overload def field( *, @@ -805,3 +901,184 @@ def connection( f = f(resolver) return f + + +_OFFSET_PAGINATED_RESOLVER_TYPE: TypeAlias = _RESOLVER_TYPE[ + Union[ + Iterator[models.Model], + Iterable[models.Model], + AsyncIterator[models.Model], + AsyncIterable[models.Model], + ] +] + + +@overload +def offset_paginated( + graphql_type: type[OffsetPaginated] | None = None, + *, + field_cls: type[StrawberryDjangoField] = StrawberryDjangoField, + name: str | None = None, + field_name: str | None = None, + is_subscription: bool = False, + description: str | None = None, + permission_classes: list[type[BasePermission]] | None = None, + deprecation_reason: str | None = None, + default: Any = dataclasses.MISSING, + default_factory: Callable[..., object] | object = dataclasses.MISSING, + metadata: Mapping[Any, Any] | None = None, + directives: Sequence[object] | None = (), + extensions: Sequence[FieldExtension] = (), + filters: type | None = UNSET, + order: type | None = UNSET, + only: TypeOrSequence[str] | None = None, + select_related: TypeOrSequence[str] | None = None, + prefetch_related: TypeOrSequence[PrefetchType] | None = None, + annotate: TypeOrMapping[AnnotateType] | None = None, + disable_optimization: bool = False, +) -> Any: ... + + +@overload +def offset_paginated( + graphql_type: type[OffsetPaginated] | None = None, + *, + field_cls: type[StrawberryDjangoField] = StrawberryDjangoField, + resolver: _OFFSET_PAGINATED_RESOLVER_TYPE | None = None, + name: str | None = None, + field_name: str | None = None, + is_subscription: bool = False, + description: str | None = None, + init: Literal[True] = True, + permission_classes: list[type[BasePermission]] | None = None, + deprecation_reason: str | None = None, + default: Any = dataclasses.MISSING, + default_factory: Callable[..., object] | object = dataclasses.MISSING, + metadata: Mapping[Any, Any] | None = None, + directives: Sequence[object] | None = (), + extensions: Sequence[FieldExtension] = (), + filters: type | None = UNSET, + order: type | None = UNSET, + only: TypeOrSequence[str] | None = None, + select_related: TypeOrSequence[str] | None = None, + prefetch_related: TypeOrSequence[PrefetchType] | None = None, + annotate: TypeOrMapping[AnnotateType] | None = None, + disable_optimization: bool = False, +) -> Any: ... + + +def offset_paginated( + graphql_type: type[OffsetPaginated] | None = None, + *, + field_cls: type[StrawberryDjangoField] = StrawberryDjangoField, + resolver: _OFFSET_PAGINATED_RESOLVER_TYPE | None = None, + name: str | None = None, + field_name: str | None = None, + is_subscription: bool = False, + description: str | None = None, + permission_classes: list[type[BasePermission]] | None = None, + deprecation_reason: str | None = None, + default: Any = dataclasses.MISSING, + default_factory: Callable[..., object] | object = dataclasses.MISSING, + metadata: Mapping[Any, Any] | None = None, + directives: Sequence[object] | None = (), + extensions: Sequence[FieldExtension] = (), + filters: type | None = UNSET, + order: type | None = UNSET, + only: TypeOrSequence[str] | None = None, + select_related: TypeOrSequence[str] | None = None, + prefetch_related: TypeOrSequence[PrefetchType] | None = None, + annotate: TypeOrMapping[AnnotateType] | None = None, + disable_optimization: bool = False, + # This init parameter is used by pyright to determine whether this field + # is added in the constructor or not. It is not used to change + # any behavior at the moment. + init: Literal[True, False, None] = None, +) -> Any: + """Annotate a property or a method to create a relay connection field. + + Relay connections_ are mostly used for pagination purposes. This decorator + helps creating a complete relay endpoint that provides default arguments + and has a default implementation for the connection slicing. + + Note that when setting a resolver to this field, it is expected for this + resolver to return an iterable of the expected node type, not the connection + itself. That iterable will then be paginated accordingly. So, the main use + case for this is to provide a filtered iterable of nodes by using some custom + filter arguments. + + Examples + -------- + Annotating something like this: + + >>> @strawberry.type + >>> class X: + ... some_node: relay.Connection[SomeType] = relay.connection( + ... description="ABC", + ... ) + ... + ... @relay.connection(description="ABC") + ... def get_some_nodes(self, age: int) -> Iterable[SomeType]: + ... ... + + Will produce a query like this: + + ``` + query { + someNode ( + before: String + after: String + first: String + after: String + age: Int + ) { + totalCount + pageInfo { + hasNextPage + hasPreviousPage + startCursor + endCursor + } + edges { + cursor + node { + id + ... + } + } + } + } + ``` + + .. _Relay connections: + https://relay.dev/graphql/connections.htm + + """ + extensions = [*extensions, StrawberryOffsetPaginatedExtension()] + f = field_cls( + python_name=None, + django_name=field_name, + graphql_name=name, + type_annotation=StrawberryAnnotation.from_annotation(graphql_type), + description=description, + is_subscription=is_subscription, + permission_classes=permission_classes or [], + deprecation_reason=deprecation_reason, + default=default, + default_factory=default_factory, + metadata=metadata, + directives=directives or (), + filters=filters, + order=order, + extensions=extensions, + only=only, + select_related=select_related, + prefetch_related=prefetch_related, + annotate=annotate, + disable_optimization=disable_optimization, + ) + + if resolver: + f = f(resolver) + + return f diff --git a/strawberry_django/pagination.py b/strawberry_django/pagination.py index 0e9b72f2..32a00fb4 100644 --- a/strawberry_django/pagination.py +++ b/strawberry_django/pagination.py @@ -17,9 +17,10 @@ from .arguments import argument NodeType = TypeVar("NodeType") -_T = TypeVar("_T") _QS = TypeVar("_QS", bound=QuerySet) +PAGINATION_ARG = "pagination" + @strawberry.input class OffsetPaginationInput: @@ -338,34 +339,3 @@ def get_queryset( pagination, related_field_id=_strawberry_related_field_id, ) - - def get_wrapped_result( - self, - result: _T, - info: Info, - *, - pagination: Optional[OffsetPaginationInput] = None, - **kwargs, - ) -> Union[_T, OffsetPaginated[_T]]: - if not self.is_paginated: - return result - - if not isinstance(result, QuerySet): - raise TypeError(f"Result expected to be a queryset, got {result!r}") - - if ( - pagination not in (None, UNSET) # noqa: PLR6201 - and not isinstance(pagination, OffsetPaginationInput) - ): - raise TypeError(f"Don't know how to resolve pagination {pagination!r}") - - paginated_type = self.type - assert isinstance(paginated_type, type) - assert issubclass(paginated_type, OffsetPaginated) - - return paginated_type.resolve_paginated( - result, - info=info, - pagination=pagination, - **kwargs, - ) diff --git a/strawberry_django/permissions.py b/strawberry_django/permissions.py index 7b56b2d1..2b1a60f8 100644 --- a/strawberry_django/permissions.py +++ b/strawberry_django/permissions.py @@ -39,7 +39,7 @@ from strawberry_django.auth.utils import aget_current_user, get_current_user from strawberry_django.fields.types import OperationInfo, OperationMessage -from strawberry_django.pagination import OffsetPaginated, OffsetPaginationInput +from strawberry_django.pagination import OffsetPaginated from strawberry_django.resolvers import django_resolver from .utils.query import filter_for_user @@ -48,6 +48,8 @@ if TYPE_CHECKING: from strawberry.django.context import StrawberryDjangoContext + from strawberry_django.fields.field import StrawberryDjangoField + _M = TypeVar("_M", bound=Model) @@ -407,7 +409,9 @@ def handle_no_permission(self, exception: BaseException, *, info: Info): return [] if isinstance(ret_type, type) and issubclass(ret_type, OffsetPaginated): - return OffsetPaginated(queryset=None, pagination=OffsetPaginationInput()) + django_model = cast("StrawberryDjangoField", info._field).django_model + assert django_model + return django_model._default_manager.none() # If it is a Connection, try to return an empty connection, but only if # it is the only possibility available... diff --git a/tests/projects/schema.py b/tests/projects/schema.py index 7502f482..d87b3f35 100644 --- a/tests/projects/schema.py +++ b/tests/projects/schema.py @@ -159,7 +159,7 @@ class MilestoneType(relay.Node, Named): order=IssueOrder, pagination=True, ) - issues_paginated: OffsetPaginated["IssueType"] = strawberry_django.field( + issues_paginated: OffsetPaginated["IssueType"] = strawberry_django.offset_paginated( field_name="issues", order=IssueOrder, ) @@ -380,7 +380,7 @@ class Query: staff_list: list[Optional[StaffType]] = strawberry_django.node() issue_list: list[IssueType] = strawberry_django.field() - issues_paginated: OffsetPaginated[IssueType] = strawberry_django.field() + issues_paginated: OffsetPaginated[IssueType] = strawberry_django.offset_paginated() milestone_list: list[MilestoneType] = strawberry_django.field( order=MilestoneOrder, filters=MilestoneFilter, @@ -436,7 +436,7 @@ class Query: extensions=[HasPerm(perms=["projects.view_issue"])], ) issues_paginated_perm_required: OffsetPaginated[IssueType] = ( - strawberry_django.field( + strawberry_django.offset_paginated( extensions=[HasPerm(perms=["projects.view_issue"])], ) ) @@ -459,7 +459,7 @@ class Query: extensions=[HasRetvalPerm(perms=["projects.view_issue"])], pagination=True ) issues_paginated_obj_perm_required: OffsetPaginated[IssueType] = ( - strawberry_django.field( + strawberry_django.offset_paginated( extensions=[HasRetvalPerm(perms=["projects.view_issue"])], ) ) diff --git a/tests/projects/snapshots/schema.gql b/tests/projects/snapshots/schema.gql index 68c7936f..b5f31d33 100644 --- a/tests/projects/snapshots/schema.gql +++ b/tests/projects/snapshots/schema.gql @@ -383,7 +383,7 @@ type MilestoneType implements Node & Named { dueDate: Date project: ProjectType! issues(filters: IssueFilter, order: IssueOrder, pagination: OffsetPaginationInput): [IssueType!]! - issuesPaginated(pagination: OffsetPaginationInput): IssueTypeOffsetPaginated! + issuesPaginated(pagination: OffsetPaginationInput, order: IssueOrder): IssueTypeOffsetPaginated! issuesWithFilters( filters: IssueFilter diff --git a/tests/projects/snapshots/schema_with_inheritance.gql b/tests/projects/snapshots/schema_with_inheritance.gql index f104ce92..352c8e1b 100644 --- a/tests/projects/snapshots/schema_with_inheritance.gql +++ b/tests/projects/snapshots/schema_with_inheritance.gql @@ -173,7 +173,7 @@ type MilestoneType implements Node & Named { dueDate: Date project: ProjectType! issues(filters: IssueFilter, order: IssueOrder, pagination: OffsetPaginationInput): [IssueType!]! - issuesPaginated(pagination: OffsetPaginationInput): IssueTypeOffsetPaginated! + issuesPaginated(pagination: OffsetPaginationInput, order: IssueOrder): IssueTypeOffsetPaginated! issuesWithFilters( filters: IssueFilter @@ -201,7 +201,7 @@ type MilestoneTypeSubclass implements Node & Named { dueDate: Date project: ProjectType! issues(filters: IssueFilter, order: IssueOrder, pagination: OffsetPaginationInput): [IssueType!]! - issuesPaginated(pagination: OffsetPaginationInput): IssueTypeOffsetPaginated! + issuesPaginated(pagination: OffsetPaginationInput, order: IssueOrder): IssueTypeOffsetPaginated! issuesWithFilters( filters: IssueFilter diff --git a/tests/test_paginated_type.py b/tests/test_paginated_type.py index 16813bb0..74537727 100644 --- a/tests/test_paginated_type.py +++ b/tests/test_paginated_type.py @@ -2,6 +2,7 @@ import pytest import strawberry +from django.db.models import QuerySet import strawberry_django from strawberry_django.pagination import OffsetPaginated, OffsetPaginationInput @@ -22,8 +23,8 @@ class Color: @strawberry.type class Query: - fruits: OffsetPaginated[Fruit] = strawberry_django.field() - colors: OffsetPaginated[Color] = strawberry_django.field() + fruits: OffsetPaginated[Fruit] = strawberry_django.offset_paginated() + colors: OffsetPaginated[Color] = strawberry_django.offset_paginated() schema = strawberry.Schema(query=Query) @@ -87,7 +88,7 @@ class Fruit: @strawberry.type class Query: - fruits: OffsetPaginated[Fruit] = strawberry_django.field() + fruits: OffsetPaginated[Fruit] = strawberry_django.offset_paginated() models.Fruit.objects.create(name="Apple") models.Fruit.objects.create(name="Banana") @@ -145,7 +146,7 @@ class Fruit: @strawberry.type class Query: - fruits: OffsetPaginated[Fruit] = strawberry_django.field() + fruits: OffsetPaginated[Fruit] = strawberry_django.offset_paginated() await models.Fruit.objects.acreate(name="Apple") await models.Fruit.objects.acreate(name="Banana") @@ -205,11 +206,11 @@ class Fruit: class Color: id: int name: str - fruits: OffsetPaginated[Fruit] + fruits: OffsetPaginated[Fruit] = strawberry_django.offset_paginated() @strawberry.type class Query: - colors: OffsetPaginated[Color] = strawberry_django.field() + colors: OffsetPaginated[Color] = strawberry_django.offset_paginated() red = models.Color.objects.create(name="Red") yellow = models.Color.objects.create(name="Yellow") @@ -316,11 +317,11 @@ class Fruit: class Color: id: int name: str - fruits: OffsetPaginated[Fruit] + fruits: OffsetPaginated[Fruit] = strawberry_django.offset_paginated() @strawberry.type class Query: - colors: OffsetPaginated[Color] = strawberry_django.field() + colors: OffsetPaginated[Color] = strawberry_django.offset_paginated() red = await models.Color.objects.acreate(name="Red") yellow = await models.Color.objects.acreate(name="Yellow") @@ -441,7 +442,7 @@ def resolve_paginated(cls, queryset, *, info, pagination=None, **kwargs): @strawberry.type class Query: - fruits: FruitPaginated = strawberry_django.field() + fruits: FruitPaginated = strawberry_django.offset_paginated() models.Fruit.objects.create(name="Apple") models.Fruit.objects.create(name="Banana") @@ -492,3 +493,339 @@ class Query: "results": [{"name": "Strawberry"}], } } + + +@pytest.mark.django_db(transaction=True) +def test_pagination_query_with_resolver_schema(): + @strawberry_django.type(models.Fruit) + class Fruit: + id: int + name: str + + @strawberry_django.filter(models.Fruit) + class FruitFilter: + name: str + + @strawberry_django.order(models.Fruit) + class FruitOrder: + name: str + + @strawberry.type + class Query: + @strawberry_django.offset_paginated(OffsetPaginated[Fruit]) + def fruits(self) -> QuerySet[models.Fruit]: ... + + @strawberry_django.offset_paginated( + OffsetPaginated[Fruit], + filters=FruitFilter, + order=FruitOrder, + ) + def fruits_with_order_and_filter(self) -> QuerySet[models.Fruit]: ... + + schema = strawberry.Schema(query=Query) + + expected = ''' + type Fruit { + id: Int! + name: String! + } + + input FruitFilter { + name: String! + AND: FruitFilter + OR: FruitFilter + NOT: FruitFilter + DISTINCT: Boolean + } + + type FruitOffsetPaginated { + pageInfo: OffsetPaginationInfo! + + """Total count of existing results.""" + totalCount: Int! + + """List of paginated results.""" + results: [Fruit!]! + } + + input FruitOrder { + name: String + } + + type OffsetPaginationInfo { + offset: Int! + limit: Int + } + + input OffsetPaginationInput { + offset: Int! = 0 + limit: Int = null + } + + type Query { + fruits(pagination: OffsetPaginationInput): FruitOffsetPaginated! + fruitsWithOrderAndFilter(filters: FruitFilter, order: FruitOrder, pagination: OffsetPaginationInput): FruitOffsetPaginated! + } + ''' + + assert textwrap.dedent(str(schema)) == textwrap.dedent(expected).strip() + + +@pytest.mark.django_db(transaction=True) +def test_pagination_query_with_resolver(): + @strawberry_django.type(models.Fruit) + class Fruit: + id: int + name: str + + @strawberry_django.filter(models.Fruit) + class FruitFilter: + name: strawberry.auto + + @strawberry_django.order(models.Fruit) + class FruitOrder: + name: strawberry.auto + + @strawberry.type + class Query: + @strawberry_django.offset_paginated(OffsetPaginated[Fruit]) + def fruits(self) -> QuerySet[models.Fruit]: + return models.Fruit.objects.filter(name__startswith="S") + + @strawberry_django.offset_paginated( + OffsetPaginated[Fruit], + filters=FruitFilter, + order=FruitOrder, + ) + def fruits_with_order_and_filter(self) -> QuerySet[models.Fruit]: + return models.Fruit.objects.filter(name__startswith="S") + + models.Fruit.objects.create(name="Apple") + models.Fruit.objects.create(name="Strawberry") + models.Fruit.objects.create(name="Banana") + models.Fruit.objects.create(name="Sugar Apple") + models.Fruit.objects.create(name="Starfruit") + + schema = strawberry.Schema(query=Query) + + query = """\ + query GetFruits ( + $pagination: OffsetPaginationInput + $filters: FruitFilter + $order: FruitOrder + ) { + fruits (pagination: $pagination) { + totalCount + results { + name + } + } + fruitsWithOrderAndFilter ( + pagination: $pagination + filters: $filters + order: $order + ) { + totalCount + results { + name + } + } + } + """ + + res = schema.execute_sync(query) + assert res.errors is None + assert res.data == { + "fruits": { + "totalCount": 3, + "results": [ + {"name": "Strawberry"}, + {"name": "Sugar Apple"}, + {"name": "Starfruit"}, + ], + }, + "fruitsWithOrderAndFilter": { + "totalCount": 3, + "results": [ + {"name": "Strawberry"}, + {"name": "Sugar Apple"}, + {"name": "Starfruit"}, + ], + }, + } + + res = schema.execute_sync(query, variable_values={"pagination": {"limit": 1}}) + assert res.errors is None + assert res.data == { + "fruits": { + "totalCount": 3, + "results": [ + {"name": "Strawberry"}, + ], + }, + "fruitsWithOrderAndFilter": { + "totalCount": 3, + "results": [ + {"name": "Strawberry"}, + ], + }, + } + + res = schema.execute_sync( + query, + variable_values={ + "pagination": {"limit": 2}, + "order": {"name": "ASC"}, + "filters": {"name": "Strawberry"}, + }, + ) + assert res.errors is None + assert res.data == { + "fruits": { + "totalCount": 3, + "results": [ + {"name": "Strawberry"}, + {"name": "Sugar Apple"}, + ], + }, + "fruitsWithOrderAndFilter": { + "totalCount": 1, + "results": [ + {"name": "Strawberry"}, + ], + }, + } + + +@pytest.mark.django_db(transaction=True) +def test_pagination_query_with_resolver_arguments(): + @strawberry_django.type(models.Fruit) + class Fruit: + id: int + name: str + + @strawberry_django.filter(models.Fruit) + class FruitFilter: + name: strawberry.auto + + @strawberry_django.order(models.Fruit) + class FruitOrder: + name: strawberry.auto + + @strawberry.type + class Query: + @strawberry_django.offset_paginated(OffsetPaginated[Fruit]) + def fruits(self, starts_with: str) -> QuerySet[models.Fruit]: + return models.Fruit.objects.filter(name__startswith=starts_with) + + @strawberry_django.offset_paginated( + OffsetPaginated[Fruit], + filters=FruitFilter, + order=FruitOrder, + ) + def fruits_with_order_and_filter( + self, starts_with: str + ) -> QuerySet[models.Fruit]: + return models.Fruit.objects.filter(name__startswith=starts_with) + + models.Fruit.objects.create(name="Apple") + models.Fruit.objects.create(name="Strawberry") + models.Fruit.objects.create(name="Banana") + models.Fruit.objects.create(name="Sugar Apple") + models.Fruit.objects.create(name="Starfruit") + + schema = strawberry.Schema(query=Query) + + query = """\ + query GetFruits ( + $pagination: OffsetPaginationInput + $filters: FruitFilter + $order: FruitOrder + $startsWith: String! + ) { + fruits (startsWith: $startsWith, pagination: $pagination) { + totalCount + results { + name + } + } + fruitsWithOrderAndFilter ( + startsWith: $startsWith + pagination: $pagination + filters: $filters + order: $order + ) { + totalCount + results { + name + } + } + } + """ + + res = schema.execute_sync(query, variable_values={"startsWith": "S"}) + assert res.errors is None + assert res.data == { + "fruits": { + "totalCount": 3, + "results": [ + {"name": "Strawberry"}, + {"name": "Sugar Apple"}, + {"name": "Starfruit"}, + ], + }, + "fruitsWithOrderAndFilter": { + "totalCount": 3, + "results": [ + {"name": "Strawberry"}, + {"name": "Sugar Apple"}, + {"name": "Starfruit"}, + ], + }, + } + + res = schema.execute_sync( + query, + variable_values={"startsWith": "S", "pagination": {"limit": 1}}, + ) + assert res.errors is None + assert res.data == { + "fruits": { + "totalCount": 3, + "results": [ + {"name": "Strawberry"}, + ], + }, + "fruitsWithOrderAndFilter": { + "totalCount": 3, + "results": [ + {"name": "Strawberry"}, + ], + }, + } + + res = schema.execute_sync( + query, + variable_values={ + "startsWith": "S", + "pagination": {"limit": 2}, + "order": {"name": "ASC"}, + "filters": {"name": "Strawberry"}, + }, + ) + assert res.errors is None + assert res.data == { + "fruits": { + "totalCount": 3, + "results": [ + {"name": "Strawberry"}, + {"name": "Sugar Apple"}, + ], + }, + "fruitsWithOrderAndFilter": { + "totalCount": 1, + "results": [ + {"name": "Strawberry"}, + ], + }, + }