From 6d8fec378fa0409f6a9e98d631a4ec2ff3bcae2e Mon Sep 17 00:00:00 2001 From: Thiago Bellini Ribeiro Date: Sat, 9 Nov 2024 23:17:49 +0100 Subject: [PATCH] feat: New Paginated generic to be used as a wrapped for paginated results (#642) --- docs/guide/pagination.md | 296 ++++++- docs/guide/settings.md | 7 +- strawberry_django/__init__.py | 3 +- strawberry_django/fields/base.py | 16 +- strawberry_django/fields/field.py | 340 ++++++- strawberry_django/filters.py | 1 + strawberry_django/optimizer.py | 79 +- strawberry_django/pagination.py | 213 ++++- strawberry_django/permissions.py | 8 + strawberry_django/relay.py | 36 +- strawberry_django/resolvers.py | 4 +- strawberry_django/settings.py | 7 +- strawberry_django/type.py | 8 +- tests/projects/schema.py | 16 + tests/projects/snapshots/schema.gql | 21 +- .../snapshots/schema_with_inheritance.gql | 19 +- tests/test_optimizer.py | 180 ++++ tests/test_paginated_type.py | 831 ++++++++++++++++++ tests/test_permissions.py | 167 +++- tests/test_settings.py | 2 + 20 files changed, 2142 insertions(+), 112 deletions(-) create mode 100644 tests/test_paginated_type.py diff --git a/docs/guide/pagination.md b/docs/guide/pagination.md index 339c199a..89296d84 100644 --- a/docs/guide/pagination.md +++ b/docs/guide/pagination.md @@ -12,20 +12,306 @@ An interface for limit/offset pagination can be use for basic pagination needs: @strawberry_django.type(models.Fruit, pagination=True) class Fruit: name: auto + + +@strawberry.type +class Query: + fruits: list[Fruit] = strawberry_django.field() +``` + +Would produce the following schema: + +```graphql title="schema.graphql" +type Fruit { + name: String! +} + +input OffsetPaginationInput { + offset: Int! = 0 + limit: Int = null +} + +type Query { + fruits(pagination: OffsetPaginationInput): [Fruit!]! +} ``` +And can be queried like: + ```graphql title="schema.graphql" query { fruits(pagination: { offset: 0, limit: 2 }) { name - color } } ``` -There is not default limit defined. All elements are returned if no pagination limit is defined. +The `pagination` argument can be given to the type, which will enforce the pagination +argument every time the field is annotated as a list, but you can also give it directly +to the field for more control, like: + +```python title="types.py" +@strawberry_django.type(models.Fruit) +class Fruit: + name: auto + + +@strawberry.type +class Query: + fruits: list[Fruit] = strawberry_django.field(pagination=True) +``` + +Which will produce the exact same schema. + +### Default limit for pagination + +The default limit for pagination is set to `100`. This can be changed in the +[strawberry django settings](./settings.md) to increase or decrease that number, +or even set to `None` to set it to unlimited. + +To configure it on a per field basis, you can define your own `OffsetPaginationInput` +subclass and modify its default value, like: + +```python +@strawberry.input +def MyOffsetPaginationInput(OffsetPaginationInput): + limit: int = 250 + + +# Pass it to the pagination argument when defining the type +@strawberry_django.type(models.Fruit, pagination=MyOffsetPaginationInput) +class Fruit: + ... + + +@strawberry.type +class Query: + # Or pass it to the pagination argument when defining the field + fruits: list[Fruit] = strawberry_django.field(pagination=MyOffsetPaginationInput) +``` + +## OffsetPaginated Generic + +For more complex pagination needs, you can use the `OffsetPaginated` generic, which alongside +the `pagination` argument, will wrap the results in an object that contains the results +and the pagination information, together with the `totalCount` of elements excluding pagination. + +```python title="types.py" +from strawberry_django.pagination import OffsetPaginated + + +@strawberry_django.type(models.Fruit) +class Fruit: + name: auto + + +@strawberry.type +class Query: + fruits: OffsetPaginated[Fruit] = strawberry_django.offset_paginated() +``` + +Would produce the following schema: + +```graphql title="schema.graphql" +type Fruit { + name: String! +} + +type PaginationInfo { + limit: Int = null + offset: Int! +} + +type FruitOffsetPaginated { + pageInfo: PaginationInfo! + totalCount: Int! + results: [Fruit]! +} + +input OffsetPaginationInput { + offset: Int! = 0 + limit: Int = null +} + +type Query { + fruits(pagination: OffsetPaginationInput): [FruitOffsetPaginated!]! +} +``` + +Which can be queried like: + +```graphql title="schema.graphql" +query { + fruits(pagination: { offset: 0, limit: 2 }) { + totalCount + pageInfo { + limit + offset + } + results { + name + } + } +} +``` + +> [!NOTE] +> OffsetPaginated follow the same rules for the default pagination limit, and can be configured +> in the same way as explained above. + +### Customizing queryset resolver + +It is possible to define a custom resolver for the queryset to either provide a custom +queryset for it, or even to receive extra arguments alongside the pagination arguments. + +Suppose we want to pre-filter a queryset of fruits for only available ones, +while also adding [ordering](./ordering.md) to it. This can be achieved with: + +```python title="types.py" + +@strawberry_django.type(models.Fruit) +class Fruit: + name: auto + price: auto + + +@strawberry_django.order(models.Fruit) +class FruitOrder: + name: auto + price: auto + + +@strawberry.type +class Query: + @strawberry_django.offset_paginated(OffsetPaginated[Fruit], order=order) + def fruits(self, only_available: bool = True) -> QuerySet[Fruit]: + queryset = models.Fruit.objects.all() + if only_available: + queryset = queryset.filter(available=True) + + return queryset +``` + +This would produce the following schema: + +```graphql title="schema.graphql" +type Fruit { + name: String! + price: Decimal! +} + +type FruitOrder { + name: Ordering + price: Ordering +} + +type PaginationInfo { + limit: Int! + offset: Int! +} + +type FruitOffsetPaginated { + pageInfo: PaginationInfo! + totalCount: Int! + results: [Fruit]! +} + +input OffsetPaginationInput { + offset: Int! = 0 + limit: Int = null +} + +type Query { + fruits( + onlyAvailable: Boolean! = true + pagination: OffsetPaginationInput + order: FruitOrder + ): [FruitOffsetPaginated!]! +} +``` + +### Customizing the pagination + +Like other generics, `OffsetPaginated` can be customized to modify its behavior or to +add extra functionality in it. For example, suppose we want to add the average +price of the fruits in the pagination: + +```python title="types.py" +from strawberry_django.pagination import OffsetPaginated + + +@strawberry_django.type(models.Fruit) +class Fruit: + name: auto + price: auto + + +@strawberry.type +class FruitOffsetPaginated(OffsetPaginated[Fruit]): + @strawberry_django.field + def average_price(self) -> Decimal: + if self.queryset is None: + return Decimal(0) + + return self.queryset.aggregate(Avg("price"))["price__avg"] + + @strawberry_django.field + def paginated_average_price(self) -> Decimal: + paginated_queryset = self.get_paginated_queryset() + if paginated_queryset is None: + return Decimal(0) + + return paginated_queryset.aggregate(Avg("price"))["price__avg"] + + +@strawberry.type +class Query: + fruits: FruitOffsetPaginated = strawberry_django.offset_paginated() +``` + +Would produce the following schema: + +```graphql title="schema.graphql" +type Fruit { + name: String! +} + +type PaginationInfo { + limit: Int = null + offset: Int! +} + +type FruitOffsetPaginated { + pageInfo: PaginationInfo! + totalCount: Int! + results: [Fruit]! + averagePrice: Decimal! + paginatedAveragePrice: Decimal! +} + +input OffsetPaginationInput { + offset: Int! = 0 + limit: Int = null +} + +type Query { + fruits(pagination: OffsetPaginationInput): [FruitOffsetPaginated!]! +} +``` + +The following attributes/methods can be accessed in the `OffsetPaginated` class: + +- `queryset`: The queryset original queryset with any filters/ordering applied, + but not paginated yet +- `pagination`: The `OffsetPaginationInput` object, with the `offset` and `limit` for pagination +- `get_total_count()`: Returns the total count of elements in the queryset without pagination +- `get_paginated_queryset()`: Returns the queryset with pagination applied +- `resolve_paginated(queryset, *, info, pagiantion, **kwargs)`: The classmethod that + strawberry-django calls to create an instance of the `OffsetPaginated` class/subclass. -## Relay pagination +## Cursor pagination (aka Relay style pagination) -For more complex scenarios, a cursor pagination would be better. For this, -use the [relay integration](./relay.md) to define those. +Another option for pagination is to use a +[relay style cursor pagination](https://graphql.org/learn/pagination). For this, +you can leverage the [relay integration](./relay.md) provided by strawberry +to create a relay connection. diff --git a/docs/guide/settings.md b/docs/guide/settings.md index d8e41973..afaf8a75 100644 --- a/docs/guide/settings.md +++ b/docs/guide/settings.md @@ -62,7 +62,11 @@ A dictionary with the following optional keys: If True, [legacy filters](filters.md#legacy-filtering) are enabled. This is usefull for migrating from previous version. -These features can be enabled by adding this code to your `settings.py` file. +- **`PAGINATION_DEFAULT_LIMIT`** (default: `100`) + + Defualt limit for [pagination](pagination.md) when one is not provided by the client. Can be set to `None` to set it to unlimited. + +These features can be enabled by adding this code to your `settings.py` file, like: ```python title="settings.py" STRAWBERRY_DJANGO = { @@ -73,5 +77,6 @@ STRAWBERRY_DJANGO = { "GENERATE_ENUMS_FROM_CHOICES": False, "MAP_AUTO_ID_AS_GLOBAL_ID": True, "DEFAULT_PK_FIELD_NAME": "id", + "PAGINATION_DEFAULT_LIMIT": 250, } ``` diff --git a/strawberry_django/__init__.py b/strawberry_django/__init__.py index 30e2ac69..74e0a557 100644 --- a/strawberry_django/__init__.py +++ b/strawberry_django/__init__.py @@ -1,5 +1,5 @@ from . import auth, filters, mutations, ordering, pagination, relay -from .fields.field import connection, field, node +from .fields.field import connection, field, node, offset_paginated from .fields.filter_order import filter_field, order_field from .fields.filter_types import ( BaseFilterLookup, @@ -60,6 +60,7 @@ "mutation", "mutations", "node", + "offset_paginated", "order", "order_field", "ordering", diff --git a/strawberry_django/fields/base.py b/strawberry_django/fields/base.py index 07b11891..c2e5e2e4 100644 --- a/strawberry_django/fields/base.py +++ b/strawberry_django/fields/base.py @@ -85,6 +85,8 @@ def is_async(self) -> bool: @functools.cached_property def django_type(self) -> type[WithStrawberryDjangoObjectDefinition] | None: + from strawberry_django.pagination import OffsetPaginated + origin = self.type if isinstance(origin, LazyType): @@ -92,7 +94,9 @@ def django_type(self) -> type[WithStrawberryDjangoObjectDefinition] | None: object_definition = get_object_definition(origin) - if object_definition and issubclass(object_definition.origin, relay.Connection): + if object_definition and issubclass( + object_definition.origin, (relay.Connection, OffsetPaginated) + ): origin_specialized_type_var_map = ( get_specialized_type_var_map(cast(type, origin)) or {} ) @@ -148,6 +152,16 @@ def is_list(self) -> bool: return isinstance(type_, StrawberryList) + @functools.cached_property + def is_paginated(self) -> bool: + from strawberry_django.pagination import OffsetPaginated + + type_ = self.type + if isinstance(type_, StrawberryOptional): + type_ = type_.of_type + + return isinstance(type_, type) and issubclass(type_, OffsetPaginated) + @functools.cached_property def is_connection(self) -> bool: type_ = self.type diff --git a/strawberry_django/fields/field.py b/strawberry_django/fields/field.py index e1d921f5..7806e77a 100644 --- a/strawberry_django/fields/field.py +++ b/strawberry_django/fields/field.py @@ -2,13 +2,21 @@ import dataclasses import inspect -from collections.abc import Iterable, Mapping, Sequence +from collections.abc import ( + AsyncIterable, + AsyncIterator, + Iterable, + Iterator, + Mapping, + Sequence, +) from functools import cached_property from typing import ( TYPE_CHECKING, Any, Callable, TypeVar, + Union, cast, overload, ) @@ -26,9 +34,12 @@ from django.db.models.query_utils import DeferredAttribute from strawberry import UNSET, relay from strawberry.annotation import StrawberryAnnotation +from strawberry.extensions.field_extension import FieldExtension +from strawberry.types.field import _RESOLVER_TYPE # noqa: PLC2701 from strawberry.types.fields.resolver import StrawberryResolver from strawberry.types.info import Info # noqa: TCH002 from strawberry.utils.await_maybe import await_maybe +from typing_extensions import TypeAlias from strawberry_django import optimizer from strawberry_django.arguments import argument @@ -37,7 +48,12 @@ from strawberry_django.filters import FILTERS_ARG, StrawberryDjangoFieldFilters from strawberry_django.optimizer import OptimizerStore, is_optimized_by_prefetching from strawberry_django.ordering import ORDER_ARG, StrawberryDjangoFieldOrdering -from strawberry_django.pagination import StrawberryDjangoPagination +from strawberry_django.pagination import ( + PAGINATION_ARG, + OffsetPaginated, + OffsetPaginationInput, + StrawberryDjangoPagination, +) from strawberry_django.permissions import filter_with_perms from strawberry_django.queryset import run_type_get_queryset from strawberry_django.relay import resolve_model_nodes @@ -51,13 +67,11 @@ if TYPE_CHECKING: from graphql.pyutils import AwaitableOrValue from strawberry import BasePermission - from strawberry.extensions.field_extension import ( - FieldExtension, - SyncExtensionResolver, - ) + from strawberry.extensions.field_extension import SyncExtensionResolver from strawberry.relay.types import NodeIterableType from strawberry.types.arguments import StrawberryArgument - from strawberry.types.field import _RESOLVER_TYPE, StrawberryField + from strawberry.types.base import WithStrawberryObjectDefinition + from strawberry.types.field import StrawberryField from strawberry.types.unset import UnsetType from typing_extensions import Literal, Self @@ -229,7 +243,7 @@ async def async_resolver(): kwargs["info"] = info resolved = await sync_to_async(self.get_queryset_hook(**kwargs))( - resolved, + resolved ) return resolved @@ -244,15 +258,15 @@ async def async_resolver(): kwargs["info"] = info result = django_resolver( - lambda obj: obj, - qs_hook=self.get_queryset_hook(**kwargs), + self.get_queryset_hook(**kwargs), + qs_hook=lambda qs: qs, )(result) return result def get_queryset_hook(self, info: Info, **kwargs): - if self.is_connection: - # We don't want to fetch results yet, those will be done by the connection + if self.is_connection or self.is_paginated: + # We don't want to fetch results yet, those will be done by the connection/pagination def qs_hook(qs: models.QuerySet): # type: ignore return self.get_queryset(qs, info, **kwargs) @@ -297,6 +311,44 @@ def get_queryset(self, queryset, info, **kwargs): return queryset +def _get_field_arguments_for_extensions( + field: StrawberryDjangoField, + *, + add_filters: bool = True, + add_order: bool = True, + add_pagination: bool = True, +) -> list[StrawberryArgument]: + """Get a list of arguments to be set to fields using extensions. + + Because we have a base_resolver defined in those, our parents will not add + order/filters/pagination resolvers in here, so we need to add them by hand (unless they + are somewhat in there). We are not adding pagination because it doesn't make + sense together with a Connection + """ + args: dict[str, StrawberryArgument] = {a.python_name: a for a in field.arguments} + + if add_filters and FILTERS_ARG not in args: + filters = field.get_filters() + if filters not in (None, UNSET): # noqa: PLR6201 + args[FILTERS_ARG] = argument(FILTERS_ARG, filters, is_optional=True) + + if add_order and ORDER_ARG not in args: + order = field.get_order() + if order not in (None, UNSET): # noqa: PLR6201 + args[ORDER_ARG] = argument(ORDER_ARG, order, is_optional=True) + + if add_pagination and PAGINATION_ARG not in args: + pagination = field.get_pagination() + if pagination not in (None, UNSET): # noqa: PLR6201 + args[PAGINATION_ARG] = argument( + PAGINATION_ARG, + pagination, + is_optional=True, + ) + + return list(args.values()) + + class StrawberryDjangoConnectionExtension(relay.ConnectionExtension): def apply(self, field: StrawberryField) -> None: if not isinstance(field, StrawberryDjangoField): @@ -304,24 +356,10 @@ def apply(self, field: StrawberryField) -> None: "The extension can only be applied to StrawberryDjangoField" ) - # NOTE: Because we have a base_resolver defined, our parents will not add - # order/filters resolvers in here, so we need to add them by hand (unless they - # are somewhat in there). We are not adding pagination because it doesn't make - # sense together with a Connection - args: dict[str, StrawberryArgument] = { - a.python_name: a for a in field.arguments - } - - if FILTERS_ARG not in args: - filters = field.get_filters() - if filters not in (None, UNSET): # noqa: PLR6201 - args[FILTERS_ARG] = argument(FILTERS_ARG, filters, is_optional=True) - if ORDER_ARG not in args: - order = field.get_order() - if order not in (None, UNSET): # noqa: PLR6201 - args[ORDER_ARG] = argument(ORDER_ARG, order, is_optional=True) - - field.arguments = list(args.values()) + field.arguments = _get_field_arguments_for_extensions( + field, + add_pagination=False, + ) if field.base_resolver is None: @@ -409,6 +447,67 @@ async def async_resolver(): ) +class StrawberryOffsetPaginatedExtension(FieldExtension): + paginated_type: type[OffsetPaginated] + + def apply(self, field: StrawberryField) -> None: + if not isinstance(field, StrawberryDjangoField): + raise TypeError( + "The extension can only be applied to StrawberryDjangoField" + ) + + field.arguments = _get_field_arguments_for_extensions(field) + self.paginated_type = cast(type[OffsetPaginated], field.type) + + def resolve( + self, + next_: SyncExtensionResolver, + source: Any, + info: Info, + *, + pagination: OffsetPaginationInput | None = None, + order: WithStrawberryObjectDefinition | None = None, + filters: WithStrawberryObjectDefinition | None = None, + **kwargs: Any, + ) -> Any: + assert self.paginated_type is not None + queryset = cast(models.QuerySet, next_(source, info, **kwargs)) + + def get_queryset(queryset): + return cast(StrawberryDjangoField, info._field).get_queryset( + queryset, + info, + pagination=pagination, + order=order, + filters=filters, + ) + + # We have a single resolver for both sync and async, so we need to check if + # nodes is awaitable or not and resolve it accordingly + if inspect.isawaitable(queryset): + + async def async_resolver(queryset=queryset): + resolved = self.paginated_type.resolve_paginated( + get_queryset(await queryset), + info=info, + pagination=pagination, + **kwargs, + ) + if inspect.isawaitable(resolved): + resolved = await resolved + + return resolved + + return async_resolver() + + return self.paginated_type.resolve_paginated( + get_queryset(queryset), + info=info, + pagination=pagination, + **kwargs, + ) + + @overload def field( *, @@ -802,3 +901,184 @@ def connection( f = f(resolver) return f + + +_OFFSET_PAGINATED_RESOLVER_TYPE: TypeAlias = _RESOLVER_TYPE[ + Union[ + Iterator[models.Model], + Iterable[models.Model], + AsyncIterator[models.Model], + AsyncIterable[models.Model], + ] +] + + +@overload +def offset_paginated( + graphql_type: type[OffsetPaginated] | None = None, + *, + field_cls: type[StrawberryDjangoField] = StrawberryDjangoField, + name: str | None = None, + field_name: str | None = None, + is_subscription: bool = False, + description: str | None = None, + permission_classes: list[type[BasePermission]] | None = None, + deprecation_reason: str | None = None, + default: Any = dataclasses.MISSING, + default_factory: Callable[..., object] | object = dataclasses.MISSING, + metadata: Mapping[Any, Any] | None = None, + directives: Sequence[object] | None = (), + extensions: Sequence[FieldExtension] = (), + filters: type | None = UNSET, + order: type | None = UNSET, + only: TypeOrSequence[str] | None = None, + select_related: TypeOrSequence[str] | None = None, + prefetch_related: TypeOrSequence[PrefetchType] | None = None, + annotate: TypeOrMapping[AnnotateType] | None = None, + disable_optimization: bool = False, +) -> Any: ... + + +@overload +def offset_paginated( + graphql_type: type[OffsetPaginated] | None = None, + *, + field_cls: type[StrawberryDjangoField] = StrawberryDjangoField, + resolver: _OFFSET_PAGINATED_RESOLVER_TYPE | None = None, + name: str | None = None, + field_name: str | None = None, + is_subscription: bool = False, + description: str | None = None, + init: Literal[True] = True, + permission_classes: list[type[BasePermission]] | None = None, + deprecation_reason: str | None = None, + default: Any = dataclasses.MISSING, + default_factory: Callable[..., object] | object = dataclasses.MISSING, + metadata: Mapping[Any, Any] | None = None, + directives: Sequence[object] | None = (), + extensions: Sequence[FieldExtension] = (), + filters: type | None = UNSET, + order: type | None = UNSET, + only: TypeOrSequence[str] | None = None, + select_related: TypeOrSequence[str] | None = None, + prefetch_related: TypeOrSequence[PrefetchType] | None = None, + annotate: TypeOrMapping[AnnotateType] | None = None, + disable_optimization: bool = False, +) -> Any: ... + + +def offset_paginated( + graphql_type: type[OffsetPaginated] | None = None, + *, + field_cls: type[StrawberryDjangoField] = StrawberryDjangoField, + resolver: _OFFSET_PAGINATED_RESOLVER_TYPE | None = None, + name: str | None = None, + field_name: str | None = None, + is_subscription: bool = False, + description: str | None = None, + permission_classes: list[type[BasePermission]] | None = None, + deprecation_reason: str | None = None, + default: Any = dataclasses.MISSING, + default_factory: Callable[..., object] | object = dataclasses.MISSING, + metadata: Mapping[Any, Any] | None = None, + directives: Sequence[object] | None = (), + extensions: Sequence[FieldExtension] = (), + filters: type | None = UNSET, + order: type | None = UNSET, + only: TypeOrSequence[str] | None = None, + select_related: TypeOrSequence[str] | None = None, + prefetch_related: TypeOrSequence[PrefetchType] | None = None, + annotate: TypeOrMapping[AnnotateType] | None = None, + disable_optimization: bool = False, + # This init parameter is used by pyright to determine whether this field + # is added in the constructor or not. It is not used to change + # any behavior at the moment. + init: Literal[True, False, None] = None, +) -> Any: + """Annotate a property or a method to create a relay connection field. + + Relay connections_ are mostly used for pagination purposes. This decorator + helps creating a complete relay endpoint that provides default arguments + and has a default implementation for the connection slicing. + + Note that when setting a resolver to this field, it is expected for this + resolver to return an iterable of the expected node type, not the connection + itself. That iterable will then be paginated accordingly. So, the main use + case for this is to provide a filtered iterable of nodes by using some custom + filter arguments. + + Examples + -------- + Annotating something like this: + + >>> @strawberry.type + >>> class X: + ... some_node: relay.Connection[SomeType] = relay.connection( + ... description="ABC", + ... ) + ... + ... @relay.connection(description="ABC") + ... def get_some_nodes(self, age: int) -> Iterable[SomeType]: + ... ... + + Will produce a query like this: + + ``` + query { + someNode ( + before: String + after: String + first: String + after: String + age: Int + ) { + totalCount + pageInfo { + hasNextPage + hasPreviousPage + startCursor + endCursor + } + edges { + cursor + node { + id + ... + } + } + } + } + ``` + + .. _Relay connections: + https://relay.dev/graphql/connections.htm + + """ + extensions = [*extensions, StrawberryOffsetPaginatedExtension()] + f = field_cls( + python_name=None, + django_name=field_name, + graphql_name=name, + type_annotation=StrawberryAnnotation.from_annotation(graphql_type), + description=description, + is_subscription=is_subscription, + permission_classes=permission_classes or [], + deprecation_reason=deprecation_reason, + default=default, + default_factory=default_factory, + metadata=metadata, + directives=directives or (), + filters=filters, + order=order, + extensions=extensions, + only=only, + select_related=select_related, + prefetch_related=prefetch_related, + annotate=annotate, + disable_optimization=disable_optimization, + ) + + if resolver: + f = f(resolver) + + return f diff --git a/strawberry_django/filters.py b/strawberry_django/filters.py index dd5d789a..9084a850 100644 --- a/strawberry_django/filters.py +++ b/strawberry_django/filters.py @@ -309,6 +309,7 @@ def arguments(self) -> list[StrawberryArgument]: and is_root_query and not self.is_list and not self.is_connection + and not self.is_paginated ): settings = strawberry_django_settings() arguments.append( diff --git a/strawberry_django/optimizer.py b/strawberry_django/optimizer.py index cf3b6d86..956627fa 100644 --- a/strawberry_django/optimizer.py +++ b/strawberry_django/optimizer.py @@ -49,7 +49,7 @@ from typing_extensions import assert_never, assert_type from strawberry_django.fields.types import resolve_model_field_name -from strawberry_django.pagination import apply_window_pagination +from strawberry_django.pagination import OffsetPaginated, apply_window_pagination from strawberry_django.queryset import get_queryset_config, run_type_get_queryset from strawberry_django.relay import ListConnectionWithTotalCount from strawberry_django.resolvers import django_fetch @@ -528,7 +528,7 @@ def _optimize_prefetch_queryset( ) field_kwargs.pop("info", None) - # Disable the optimizer to avoid doint double optimization while running get_queryset + # Disable the optimizer to avoid doing double optimization while running get_queryset with DjangoOptimizerExtension.disabled(): qs = field.get_queryset( qs, @@ -574,6 +574,15 @@ def _optimize_prefetch_queryset( else: mark_optimized = False + if isinstance(field.type, type) and issubclass(field.type, OffsetPaginated): + pagination = field_kwargs.get("pagination") + qs = apply_window_pagination( + qs, + related_field_id=related_field_id, + offset=pagination.offset if pagination else 0, + limit=pagination.limit if pagination else -1, + ) + if mark_optimized: qs = mark_optimized_by_prefetching(qs) @@ -977,8 +986,7 @@ def _get_model_hints( ) -> OptimizerStore | None: cache = cache or {} - # In case this is a relay field, find the selected edges/nodes, the selected fields - # are actually inside edges -> node selection... + # In case this is a relay field, the selected fields are inside edges -> node selection if issubclass(object_definition.origin, relay.Connection): return _get_model_hints_from_connection( model, @@ -992,6 +1000,20 @@ def _get_model_hints( level=level, ) + # In case this is a Paginated field, the selected fields are inside results selection + if issubclass(object_definition.origin, OffsetPaginated): + return _get_model_hints_from_paginated( + model, + schema, + object_definition, + parent_type=parent_type, + info=info, + config=config, + prefix=prefix, + cache=cache, + level=level, + ) + store = OptimizerStore() config = config or OptimizerConfig() @@ -1156,6 +1178,55 @@ def _get_model_hints_from_connection( return store +def _get_model_hints_from_paginated( + model: type[models.Model], + schema: Schema, + object_definition: StrawberryObjectDefinition, + *, + parent_type: GraphQLObjectType | GraphQLInterfaceType, + info: GraphQLResolveInfo, + config: OptimizerConfig | None = None, + prefix: str = "", + cache: dict[type[models.Model], list[tuple[int, OptimizerStore]]] | None = None, + level: int = 0, +) -> OptimizerStore | None: + store = None + + n_type = object_definition.type_var_map.get("NodeType") + n_definition = get_object_definition(n_type, strict=True) + n_gql_definition = _get_gql_definition( + schema, + get_object_definition(n_type, strict=True), + ) + assert isinstance(n_gql_definition, (GraphQLObjectType, GraphQLInterfaceType)) + + for selections in _get_selections(info, parent_type).values(): + selection = selections[0] + if selection.name.value != "results": + continue + + n_info = _generate_selection_resolve_info( + info, + selections, + n_gql_definition, + n_gql_definition, + ) + + store = _get_model_hints( + model=model, + schema=schema, + object_definition=n_definition, + parent_type=n_gql_definition, + info=n_info, + config=config, + prefix=prefix, + cache=cache, + level=level, + ) + + return store + + def optimize( qs: QuerySet[_M] | BaseManager[_M], info: GraphQLResolveInfo | Info, diff --git a/strawberry_django/pagination.py b/strawberry_django/pagination.py index ef224db0..32a00fb4 100644 --- a/strawberry_django/pagination.py +++ b/strawberry_django/pagination.py @@ -1,9 +1,10 @@ import sys -from typing import TYPE_CHECKING, Optional, TypeVar, Union +import warnings +from typing import Generic, Optional, TypeVar, Union, cast import strawberry from django.db import DEFAULT_DB_ALIAS -from django.db.models import Count, Window +from django.db.models import Count, QuerySet, Window from django.db.models.functions import RowNumber from strawberry.types import Info from strawberry.types.arguments import StrawberryArgument @@ -11,19 +12,100 @@ from typing_extensions import Self from strawberry_django.fields.base import StrawberryDjangoFieldBase +from strawberry_django.resolvers import django_resolver from .arguments import argument -if TYPE_CHECKING: - from django.db.models import QuerySet +NodeType = TypeVar("NodeType") +_QS = TypeVar("_QS", bound=QuerySet) -_QS = TypeVar("_QS", bound="QuerySet") +PAGINATION_ARG = "pagination" @strawberry.input class OffsetPaginationInput: offset: int = 0 - limit: int = -1 + limit: Optional[int] = None + + +@strawberry.type +class OffsetPaginationInfo: + offset: int = 0 + limit: Optional[int] = None + + +@strawberry.type +class OffsetPaginated(Generic[NodeType]): + queryset: strawberry.Private[Optional[QuerySet]] + pagination: strawberry.Private[OffsetPaginationInput] + + @strawberry.field + def page_info(self) -> OffsetPaginationInfo: + return OffsetPaginationInfo( + limit=self.pagination.limit, + offset=self.pagination.offset, + ) + + @strawberry.field(description="Total count of existing results.") + @django_resolver + def total_count(self) -> int: + return self.get_total_count() + + @strawberry.field(description="List of paginated results.") + @django_resolver + def results(self) -> list[NodeType]: + paginated_queryset = self.get_paginated_queryset() + + return cast( + list[NodeType], paginated_queryset if paginated_queryset is not None else [] + ) + + @classmethod + def resolve_paginated( + cls, + queryset: QuerySet, + *, + info: Info, + pagination: Optional[OffsetPaginationInput] = None, + **kwargs, + ) -> Self: + """Resolve the paginated queryset. + + Args: + queryset: The queryset to be paginated. + info: The strawberry execution info resolve the type name from. + pagination: The pagination input to be applied. + kwargs: Additional arguments passed to the resolver. + + Returns: + The resolved `OffsetPaginated` + + """ + return cls( + queryset=queryset, + pagination=pagination or OffsetPaginationInput(), + ) + + def get_total_count(self) -> int: + """Retrieve tht total count of the queryset without pagination.""" + return get_total_count(self.queryset) if self.queryset is not None else 0 + + def get_paginated_queryset(self) -> Optional[QuerySet]: + """Retrieve the queryset with pagination applied. + + This will apply the paginated arguments to the queryset and return it. + To use the original queryset, access `.queryset` directly. + """ + from strawberry_django.optimizer import is_optimized_by_prefetching + + if self.queryset is None: + return None + + return ( + self.queryset._result_cache # type: ignore + if is_optimized_by_prefetching(self.queryset) + else apply(self.pagination, self.queryset) + ) def apply( @@ -59,18 +141,30 @@ def apply( ) else: start = pagination.offset - stop = start + pagination.limit - queryset = queryset[start:stop] + if pagination.limit is not None and pagination.limit >= 0: + stop = start + pagination.limit + queryset = queryset[start:stop] + else: + queryset = queryset[start:] return queryset +class _PaginationWindow(Window): + """Window function to be used for pagination. + + This is the same as django's `Window` function, but we can easily identify + it in case we need to remove it from the queryset, as there might be other + window functions in the queryset and no other way to identify ours. + """ + + def apply_window_pagination( queryset: _QS, *, related_field_id: str, offset: int = 0, - limit: int = -1, + limit: Optional[int] = None, ) -> _QS: """Apply pagination using window functions. @@ -93,13 +187,14 @@ def apply_window_pagination( using=queryset._db or DEFAULT_DB_ALIAS # type: ignore ).get_order_by() ] + queryset = queryset.annotate( - _strawberry_row_number=Window( + _strawberry_row_number=_PaginationWindow( RowNumber(), partition_by=related_field_id, order_by=order_by, ), - _strawberry_total_count=Window( + _strawberry_total_count=_PaginationWindow( Count(1), partition_by=related_field_id, ), @@ -110,12 +205,64 @@ def apply_window_pagination( # Limit == -1 means no limit. sys.maxsize is set by relay when paginating # from the end to as a way to mimic a "not limit" as well - if limit >= 0 and limit != sys.maxsize: + if limit is not None and limit >= 0 and limit != sys.maxsize: queryset = queryset.filter(_strawberry_row_number__lte=offset + limit) return queryset +def remove_window_pagination(queryset: _QS) -> _QS: + """Remove pagination window functions from a queryset. + + Utility function to remove the pagination `WHERE` clause added by + the `apply_window_pagination` function. + + Args: + ---- + queryset: The queryset to apply pagination to. + + """ + queryset = queryset._chain() # type: ignore + queryset.query.where.children = [ + child + for child in queryset.query.where.children + if (not hasattr(child, "lhs") or not isinstance(child.lhs, _PaginationWindow)) + ] + return queryset + + +def get_total_count(queryset: QuerySet) -> int: + """Get the total count of a queryset. + + Try to get the total count from the queryset cache, if it's optimized by + prefetching. Otherwise, fallback to the `QuerySet.count()` method. + """ + from strawberry_django.optimizer import is_optimized_by_prefetching + + if is_optimized_by_prefetching(queryset): + results = queryset._result_cache # type: ignore + + if results: + try: + return results[0]._strawberry_total_count + except AttributeError: + warnings.warn( + ( + "Pagination annotations not found, falling back to QuerySet resolution. " + "This might cause n+1 issues..." + ), + RuntimeWarning, + stacklevel=2, + ) + + # If we have no results, we can't get the total count from the cache. + # In this case we will remove the pagination filter to be able to `.count()` + # the whole queryset with its original filters. + queryset = remove_window_pagination(queryset) + + return queryset.count() + + class StrawberryDjangoPagination(StrawberryDjangoFieldBase): def __init__(self, pagination: Union[bool, UnsetType] = UNSET, **kwargs): self.pagination = pagination @@ -126,10 +273,25 @@ def __copy__(self) -> Self: new_field.pagination = self.pagination return new_field + def _has_pagination(self) -> bool: + if isinstance(self.pagination, bool): + return self.pagination + + if self.is_paginated: + return True + + django_type = self.django_type + if django_type is not None and not issubclass( + django_type, strawberry.relay.Node + ): + return django_type.__strawberry_django_definition__.pagination + + return False + @property def arguments(self) -> list[StrawberryArgument]: arguments = [] - if self.base_resolver is None and self.is_list: + if self.base_resolver is None and (self.is_list or self.is_paginated): pagination = self.get_pagination() if pagination is not None: arguments.append( @@ -143,20 +305,7 @@ def arguments(self, value: list[StrawberryArgument]): return args_prop.fset(self, value) # type: ignore def get_pagination(self) -> Optional[type]: - has_pagination = self.pagination - - if isinstance(has_pagination, UnsetType): - django_type = self.django_type - has_pagination = ( - django_type.__strawberry_django_definition__.pagination - if ( - django_type is not None - and not issubclass(django_type, strawberry.relay.Node) - ) - else False - ) - - return OffsetPaginationInput if has_pagination else None + return OffsetPaginationInput if self._has_pagination() else None def apply_pagination( self, @@ -172,11 +321,19 @@ def get_queryset( queryset: _QS, info: Info, *, - pagination: Optional[object] = None, + pagination: Optional[OffsetPaginationInput] = None, _strawberry_related_field_id: Optional[str] = None, **kwargs, ) -> _QS: queryset = super().get_queryset(queryset, info, **kwargs) + + # This is counter intuitive, but in case we are returning a `Paginated` + # result, we want to set the original queryset _as is_ as it will apply + # the pagination later on when resolving its `.results` field. + # Check `get_wrapped_result` below for more details. + if self.is_paginated: + return queryset + return self.apply_pagination( queryset, pagination, diff --git a/strawberry_django/permissions.py b/strawberry_django/permissions.py index dab1995a..2b1a60f8 100644 --- a/strawberry_django/permissions.py +++ b/strawberry_django/permissions.py @@ -39,6 +39,7 @@ from strawberry_django.auth.utils import aget_current_user, get_current_user from strawberry_django.fields.types import OperationInfo, OperationMessage +from strawberry_django.pagination import OffsetPaginated from strawberry_django.resolvers import django_resolver from .utils.query import filter_for_user @@ -47,6 +48,8 @@ if TYPE_CHECKING: from strawberry.django.context import StrawberryDjangoContext + from strawberry_django.fields.field import StrawberryDjangoField + _M = TypeVar("_M", bound=Model) @@ -405,6 +408,11 @@ def handle_no_permission(self, exception: BaseException, *, info: Info): if isinstance(ret_type, StrawberryList): return [] + if isinstance(ret_type, type) and issubclass(ret_type, OffsetPaginated): + django_model = cast("StrawberryDjangoField", info._field).django_model + assert django_model + return django_model._default_manager.none() + # If it is a Connection, try to return an empty connection, but only if # it is the only possibility available... for ret_possibility in ret_types: diff --git a/strawberry_django/relay.py b/strawberry_django/relay.py index b8db964d..17750e74 100644 --- a/strawberry_django/relay.py +++ b/strawberry_django/relay.py @@ -22,6 +22,7 @@ from strawberry.utils.await_maybe import AwaitableOrValue from typing_extensions import Literal, Self +from strawberry_django.pagination import get_total_count from strawberry_django.queryset import run_type_get_queryset from strawberry_django.resolvers import django_getattr, django_resolver from strawberry_django.utils.typing import ( @@ -48,41 +49,12 @@ class ListConnectionWithTotalCount(relay.ListConnection[relay.NodeType]): @strawberry.field(description="Total quantity of existing nodes.") @django_resolver def total_count(self) -> Optional[int]: - from .optimizer import is_optimized_by_prefetching - assert self.nodes is not None - if isinstance(self.nodes, models.QuerySet) and is_optimized_by_prefetching( - self.nodes - ): - result = cast(list[relay.NodeType], self.nodes._result_cache) # type: ignore - try: - return ( - result[0]._strawberry_total_count # type: ignore - if result - else 0 - ) - except AttributeError: - warnings.warn( - ( - "Pagination annotations not found, falling back to QuerySet resolution. " - "This might cause n+1 issues..." - ), - RuntimeWarning, - stacklevel=2, - ) + if isinstance(self.nodes, models.QuerySet): + return get_total_count(self.nodes) - total_count = None - try: - total_count = cast( - "models.QuerySet[models.Model]", - self.nodes, - ).count() - except (AttributeError, ValueError, TypeError): - if isinstance(self.nodes, Sized): - total_count = len(self.nodes) - - return total_count + return len(self.nodes) if isinstance(self.nodes, Sized) else None @classmethod def resolve_connection( diff --git a/strawberry_django/resolvers.py b/strawberry_django/resolvers.py index c4e186b0..491b486e 100644 --- a/strawberry_django/resolvers.py +++ b/strawberry_django/resolvers.py @@ -193,7 +193,7 @@ def resolve_base_manager(manager: BaseManager) -> Any: # prevents us from importing and checking isinstance on them directly. try: # ManyRelatedManager - return list(prefetched_cache[manager.prefetch_cache_name]) # type: ignore + return prefetched_cache[manager.prefetch_cache_name] # type: ignore except (AttributeError, KeyError): try: # RelatedManager @@ -203,7 +203,7 @@ def resolve_base_manager(manager: BaseManager) -> Any: getattr(result_field.remote_field, "cache_name", None) or result_field.remote_field.get_cache_name() ) - return list(prefetched_cache[cache_name]) + return prefetched_cache[cache_name] except (AttributeError, KeyError): pass diff --git a/strawberry_django/settings.py b/strawberry_django/settings.py index e0a0606c..72a4099d 100644 --- a/strawberry_django/settings.py +++ b/strawberry_django/settings.py @@ -1,6 +1,6 @@ """Code for interacting with Django settings.""" -from typing import cast +from typing import Optional, cast from django.conf import settings from typing_extensions import TypedDict @@ -42,6 +42,10 @@ class StrawberryDjangoSettings(TypedDict): #: If True, deprecated way of using filters will be working USE_DEPRECATED_FILTERS: bool + #: The default limit for pagination when not provided. Can be set to `None` + #: to set it to unlimited. + PAGINATION_DEFAULT_LIMIT: Optional[int] + DEFAULT_DJANGO_SETTINGS = StrawberryDjangoSettings( FIELD_DESCRIPTION_FROM_HELP_TEXT=False, @@ -52,6 +56,7 @@ class StrawberryDjangoSettings(TypedDict): MAP_AUTO_ID_AS_GLOBAL_ID=False, DEFAULT_PK_FIELD_NAME="pk", USE_DEPRECATED_FILTERS=False, + PAGINATION_DEFAULT_LIMIT=100, ) diff --git a/strawberry_django/type.py b/strawberry_django/type.py index 3ca5c2a4..8d234a7e 100644 --- a/strawberry_django/type.py +++ b/strawberry_django/type.py @@ -55,12 +55,12 @@ from .fields.types import get_model_field, resolve_model_field_name from .settings import strawberry_django_settings as django_settings -__all = [ - "StrawberryDjangoType", - "type", - "interface", +__all__ = [ + "StrawberryDjangoDefinition", "input", + "interface", "partial", + "type", ] _T = TypeVar("_T", bound=type) diff --git a/tests/projects/schema.py b/tests/projects/schema.py index bedd5934..d87b3f35 100644 --- a/tests/projects/schema.py +++ b/tests/projects/schema.py @@ -29,6 +29,7 @@ from strawberry_django.fields.types import ListInput, NodeInput, NodeInputPartial from strawberry_django.mutations import resolvers from strawberry_django.optimizer import DjangoOptimizerExtension +from strawberry_django.pagination import OffsetPaginated from strawberry_django.permissions import ( HasPerm, HasRetvalPerm, @@ -158,6 +159,10 @@ class MilestoneType(relay.Node, Named): order=IssueOrder, pagination=True, ) + issues_paginated: OffsetPaginated["IssueType"] = strawberry_django.offset_paginated( + field_name="issues", + order=IssueOrder, + ) issues_with_filters: ListConnectionWithTotalCount["IssueType"] = ( strawberry_django.connection( field_name="issues", @@ -375,6 +380,7 @@ class Query: staff_list: list[Optional[StaffType]] = strawberry_django.node() issue_list: list[IssueType] = strawberry_django.field() + issues_paginated: OffsetPaginated[IssueType] = strawberry_django.offset_paginated() milestone_list: list[MilestoneType] = strawberry_django.field( order=MilestoneOrder, filters=MilestoneFilter, @@ -429,6 +435,11 @@ class Query: issue_list_perm_required: list[IssueType] = strawberry_django.field( extensions=[HasPerm(perms=["projects.view_issue"])], ) + issues_paginated_perm_required: OffsetPaginated[IssueType] = ( + strawberry_django.offset_paginated( + extensions=[HasPerm(perms=["projects.view_issue"])], + ) + ) issue_conn_perm_required: ListConnectionWithTotalCount[IssueType] = ( strawberry_django.connection( extensions=[HasPerm(perms=["projects.view_issue"])], @@ -447,6 +458,11 @@ class Query: issue_list_obj_perm_required_paginated: list[IssueType] = strawberry_django.field( extensions=[HasRetvalPerm(perms=["projects.view_issue"])], pagination=True ) + issues_paginated_obj_perm_required: OffsetPaginated[IssueType] = ( + strawberry_django.offset_paginated( + extensions=[HasRetvalPerm(perms=["projects.view_issue"])], + ) + ) issue_conn_obj_perm_required: ListConnectionWithTotalCount[IssueType] = ( strawberry_django.connection( extensions=[HasRetvalPerm(perms=["projects.view_issue"])], diff --git a/tests/projects/snapshots/schema.gql b/tests/projects/snapshots/schema.gql index 16533658..b5f31d33 100644 --- a/tests/projects/snapshots/schema.gql +++ b/tests/projects/snapshots/schema.gql @@ -336,6 +336,16 @@ type IssueTypeEdge { node: IssueType! } +type IssueTypeOffsetPaginated { + pageInfo: OffsetPaginationInfo! + + """Total count of existing results.""" + totalCount: Int! + + """List of paginated results.""" + results: [IssueType!]! +} + input MilestoneFilter { name: StrFilterLookup project: DjangoModelFilterInput @@ -373,6 +383,7 @@ type MilestoneType implements Node & Named { dueDate: Date project: ProjectType! issues(filters: IssueFilter, order: IssueOrder, pagination: OffsetPaginationInput): [IssueType!]! + issuesPaginated(pagination: OffsetPaginationInput, order: IssueOrder): IssueTypeOffsetPaginated! issuesWithFilters( filters: IssueFilter @@ -452,9 +463,14 @@ input NodeInputPartial { id: GlobalID } +type OffsetPaginationInfo { + offset: Int! + limit: Int +} + input OffsetPaginationInput { offset: Int! = 0 - limit: Int! = -1 + limit: Int = null } type OperationInfo { @@ -610,6 +626,7 @@ type Query { ids: [GlobalID!]! ): [StaffType]! issueList: [IssueType!]! + issuesPaginated(pagination: OffsetPaginationInput): IssueTypeOffsetPaginated! milestoneList(filters: MilestoneFilter, order: MilestoneOrder, pagination: OffsetPaginationInput): [MilestoneType!]! projectList(filters: ProjectFilter): [ProjectType!]! tagList: [TagType!]! @@ -729,6 +746,7 @@ type Query { id: GlobalID! ): IssueType @hasPerm(permissions: [{app: "projects", permission: "view_issue"}], any: true) issueListPermRequired: [IssueType!]! @hasPerm(permissions: [{app: "projects", permission: "view_issue"}], any: true) + issuesPaginatedPermRequired(pagination: OffsetPaginationInput): IssueTypeOffsetPaginated! @hasPerm(permissions: [{app: "projects", permission: "view_issue"}], any: true) issueConnPermRequired( """Returns the items in the list that come before the specified cursor.""" before: String = null @@ -752,6 +770,7 @@ type Query { ): IssueType @hasRetvalPerm(permissions: [{app: "projects", permission: "view_issue"}], any: true) issueListObjPermRequired: [IssueType!]! @hasRetvalPerm(permissions: [{app: "projects", permission: "view_issue"}], any: true) issueListObjPermRequiredPaginated(pagination: OffsetPaginationInput): [IssueType!]! @hasRetvalPerm(permissions: [{app: "projects", permission: "view_issue"}], any: true) + issuesPaginatedObjPermRequired(pagination: OffsetPaginationInput): IssueTypeOffsetPaginated! @hasRetvalPerm(permissions: [{app: "projects", permission: "view_issue"}], any: true) issueConnObjPermRequired( """Returns the items in the list that come before the specified cursor.""" before: String = null diff --git a/tests/projects/snapshots/schema_with_inheritance.gql b/tests/projects/snapshots/schema_with_inheritance.gql index 8c891a65..352c8e1b 100644 --- a/tests/projects/snapshots/schema_with_inheritance.gql +++ b/tests/projects/snapshots/schema_with_inheritance.gql @@ -136,6 +136,16 @@ type IssueTypeEdge { node: IssueType! } +type IssueTypeOffsetPaginated { + pageInfo: OffsetPaginationInfo! + + """Total count of existing results.""" + totalCount: Int! + + """List of paginated results.""" + results: [IssueType!]! +} + input MilestoneFilter { name: StrFilterLookup project: DjangoModelFilterInput @@ -163,6 +173,7 @@ type MilestoneType implements Node & Named { dueDate: Date project: ProjectType! issues(filters: IssueFilter, order: IssueOrder, pagination: OffsetPaginationInput): [IssueType!]! + issuesPaginated(pagination: OffsetPaginationInput, order: IssueOrder): IssueTypeOffsetPaginated! issuesWithFilters( filters: IssueFilter @@ -190,6 +201,7 @@ type MilestoneTypeSubclass implements Node & Named { dueDate: Date project: ProjectType! issues(filters: IssueFilter, order: IssueOrder, pagination: OffsetPaginationInput): [IssueType!]! + issuesPaginated(pagination: OffsetPaginationInput, order: IssueOrder): IssueTypeOffsetPaginated! issuesWithFilters( filters: IssueFilter @@ -229,9 +241,14 @@ input NodeInput { id: GlobalID! } +type OffsetPaginationInfo { + offset: Int! + limit: Int +} + input OffsetPaginationInput { offset: Int! = 0 - limit: Int! = -1 + limit: Int = null } type OperationInfo { diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index a2046170..5ed3d957 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -1425,3 +1425,183 @@ class Query: "issues": [{"pk": str(issue1.pk)}], }, } + + +@pytest.mark.django_db(transaction=True) +def test_query_paginated(db, gql_client: GraphQLTestClient): + query = """ + query TestQuery ($pagination: OffsetPaginationInput) { + issuesPaginated (pagination: $pagination) { + totalCount + results { + name + milestone { + name + } + } + } + } + """ + + milestone1 = MilestoneFactory.create() + milestone2 = MilestoneFactory.create() + + issue1 = IssueFactory.create(milestone=milestone1) + issue2 = IssueFactory.create(milestone=milestone1) + issue3 = IssueFactory.create(milestone=milestone1) + issue4 = IssueFactory.create(milestone=milestone2) + issue5 = IssueFactory.create(milestone=milestone2) + + with assert_num_queries(2 if DjangoOptimizerExtension.enabled.get() else 7): + res = gql_client.query(query) + + assert res.data == { + "issuesPaginated": { + "totalCount": 5, + "results": [ + {"name": issue1.name, "milestone": {"name": milestone1.name}}, + {"name": issue2.name, "milestone": {"name": milestone1.name}}, + {"name": issue3.name, "milestone": {"name": milestone1.name}}, + {"name": issue4.name, "milestone": {"name": milestone2.name}}, + {"name": issue5.name, "milestone": {"name": milestone2.name}}, + ], + } + } + + with assert_num_queries(2 if DjangoOptimizerExtension.enabled.get() else 4): + res = gql_client.query(query, variables={"pagination": {"limit": 2}}) + + assert res.data == { + "issuesPaginated": { + "totalCount": 5, + "results": [ + {"name": issue1.name, "milestone": {"name": milestone1.name}}, + {"name": issue2.name, "milestone": {"name": milestone1.name}}, + ], + } + } + + with assert_num_queries(2 if DjangoOptimizerExtension.enabled.get() else 4): + res = gql_client.query( + query, variables={"pagination": {"limit": 2, "offset": 2}} + ) + + assert res.data == { + "issuesPaginated": { + "totalCount": 5, + "results": [ + {"name": issue3.name, "milestone": {"name": milestone1.name}}, + {"name": issue4.name, "milestone": {"name": milestone2.name}}, + ], + } + } + + +@pytest.mark.django_db(transaction=True) +def test_query_paginated_nested(db, gql_client: GraphQLTestClient): + query = """ + query TestQuery ($pagination: OffsetPaginationInput) { + milestoneList { + name + issuesPaginated (pagination: $pagination) { + totalCount + results { + name + milestone { + name + } + } + } + } + } + """ + + milestone1 = MilestoneFactory.create() + milestone2 = MilestoneFactory.create() + + issue1 = IssueFactory.create(milestone=milestone1) + issue2 = IssueFactory.create(milestone=milestone1) + issue3 = IssueFactory.create(milestone=milestone1) + issue4 = IssueFactory.create(milestone=milestone2) + issue5 = IssueFactory.create(milestone=milestone2) + + with assert_num_queries(2 if DjangoOptimizerExtension.enabled.get() else 5): + res = gql_client.query(query) + + assert res.data == { + "milestoneList": [ + { + "name": milestone1.name, + "issuesPaginated": { + "totalCount": 3, + "results": [ + {"name": issue1.name, "milestone": {"name": milestone1.name}}, + {"name": issue2.name, "milestone": {"name": milestone1.name}}, + {"name": issue3.name, "milestone": {"name": milestone1.name}}, + ], + }, + }, + { + "name": milestone2.name, + "issuesPaginated": { + "totalCount": 2, + "results": [ + {"name": issue4.name, "milestone": {"name": milestone2.name}}, + {"name": issue5.name, "milestone": {"name": milestone2.name}}, + ], + }, + }, + ] + } + + with assert_num_queries(2 if DjangoOptimizerExtension.enabled.get() else 5): + res = gql_client.query(query, variables={"pagination": {"limit": 1}}) + + assert res.data == { + "milestoneList": [ + { + "name": milestone1.name, + "issuesPaginated": { + "totalCount": 3, + "results": [ + {"name": issue1.name, "milestone": {"name": milestone1.name}}, + ], + }, + }, + { + "name": milestone2.name, + "issuesPaginated": { + "totalCount": 2, + "results": [ + {"name": issue4.name, "milestone": {"name": milestone2.name}}, + ], + }, + }, + ] + } + + with assert_num_queries(3 if DjangoOptimizerExtension.enabled.get() else 5): + res = gql_client.query( + query, variables={"pagination": {"limit": 1, "offset": 2}} + ) + + assert res.data == { + "milestoneList": [ + { + "name": milestone1.name, + "issuesPaginated": { + "totalCount": 3, + "results": [ + {"name": issue3.name, "milestone": {"name": milestone1.name}}, + ], + }, + }, + { + "name": milestone2.name, + "issuesPaginated": { + "totalCount": 2, + "results": [], + }, + }, + ] + } diff --git a/tests/test_paginated_type.py b/tests/test_paginated_type.py new file mode 100644 index 00000000..74537727 --- /dev/null +++ b/tests/test_paginated_type.py @@ -0,0 +1,831 @@ +import textwrap + +import pytest +import strawberry +from django.db.models import QuerySet + +import strawberry_django +from strawberry_django.pagination import OffsetPaginated, OffsetPaginationInput +from tests import models + + +def test_paginated_schema(): + @strawberry_django.type(models.Fruit) + class Fruit: + id: int + name: str + + @strawberry_django.type(models.Color) + class Color: + id: int + name: str + fruits: OffsetPaginated[Fruit] + + @strawberry.type + class Query: + fruits: OffsetPaginated[Fruit] = strawberry_django.offset_paginated() + colors: OffsetPaginated[Color] = strawberry_django.offset_paginated() + + schema = strawberry.Schema(query=Query) + + expected = '''\ + type Color { + id: Int! + name: String! + fruits(pagination: OffsetPaginationInput): FruitOffsetPaginated! + } + + type ColorOffsetPaginated { + pageInfo: OffsetPaginationInfo! + + """Total count of existing results.""" + totalCount: Int! + + """List of paginated results.""" + results: [Color!]! + } + + type Fruit { + id: Int! + name: String! + } + + type FruitOffsetPaginated { + pageInfo: OffsetPaginationInfo! + + """Total count of existing results.""" + totalCount: Int! + + """List of paginated results.""" + results: [Fruit!]! + } + + type OffsetPaginationInfo { + offset: Int! + limit: Int + } + + input OffsetPaginationInput { + offset: Int! = 0 + limit: Int = null + } + + type Query { + fruits(pagination: OffsetPaginationInput): FruitOffsetPaginated! + colors(pagination: OffsetPaginationInput): ColorOffsetPaginated! + } + ''' + + assert textwrap.dedent(str(schema)) == textwrap.dedent(expected).strip() + + +@pytest.mark.django_db(transaction=True) +def test_pagination_query(): + @strawberry_django.type(models.Fruit) + class Fruit: + id: int + name: str + + @strawberry.type + class Query: + fruits: OffsetPaginated[Fruit] = strawberry_django.offset_paginated() + + models.Fruit.objects.create(name="Apple") + models.Fruit.objects.create(name="Banana") + models.Fruit.objects.create(name="Strawberry") + + schema = strawberry.Schema(query=Query) + + query = """\ + query GetFruits ($pagination: OffsetPaginationInput) { + fruits (pagination: $pagination) { + totalCount + results { + name + } + } + } + """ + + res = schema.execute_sync(query) + assert res.errors is None + assert res.data == { + "fruits": { + "totalCount": 3, + "results": [{"name": "Apple"}, {"name": "Banana"}, {"name": "Strawberry"}], + } + } + + res = schema.execute_sync(query, variable_values={"pagination": {"limit": 1}}) + assert res.errors is None + assert res.data == { + "fruits": { + "totalCount": 3, + "results": [{"name": "Apple"}], + } + } + + res = schema.execute_sync( + query, variable_values={"pagination": {"limit": 1, "offset": 1}} + ) + assert res.errors is None + assert res.data == { + "fruits": { + "totalCount": 3, + "results": [{"name": "Banana"}], + } + } + + +@pytest.mark.django_db(transaction=True) +async def test_pagination_query_async(): + @strawberry_django.type(models.Fruit) + class Fruit: + id: int + name: str + + @strawberry.type + class Query: + fruits: OffsetPaginated[Fruit] = strawberry_django.offset_paginated() + + await models.Fruit.objects.acreate(name="Apple") + await models.Fruit.objects.acreate(name="Banana") + await models.Fruit.objects.acreate(name="Strawberry") + + schema = strawberry.Schema(query=Query) + + query = """\ + query GetFruits ($pagination: OffsetPaginationInput) { + fruits (pagination: $pagination) { + totalCount + results { + name + } + } + } + """ + + res = await schema.execute(query) + assert res.errors is None + assert res.data == { + "fruits": { + "totalCount": 3, + "results": [{"name": "Apple"}, {"name": "Banana"}, {"name": "Strawberry"}], + } + } + + res = await schema.execute(query, variable_values={"pagination": {"limit": 1}}) + assert res.errors is None + assert res.data == { + "fruits": { + "totalCount": 3, + "results": [{"name": "Apple"}], + } + } + + res = await schema.execute( + query, variable_values={"pagination": {"limit": 1, "offset": 1}} + ) + assert res.errors is None + assert res.data == { + "fruits": { + "totalCount": 3, + "results": [{"name": "Banana"}], + } + } + + +@pytest.mark.django_db(transaction=True) +def test_pagination_nested_query(): + @strawberry_django.type(models.Fruit) + class Fruit: + id: int + name: str + + @strawberry_django.type(models.Color) + class Color: + id: int + name: str + fruits: OffsetPaginated[Fruit] = strawberry_django.offset_paginated() + + @strawberry.type + class Query: + colors: OffsetPaginated[Color] = strawberry_django.offset_paginated() + + red = models.Color.objects.create(name="Red") + yellow = models.Color.objects.create(name="Yellow") + + models.Fruit.objects.create(name="Apple", color=red) + models.Fruit.objects.create(name="Banana", color=yellow) + models.Fruit.objects.create(name="Strawberry", color=red) + + schema = strawberry.Schema(query=Query) + + query = """\ + query GetColors ($pagination: OffsetPaginationInput) { + colors { + totalCount + results { + fruits (pagination: $pagination) { + totalCount + results { + name + } + } + } + } + } + """ + + res = schema.execute_sync(query) + assert res.errors is None + assert res.data == { + "colors": { + "totalCount": 2, + "results": [ + { + "fruits": { + "totalCount": 2, + "results": [{"name": "Apple"}, {"name": "Strawberry"}], + } + }, + { + "fruits": { + "totalCount": 1, + "results": [{"name": "Banana"}], + } + }, + ], + } + } + + res = schema.execute_sync(query, variable_values={"pagination": {"limit": 1}}) + assert res.errors is None + assert res.data == { + "colors": { + "totalCount": 2, + "results": [ + { + "fruits": { + "totalCount": 2, + "results": [{"name": "Apple"}], + } + }, + { + "fruits": { + "totalCount": 1, + "results": [{"name": "Banana"}], + } + }, + ], + } + } + + res = schema.execute_sync( + query, variable_values={"pagination": {"limit": 1, "offset": 1}} + ) + assert res.errors is None + assert res.data == { + "colors": { + "totalCount": 2, + "results": [ + { + "fruits": { + "totalCount": 2, + "results": [{"name": "Strawberry"}], + } + }, + { + "fruits": { + "totalCount": 1, + "results": [], + } + }, + ], + } + } + + +@pytest.mark.django_db(transaction=True) +async def test_pagination_nested_query_async(): + @strawberry_django.type(models.Fruit) + class Fruit: + id: int + name: str + + @strawberry_django.type(models.Color) + class Color: + id: int + name: str + fruits: OffsetPaginated[Fruit] = strawberry_django.offset_paginated() + + @strawberry.type + class Query: + colors: OffsetPaginated[Color] = strawberry_django.offset_paginated() + + red = await models.Color.objects.acreate(name="Red") + yellow = await models.Color.objects.acreate(name="Yellow") + + await models.Fruit.objects.acreate(name="Apple", color=red) + await models.Fruit.objects.acreate(name="Banana", color=yellow) + await models.Fruit.objects.acreate(name="Strawberry", color=red) + + schema = strawberry.Schema(query=Query) + + query = """\ + query GetColors ($pagination: OffsetPaginationInput) { + colors { + totalCount + results { + fruits (pagination: $pagination) { + totalCount + results { + name + } + } + } + } + } + """ + + res = await schema.execute(query) + assert res.errors is None + assert res.data == { + "colors": { + "totalCount": 2, + "results": [ + { + "fruits": { + "totalCount": 2, + "results": [{"name": "Apple"}, {"name": "Strawberry"}], + } + }, + { + "fruits": { + "totalCount": 1, + "results": [{"name": "Banana"}], + } + }, + ], + } + } + + res = await schema.execute(query, variable_values={"pagination": {"limit": 1}}) + assert res.errors is None + assert res.data == { + "colors": { + "totalCount": 2, + "results": [ + { + "fruits": { + "totalCount": 2, + "results": [{"name": "Apple"}], + } + }, + { + "fruits": { + "totalCount": 1, + "results": [{"name": "Banana"}], + } + }, + ], + } + } + + res = await schema.execute( + query, variable_values={"pagination": {"limit": 1, "offset": 1}} + ) + assert res.errors is None + assert res.data == { + "colors": { + "totalCount": 2, + "results": [ + { + "fruits": { + "totalCount": 2, + "results": [{"name": "Strawberry"}], + } + }, + { + "fruits": { + "totalCount": 1, + "results": [], + } + }, + ], + } + } + + +@pytest.mark.django_db(transaction=True) +def test_pagination_query_with_subclass(): + @strawberry_django.type(models.Fruit) + class Fruit: + id: int + name: str + + @strawberry.type + class FruitPaginated(OffsetPaginated[Fruit]): + _custom_field: strawberry.Private[str] + + @strawberry_django.field + def custom_field(self) -> str: + return self._custom_field + + @classmethod + def resolve_paginated(cls, queryset, *, info, pagination=None, **kwargs): + return cls( + queryset=queryset, + pagination=pagination or OffsetPaginationInput(), + _custom_field="pagination rocks", + ) + + @strawberry.type + class Query: + fruits: FruitPaginated = strawberry_django.offset_paginated() + + models.Fruit.objects.create(name="Apple") + models.Fruit.objects.create(name="Banana") + models.Fruit.objects.create(name="Strawberry") + + schema = strawberry.Schema(query=Query) + + query = """\ + query GetFruits ($pagination: OffsetPaginationInput) { + fruits (pagination: $pagination) { + totalCount + customField + results { + name + } + } + } + """ + + res = schema.execute_sync(query) + assert res.errors is None + assert res.data == { + "fruits": { + "totalCount": 3, + "customField": "pagination rocks", + "results": [{"name": "Apple"}, {"name": "Banana"}, {"name": "Strawberry"}], + } + } + + res = schema.execute_sync(query, variable_values={"pagination": {"limit": 1}}) + assert res.errors is None + assert res.data == { + "fruits": { + "totalCount": 3, + "customField": "pagination rocks", + "results": [{"name": "Apple"}], + } + } + + res = schema.execute_sync( + query, variable_values={"pagination": {"limit": 1, "offset": 2}} + ) + assert res.errors is None + assert res.data == { + "fruits": { + "totalCount": 3, + "customField": "pagination rocks", + "results": [{"name": "Strawberry"}], + } + } + + +@pytest.mark.django_db(transaction=True) +def test_pagination_query_with_resolver_schema(): + @strawberry_django.type(models.Fruit) + class Fruit: + id: int + name: str + + @strawberry_django.filter(models.Fruit) + class FruitFilter: + name: str + + @strawberry_django.order(models.Fruit) + class FruitOrder: + name: str + + @strawberry.type + class Query: + @strawberry_django.offset_paginated(OffsetPaginated[Fruit]) + def fruits(self) -> QuerySet[models.Fruit]: ... + + @strawberry_django.offset_paginated( + OffsetPaginated[Fruit], + filters=FruitFilter, + order=FruitOrder, + ) + def fruits_with_order_and_filter(self) -> QuerySet[models.Fruit]: ... + + schema = strawberry.Schema(query=Query) + + expected = ''' + type Fruit { + id: Int! + name: String! + } + + input FruitFilter { + name: String! + AND: FruitFilter + OR: FruitFilter + NOT: FruitFilter + DISTINCT: Boolean + } + + type FruitOffsetPaginated { + pageInfo: OffsetPaginationInfo! + + """Total count of existing results.""" + totalCount: Int! + + """List of paginated results.""" + results: [Fruit!]! + } + + input FruitOrder { + name: String + } + + type OffsetPaginationInfo { + offset: Int! + limit: Int + } + + input OffsetPaginationInput { + offset: Int! = 0 + limit: Int = null + } + + type Query { + fruits(pagination: OffsetPaginationInput): FruitOffsetPaginated! + fruitsWithOrderAndFilter(filters: FruitFilter, order: FruitOrder, pagination: OffsetPaginationInput): FruitOffsetPaginated! + } + ''' + + assert textwrap.dedent(str(schema)) == textwrap.dedent(expected).strip() + + +@pytest.mark.django_db(transaction=True) +def test_pagination_query_with_resolver(): + @strawberry_django.type(models.Fruit) + class Fruit: + id: int + name: str + + @strawberry_django.filter(models.Fruit) + class FruitFilter: + name: strawberry.auto + + @strawberry_django.order(models.Fruit) + class FruitOrder: + name: strawberry.auto + + @strawberry.type + class Query: + @strawberry_django.offset_paginated(OffsetPaginated[Fruit]) + def fruits(self) -> QuerySet[models.Fruit]: + return models.Fruit.objects.filter(name__startswith="S") + + @strawberry_django.offset_paginated( + OffsetPaginated[Fruit], + filters=FruitFilter, + order=FruitOrder, + ) + def fruits_with_order_and_filter(self) -> QuerySet[models.Fruit]: + return models.Fruit.objects.filter(name__startswith="S") + + models.Fruit.objects.create(name="Apple") + models.Fruit.objects.create(name="Strawberry") + models.Fruit.objects.create(name="Banana") + models.Fruit.objects.create(name="Sugar Apple") + models.Fruit.objects.create(name="Starfruit") + + schema = strawberry.Schema(query=Query) + + query = """\ + query GetFruits ( + $pagination: OffsetPaginationInput + $filters: FruitFilter + $order: FruitOrder + ) { + fruits (pagination: $pagination) { + totalCount + results { + name + } + } + fruitsWithOrderAndFilter ( + pagination: $pagination + filters: $filters + order: $order + ) { + totalCount + results { + name + } + } + } + """ + + res = schema.execute_sync(query) + assert res.errors is None + assert res.data == { + "fruits": { + "totalCount": 3, + "results": [ + {"name": "Strawberry"}, + {"name": "Sugar Apple"}, + {"name": "Starfruit"}, + ], + }, + "fruitsWithOrderAndFilter": { + "totalCount": 3, + "results": [ + {"name": "Strawberry"}, + {"name": "Sugar Apple"}, + {"name": "Starfruit"}, + ], + }, + } + + res = schema.execute_sync(query, variable_values={"pagination": {"limit": 1}}) + assert res.errors is None + assert res.data == { + "fruits": { + "totalCount": 3, + "results": [ + {"name": "Strawberry"}, + ], + }, + "fruitsWithOrderAndFilter": { + "totalCount": 3, + "results": [ + {"name": "Strawberry"}, + ], + }, + } + + res = schema.execute_sync( + query, + variable_values={ + "pagination": {"limit": 2}, + "order": {"name": "ASC"}, + "filters": {"name": "Strawberry"}, + }, + ) + assert res.errors is None + assert res.data == { + "fruits": { + "totalCount": 3, + "results": [ + {"name": "Strawberry"}, + {"name": "Sugar Apple"}, + ], + }, + "fruitsWithOrderAndFilter": { + "totalCount": 1, + "results": [ + {"name": "Strawberry"}, + ], + }, + } + + +@pytest.mark.django_db(transaction=True) +def test_pagination_query_with_resolver_arguments(): + @strawberry_django.type(models.Fruit) + class Fruit: + id: int + name: str + + @strawberry_django.filter(models.Fruit) + class FruitFilter: + name: strawberry.auto + + @strawberry_django.order(models.Fruit) + class FruitOrder: + name: strawberry.auto + + @strawberry.type + class Query: + @strawberry_django.offset_paginated(OffsetPaginated[Fruit]) + def fruits(self, starts_with: str) -> QuerySet[models.Fruit]: + return models.Fruit.objects.filter(name__startswith=starts_with) + + @strawberry_django.offset_paginated( + OffsetPaginated[Fruit], + filters=FruitFilter, + order=FruitOrder, + ) + def fruits_with_order_and_filter( + self, starts_with: str + ) -> QuerySet[models.Fruit]: + return models.Fruit.objects.filter(name__startswith=starts_with) + + models.Fruit.objects.create(name="Apple") + models.Fruit.objects.create(name="Strawberry") + models.Fruit.objects.create(name="Banana") + models.Fruit.objects.create(name="Sugar Apple") + models.Fruit.objects.create(name="Starfruit") + + schema = strawberry.Schema(query=Query) + + query = """\ + query GetFruits ( + $pagination: OffsetPaginationInput + $filters: FruitFilter + $order: FruitOrder + $startsWith: String! + ) { + fruits (startsWith: $startsWith, pagination: $pagination) { + totalCount + results { + name + } + } + fruitsWithOrderAndFilter ( + startsWith: $startsWith + pagination: $pagination + filters: $filters + order: $order + ) { + totalCount + results { + name + } + } + } + """ + + res = schema.execute_sync(query, variable_values={"startsWith": "S"}) + assert res.errors is None + assert res.data == { + "fruits": { + "totalCount": 3, + "results": [ + {"name": "Strawberry"}, + {"name": "Sugar Apple"}, + {"name": "Starfruit"}, + ], + }, + "fruitsWithOrderAndFilter": { + "totalCount": 3, + "results": [ + {"name": "Strawberry"}, + {"name": "Sugar Apple"}, + {"name": "Starfruit"}, + ], + }, + } + + res = schema.execute_sync( + query, + variable_values={"startsWith": "S", "pagination": {"limit": 1}}, + ) + assert res.errors is None + assert res.data == { + "fruits": { + "totalCount": 3, + "results": [ + {"name": "Strawberry"}, + ], + }, + "fruitsWithOrderAndFilter": { + "totalCount": 3, + "results": [ + {"name": "Strawberry"}, + ], + }, + } + + res = schema.execute_sync( + query, + variable_values={ + "startsWith": "S", + "pagination": {"limit": 2}, + "order": {"name": "ASC"}, + "filters": {"name": "Strawberry"}, + }, + ) + assert res.errors is None + assert res.data == { + "fruits": { + "totalCount": 3, + "results": [ + {"name": "Strawberry"}, + {"name": "Sugar Apple"}, + ], + }, + "fruitsWithOrderAndFilter": { + "totalCount": 1, + "results": [ + {"name": "Strawberry"}, + ], + }, + } diff --git a/tests/test_permissions.py b/tests/test_permissions.py index c1c492ee..c6e79318 100644 --- a/tests/test_permissions.py +++ b/tests/test_permissions.py @@ -9,11 +9,12 @@ from .projects.faker import ( GroupFactory, IssueFactory, + MilestoneFactory, StaffUserFactory, SuperuserUserFactory, UserFactory, ) -from .utils import GraphQLTestClient +from .utils import GraphQLTestClient, assert_num_queries PermKind: TypeAlias = Literal["user", "group", "superuser"] perm_kinds: list[PermKind] = ["user", "group", "superuser"] @@ -934,3 +935,167 @@ def test_conn_obj_perm_required(db, gql_client: GraphQLTestClient, kind: PermKin "totalCount": 1, }, } + + +@pytest.mark.django_db(transaction=True) +def test_query_paginated_with_permissions(db, gql_client: GraphQLTestClient): + query = """ + query TestQuery ($pagination: OffsetPaginationInput) { + issuesPaginatedPermRequired (pagination: $pagination) { + totalCount + results { + name + milestone { + name + } + } + } + } + """ + + milestone1 = MilestoneFactory.create() + milestone2 = MilestoneFactory.create() + + issue1 = IssueFactory.create(milestone=milestone1) + issue2 = IssueFactory.create(milestone=milestone1) + issue3 = IssueFactory.create(milestone=milestone1) + issue4 = IssueFactory.create(milestone=milestone2) + issue5 = IssueFactory.create(milestone=milestone2) + + # No user logged in + with assert_num_queries(0): + res = gql_client.query(query) + + assert res.data == { + "issuesPaginatedPermRequired": { + "totalCount": 0, + "results": [], + } + } + + user = UserFactory.create() + + # User logged in without permissions + with gql_client.login(user): + with assert_num_queries(4): + res = gql_client.query(query) + + assert res.data == { + "issuesPaginatedPermRequired": { + "totalCount": 0, + "results": [], + } + } + + # User logged in with permissions + user.user_permissions.add(Permission.objects.get(codename="view_issue")) + with gql_client.login(user): + with assert_num_queries(6 if DjangoOptimizerExtension.enabled.get() else 11): + res = gql_client.query(query) + + assert res.data == { + "issuesPaginatedPermRequired": { + "totalCount": 5, + "results": [ + {"name": issue1.name, "milestone": {"name": milestone1.name}}, + {"name": issue2.name, "milestone": {"name": milestone1.name}}, + {"name": issue3.name, "milestone": {"name": milestone1.name}}, + {"name": issue4.name, "milestone": {"name": milestone2.name}}, + {"name": issue5.name, "milestone": {"name": milestone2.name}}, + ], + } + } + + with assert_num_queries(6 if DjangoOptimizerExtension.enabled.get() else 8): + res = gql_client.query(query, variables={"pagination": {"limit": 2}}) + + assert res.data == { + "issuesPaginatedPermRequired": { + "totalCount": 5, + "results": [ + {"name": issue1.name, "milestone": {"name": milestone1.name}}, + {"name": issue2.name, "milestone": {"name": milestone1.name}}, + ], + } + } + + +@pytest.mark.django_db(transaction=True) +def test_query_paginated_with_obj_permissions(db, gql_client: GraphQLTestClient): + query = """ + query TestQuery ($pagination: OffsetPaginationInput) { + issuesPaginatedObjPermRequired (pagination: $pagination) { + totalCount + results { + name + milestone { + name + } + } + } + } + """ + + milestone1 = MilestoneFactory.create() + milestone2 = MilestoneFactory.create() + + IssueFactory.create(milestone=milestone1) + issue2 = IssueFactory.create(milestone=milestone1) + IssueFactory.create(milestone=milestone1) + issue4 = IssueFactory.create(milestone=milestone2) + IssueFactory.create(milestone=milestone2) + + # No user logged in + with assert_num_queries(0): + res = gql_client.query(query) + + assert res.data == { + "issuesPaginatedObjPermRequired": { + "totalCount": 0, + "results": [], + } + } + + user = UserFactory.create() + + # User logged in without permissions + with gql_client.login(user): + with assert_num_queries(5): + res = gql_client.query(query) + + assert res.data == { + "issuesPaginatedObjPermRequired": { + "totalCount": 0, + "results": [], + } + } + + assign_perm("view_issue", user, issue2) + assign_perm("view_issue", user, issue4) + + # User logged in with permissions + with gql_client.login(user): + with assert_num_queries(4 if DjangoOptimizerExtension.enabled.get() else 6): + res = gql_client.query(query) + + assert res.data == { + "issuesPaginatedObjPermRequired": { + "totalCount": 2, + "results": [ + {"name": issue2.name, "milestone": {"name": milestone1.name}}, + {"name": issue4.name, "milestone": {"name": milestone2.name}}, + ], + } + } + + with assert_num_queries(4 if DjangoOptimizerExtension.enabled.get() else 5): + res = gql_client.query(query, variables={"pagination": {"limit": 1}}) + + assert res.data == { + "issuesPaginatedObjPermRequired": { + "totalCount": 2, + "results": [ + {"name": issue2.name, "milestone": {"name": milestone1.name}}, + ], + } + } diff --git a/tests/test_settings.py b/tests/test_settings.py index 32f572a7..3c7cf1d6 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -30,6 +30,7 @@ def test_non_defaults(): MAP_AUTO_ID_AS_GLOBAL_ID=True, DEFAULT_PK_FIELD_NAME="id", USE_DEPRECATED_FILTERS=True, + PAGINATION_DEFAULT_LIMIT=250, ), ): assert ( @@ -43,5 +44,6 @@ def test_non_defaults(): MAP_AUTO_ID_AS_GLOBAL_ID=True, DEFAULT_PK_FIELD_NAME="id", USE_DEPRECATED_FILTERS=True, + PAGINATION_DEFAULT_LIMIT=250, ) )