From 7c6b44ecf590d58184c99b7acc7ebffc18718ce5 Mon Sep 17 00:00:00 2001 From: Thiago Bellini Ribeiro Date: Sun, 13 Oct 2024 13:47:47 -0300 Subject: [PATCH 01/10] feat: New Paginated generic to be used as a wrapped for paginated results --- strawberry_django/fields/base.py | 16 +- strawberry_django/fields/field.py | 17 +- strawberry_django/filters.py | 1 + strawberry_django/pagination.py | 139 ++++++++-- strawberry_django/relay.py | 36 +-- strawberry_django/type.py | 8 +- tests/test_paginated_type.py | 413 ++++++++++++++++++++++++++++++ 7 files changed, 562 insertions(+), 68 deletions(-) create mode 100644 tests/test_paginated_type.py diff --git a/strawberry_django/fields/base.py b/strawberry_django/fields/base.py index 07b11891..20e6864f 100644 --- a/strawberry_django/fields/base.py +++ b/strawberry_django/fields/base.py @@ -85,6 +85,8 @@ def is_async(self) -> bool: @functools.cached_property def django_type(self) -> type[WithStrawberryDjangoObjectDefinition] | None: + from strawberry_django.pagination import Paginated + origin = self.type if isinstance(origin, LazyType): @@ -92,7 +94,9 @@ def django_type(self) -> type[WithStrawberryDjangoObjectDefinition] | None: object_definition = get_object_definition(origin) - if object_definition and issubclass(object_definition.origin, relay.Connection): + if object_definition and issubclass( + object_definition.origin, (relay.Connection, Paginated) + ): origin_specialized_type_var_map = ( get_specialized_type_var_map(cast(type, origin)) or {} ) @@ -148,6 +152,16 @@ def is_list(self) -> bool: return isinstance(type_, StrawberryList) + @functools.cached_property + def is_paginated(self) -> bool: + from strawberry_django.pagination import Paginated + + type_ = self.type + if isinstance(type_, StrawberryOptional): + type_ = type_.of_type + + return isinstance(type_, type) and issubclass(type_, Paginated) + @functools.cached_property def is_connection(self) -> bool: type_ = self.type diff --git a/strawberry_django/fields/field.py b/strawberry_django/fields/field.py index 20a512fc..7f9fe3df 100644 --- a/strawberry_django/fields/field.py +++ b/strawberry_django/fields/field.py @@ -3,7 +3,7 @@ import dataclasses import inspect from collections.abc import Iterable, Mapping, Sequence -from functools import cached_property +from functools import cached_property, partial from typing import ( TYPE_CHECKING, Any, @@ -227,9 +227,12 @@ async def async_resolver(): if "info" not in kwargs: kwargs["info"] = info - resolved = await sync_to_async(self.get_queryset_hook(**kwargs))( - resolved, - ) + @sync_to_async + def resolve(): + inner_resolved = self.get_queryset_hook(**kwargs)(resolved) + return self.get_wrapped_result(inner_resolved, **kwargs) + + resolved = await resolve() return resolved @@ -243,15 +246,15 @@ async def async_resolver(): kwargs["info"] = info result = django_resolver( - lambda obj: obj, + partial(self.get_wrapped_result, **kwargs), qs_hook=self.get_queryset_hook(**kwargs), )(result) return result def get_queryset_hook(self, info: Info, **kwargs): - if self.is_connection: - # We don't want to fetch results yet, those will be done by the connection + if self.is_connection or self.is_paginated: + # We don't want to fetch results yet, those will be done by the connection/pagination def qs_hook(qs: models.QuerySet): # type: ignore return self.get_queryset(qs, info, **kwargs) diff --git a/strawberry_django/filters.py b/strawberry_django/filters.py index dd5d789a..9084a850 100644 --- a/strawberry_django/filters.py +++ b/strawberry_django/filters.py @@ -309,6 +309,7 @@ def arguments(self) -> list[StrawberryArgument]: and is_root_query and not self.is_list and not self.is_connection + and not self.is_paginated ): settings = strawberry_django_settings() arguments.append( diff --git a/strawberry_django/pagination.py b/strawberry_django/pagination.py index ef224db0..14bc53d2 100644 --- a/strawberry_django/pagination.py +++ b/strawberry_django/pagination.py @@ -1,9 +1,10 @@ import sys -from typing import TYPE_CHECKING, Optional, TypeVar, Union +import warnings +from typing import Generic, Optional, TypeVar, Union, cast import strawberry from django.db import DEFAULT_DB_ALIAS -from django.db.models import Count, Window +from django.db.models import Count, QuerySet, Window from django.db.models.functions import RowNumber from strawberry.types import Info from strawberry.types.arguments import StrawberryArgument @@ -11,19 +12,53 @@ from typing_extensions import Self from strawberry_django.fields.base import StrawberryDjangoFieldBase +from strawberry_django.resolvers import django_resolver from .arguments import argument -if TYPE_CHECKING: - from django.db.models import QuerySet +NodeType = TypeVar("NodeType") +_T = TypeVar("_T") +_QS = TypeVar("_QS", bound=QuerySet) -_QS = TypeVar("_QS", bound="QuerySet") +DEFAULT_OFFSET: int = 0 +DEFAULT_LIMIT: int = -1 @strawberry.input class OffsetPaginationInput: - offset: int = 0 - limit: int = -1 + offset: int = DEFAULT_OFFSET + limit: int = DEFAULT_LIMIT + + +@strawberry.type +class Paginated(Generic[NodeType]): + queryset: strawberry.Private[QuerySet] + pagination: strawberry.Private[OffsetPaginationInput] + + @strawberry.field + def limit(self) -> int: + return self.pagination.limit + + @strawberry.field + def offset(self) -> int: + return self.pagination.limit + + @strawberry.field(description="Total count of existing results.") + @django_resolver + def total_count(self) -> int: + return get_total_count(self.queryset) + + @strawberry.field(description="List of paginated results.") + @django_resolver + def results(self) -> list[NodeType]: + from strawberry_django.optimizer import is_optimized_by_prefetching + + if is_optimized_by_prefetching(self.queryset): + results = self.queryset._result_cache # type: ignore + else: + results = apply(self.pagination, self.queryset) + + return cast(list[NodeType], results) def apply( @@ -59,8 +94,11 @@ def apply( ) else: start = pagination.offset - stop = start + pagination.limit - queryset = queryset[start:stop] + if pagination.limit >= 0: + stop = start + pagination.limit + queryset = queryset[start:stop] + else: + queryset = queryset[start:] return queryset @@ -116,6 +154,32 @@ def apply_window_pagination( return queryset +def get_total_count(queryset: QuerySet) -> int: + """Get the total count of a queryset. + + Try to get the total count from the queryset cache, if it's optimized by + prefetching. Otherwise, fallback to the `QuerySet.count()` method. + """ + from strawberry_django.optimizer import is_optimized_by_prefetching + + if is_optimized_by_prefetching(queryset): + results = queryset._result_cache # type: ignore + + try: + return results[0]._strawberry_total_count if results else 0 + except AttributeError: + warnings.warn( + ( + "Pagination annotations not found, falling back to QuerySet resolution. " + "This might cause n+1 issues..." + ), + RuntimeWarning, + stacklevel=2, + ) + + return queryset.count() + + class StrawberryDjangoPagination(StrawberryDjangoFieldBase): def __init__(self, pagination: Union[bool, UnsetType] = UNSET, **kwargs): self.pagination = pagination @@ -126,10 +190,25 @@ def __copy__(self) -> Self: new_field.pagination = self.pagination return new_field + def _has_pagination(self) -> bool: + if isinstance(self.pagination, bool): + return self.pagination + + if self.is_paginated: + return True + + django_type = self.django_type + if django_type is not None and not issubclass( + django_type, strawberry.relay.Node + ): + return django_type.__strawberry_django_definition__.pagination + + return False + @property def arguments(self) -> list[StrawberryArgument]: arguments = [] - if self.base_resolver is None and self.is_list: + if self.base_resolver is None and (self.is_list or self.is_paginated): pagination = self.get_pagination() if pagination is not None: arguments.append( @@ -143,20 +222,7 @@ def arguments(self, value: list[StrawberryArgument]): return args_prop.fset(self, value) # type: ignore def get_pagination(self) -> Optional[type]: - has_pagination = self.pagination - - if isinstance(has_pagination, UnsetType): - django_type = self.django_type - has_pagination = ( - django_type.__strawberry_django_definition__.pagination - if ( - django_type is not None - and not issubclass(django_type, strawberry.relay.Node) - ) - else False - ) - - return OffsetPaginationInput if has_pagination else None + return OffsetPaginationInput if self._has_pagination() else None def apply_pagination( self, @@ -182,3 +248,28 @@ def get_queryset( pagination, related_field_id=_strawberry_related_field_id, ) + + def get_wrapped_result( + self, + result: _T, + info: Info, + *, + pagination: Optional[object] = None, + **kwargs, + ) -> Union[_T, Paginated[_T]]: + if not self.is_paginated: + return result + + if not isinstance(result, QuerySet): + raise TypeError(f"Result expected to be a queryset, got {result!r}") + + if ( + pagination not in (None, UNSET) # noqa: PLR6201 + and not isinstance(pagination, OffsetPaginationInput) + ): + raise TypeError(f"Don't know how to resolve pagination {pagination!r}") + + return Paginated( + queryset=result, + pagination=pagination or OffsetPaginationInput(), + ) diff --git a/strawberry_django/relay.py b/strawberry_django/relay.py index b8db964d..17750e74 100644 --- a/strawberry_django/relay.py +++ b/strawberry_django/relay.py @@ -22,6 +22,7 @@ from strawberry.utils.await_maybe import AwaitableOrValue from typing_extensions import Literal, Self +from strawberry_django.pagination import get_total_count from strawberry_django.queryset import run_type_get_queryset from strawberry_django.resolvers import django_getattr, django_resolver from strawberry_django.utils.typing import ( @@ -48,41 +49,12 @@ class ListConnectionWithTotalCount(relay.ListConnection[relay.NodeType]): @strawberry.field(description="Total quantity of existing nodes.") @django_resolver def total_count(self) -> Optional[int]: - from .optimizer import is_optimized_by_prefetching - assert self.nodes is not None - if isinstance(self.nodes, models.QuerySet) and is_optimized_by_prefetching( - self.nodes - ): - result = cast(list[relay.NodeType], self.nodes._result_cache) # type: ignore - try: - return ( - result[0]._strawberry_total_count # type: ignore - if result - else 0 - ) - except AttributeError: - warnings.warn( - ( - "Pagination annotations not found, falling back to QuerySet resolution. " - "This might cause n+1 issues..." - ), - RuntimeWarning, - stacklevel=2, - ) + if isinstance(self.nodes, models.QuerySet): + return get_total_count(self.nodes) - total_count = None - try: - total_count = cast( - "models.QuerySet[models.Model]", - self.nodes, - ).count() - except (AttributeError, ValueError, TypeError): - if isinstance(self.nodes, Sized): - total_count = len(self.nodes) - - return total_count + return len(self.nodes) if isinstance(self.nodes, Sized) else None @classmethod def resolve_connection( diff --git a/strawberry_django/type.py b/strawberry_django/type.py index 3ca5c2a4..8d234a7e 100644 --- a/strawberry_django/type.py +++ b/strawberry_django/type.py @@ -55,12 +55,12 @@ from .fields.types import get_model_field, resolve_model_field_name from .settings import strawberry_django_settings as django_settings -__all = [ - "StrawberryDjangoType", - "type", - "interface", +__all__ = [ + "StrawberryDjangoDefinition", "input", + "interface", "partial", + "type", ] _T = TypeVar("_T", bound=type) diff --git a/tests/test_paginated_type.py b/tests/test_paginated_type.py new file mode 100644 index 00000000..08641485 --- /dev/null +++ b/tests/test_paginated_type.py @@ -0,0 +1,413 @@ +import textwrap + +import pytest +import strawberry + +import strawberry_django +from strawberry_django.pagination import Paginated +from tests import models + + +def test_pagination_schema(): + @strawberry_django.type(models.Fruit) + class Fruit: + id: int + name: str + + @strawberry_django.type(models.Color) + class Color: + id: int + name: str + fruits: Paginated[Fruit] + + @strawberry.type + class Query: + fruits: Paginated[Fruit] = strawberry_django.field() + colors: Paginated[Color] = strawberry_django.field() + + schema = strawberry.Schema(query=Query) + + expected = '''\ + type Color { + id: Int! + name: String! + fruits(pagination: OffsetPaginationInput): FruitPaginated! + } + + type ColorPaginated { + limit: Int! + offset: Int! + + """Total count of existing results.""" + totalCount: Int! + + """List of paginated results.""" + results: [Color!]! + } + + type Fruit { + id: Int! + name: String! + } + + type FruitPaginated { + limit: Int! + offset: Int! + + """Total count of existing results.""" + totalCount: Int! + + """List of paginated results.""" + results: [Fruit!]! + } + + input OffsetPaginationInput { + offset: Int! = 0 + limit: Int! = -1 + } + + type Query { + fruits(pagination: OffsetPaginationInput): FruitPaginated! + colors(pagination: OffsetPaginationInput): ColorPaginated! + } + ''' + + assert textwrap.dedent(str(schema)) == textwrap.dedent(expected).strip() + + +@pytest.mark.django_db(transaction=True) +def test_pagination_query(): + @strawberry_django.type(models.Fruit) + class Fruit: + id: int + name: str + + @strawberry.type + class Query: + fruits: Paginated[Fruit] = strawberry_django.field() + + models.Fruit.objects.create(name="Apple") + models.Fruit.objects.create(name="Banana") + models.Fruit.objects.create(name="Strawberry") + + schema = strawberry.Schema(query=Query) + + query = """\ + query GetFruits ($pagination: OffsetPaginationInput) { + fruits (pagination: $pagination) { + totalCount + results { + name + } + } + } + """ + + res = schema.execute_sync(query) + assert res.errors is None + assert res.data == { + "fruits": { + "totalCount": 3, + "results": [{"name": "Apple"}, {"name": "Banana"}, {"name": "Strawberry"}], + } + } + + res = schema.execute_sync(query, variable_values={"pagination": {"limit": 1}}) + assert res.errors is None + assert res.data == { + "fruits": { + "totalCount": 3, + "results": [{"name": "Apple"}], + } + } + + res = schema.execute_sync( + query, variable_values={"pagination": {"limit": 1, "offset": 1}} + ) + assert res.errors is None + assert res.data == { + "fruits": { + "totalCount": 3, + "results": [{"name": "Banana"}], + } + } + + +@pytest.mark.django_db(transaction=True) +async def test_pagination_query_async(): + @strawberry_django.type(models.Fruit) + class Fruit: + id: int + name: str + + @strawberry.type + class Query: + fruits: Paginated[Fruit] = strawberry_django.field() + + await models.Fruit.objects.acreate(name="Apple") + await models.Fruit.objects.acreate(name="Banana") + await models.Fruit.objects.acreate(name="Strawberry") + + schema = strawberry.Schema(query=Query) + + query = """\ + query GetFruits ($pagination: OffsetPaginationInput) { + fruits (pagination: $pagination) { + totalCount + results { + name + } + } + } + """ + + res = await schema.execute(query) + assert res.errors is None + assert res.data == { + "fruits": { + "totalCount": 3, + "results": [{"name": "Apple"}, {"name": "Banana"}, {"name": "Strawberry"}], + } + } + + res = await schema.execute(query, variable_values={"pagination": {"limit": 1}}) + assert res.errors is None + assert res.data == { + "fruits": { + "totalCount": 3, + "results": [{"name": "Apple"}], + } + } + + res = await schema.execute( + query, variable_values={"pagination": {"limit": 1, "offset": 1}} + ) + assert res.errors is None + assert res.data == { + "fruits": { + "totalCount": 3, + "results": [{"name": "Banana"}], + } + } + + +@pytest.mark.django_db(transaction=True) +def test_pagination_nested_query(): + @strawberry_django.type(models.Fruit) + class Fruit: + id: int + name: str + + @strawberry_django.type(models.Color) + class Color: + id: int + name: str + fruits: Paginated[Fruit] + + @strawberry.type + class Query: + colors: Paginated[Color] = strawberry_django.field() + + red = models.Color.objects.create(name="Red") + yellow = models.Color.objects.create(name="Yellow") + + models.Fruit.objects.create(name="Apple", color=red) + models.Fruit.objects.create(name="Banana", color=yellow) + models.Fruit.objects.create(name="Strawberry", color=red) + + schema = strawberry.Schema(query=Query) + + query = """\ + query GetColors ($pagination: OffsetPaginationInput) { + colors { + totalCount + results { + fruits (pagination: $pagination) { + totalCount + results { + name + } + } + } + } + } + """ + + res = schema.execute_sync(query) + assert res.errors is None + assert res.data == { + "colors": { + "totalCount": 2, + "results": [ + { + "fruits": { + "totalCount": 2, + "results": [{"name": "Apple"}, {"name": "Strawberry"}], + } + }, + { + "fruits": { + "totalCount": 1, + "results": [{"name": "Banana"}], + } + }, + ], + } + } + + res = schema.execute_sync(query, variable_values={"pagination": {"limit": 1}}) + assert res.errors is None + assert res.data == { + "colors": { + "totalCount": 2, + "results": [ + { + "fruits": { + "totalCount": 2, + "results": [{"name": "Apple"}], + } + }, + { + "fruits": { + "totalCount": 1, + "results": [{"name": "Banana"}], + } + }, + ], + } + } + + res = schema.execute_sync( + query, variable_values={"pagination": {"limit": 1, "offset": 1}} + ) + assert res.errors is None + assert res.data == { + "colors": { + "totalCount": 2, + "results": [ + { + "fruits": { + "totalCount": 2, + "results": [{"name": "Strawberry"}], + } + }, + { + "fruits": { + "totalCount": 1, + "results": [], + } + }, + ], + } + } + + +@pytest.mark.django_db(transaction=True) +async def test_pagination_nested_query_async(): + @strawberry_django.type(models.Fruit) + class Fruit: + id: int + name: str + + @strawberry_django.type(models.Color) + class Color: + id: int + name: str + fruits: Paginated[Fruit] + + @strawberry.type + class Query: + colors: Paginated[Color] = strawberry_django.field() + + red = await models.Color.objects.acreate(name="Red") + yellow = await models.Color.objects.acreate(name="Yellow") + + await models.Fruit.objects.acreate(name="Apple", color=red) + await models.Fruit.objects.acreate(name="Banana", color=yellow) + await models.Fruit.objects.acreate(name="Strawberry", color=red) + + schema = strawberry.Schema(query=Query) + + query = """\ + query GetColors ($pagination: OffsetPaginationInput) { + colors { + totalCount + results { + fruits (pagination: $pagination) { + totalCount + results { + name + } + } + } + } + } + """ + + res = await schema.execute(query) + assert res.errors is None + assert res.data == { + "colors": { + "totalCount": 2, + "results": [ + { + "fruits": { + "totalCount": 2, + "results": [{"name": "Apple"}, {"name": "Strawberry"}], + } + }, + { + "fruits": { + "totalCount": 1, + "results": [{"name": "Banana"}], + } + }, + ], + } + } + + res = await schema.execute(query, variable_values={"pagination": {"limit": 1}}) + assert res.errors is None + assert res.data == { + "colors": { + "totalCount": 2, + "results": [ + { + "fruits": { + "totalCount": 2, + "results": [{"name": "Apple"}], + } + }, + { + "fruits": { + "totalCount": 1, + "results": [{"name": "Banana"}], + } + }, + ], + } + } + + res = await schema.execute( + query, variable_values={"pagination": {"limit": 1, "offset": 1}} + ) + assert res.errors is None + assert res.data == { + "colors": { + "totalCount": 2, + "results": [ + { + "fruits": { + "totalCount": 2, + "results": [{"name": "Strawberry"}], + } + }, + { + "fruits": { + "totalCount": 1, + "results": [], + } + }, + ], + } + } From ed5019b932bdf0efb227e11acf940ebf113137a9 Mon Sep 17 00:00:00 2001 From: Thiago Bellini Ribeiro Date: Thu, 17 Oct 2024 00:09:35 -0300 Subject: [PATCH 02/10] Make optimizer work with Paginated results --- strawberry_django/fields/field.py | 4 +- strawberry_django/optimizer.py | 79 +++++++- strawberry_django/pagination.py | 51 +++-- strawberry_django/resolvers.py | 4 +- tests/projects/schema.py | 14 ++ tests/projects/snapshots/schema.gql | 15 ++ .../snapshots/schema_with_inheritance.gql | 13 ++ tests/test_optimizer.py | 180 ++++++++++++++++++ 8 files changed, 339 insertions(+), 21 deletions(-) diff --git a/strawberry_django/fields/field.py b/strawberry_django/fields/field.py index 7f9fe3df..c2d52b98 100644 --- a/strawberry_django/fields/field.py +++ b/strawberry_django/fields/field.py @@ -246,8 +246,8 @@ def resolve(): kwargs["info"] = info result = django_resolver( - partial(self.get_wrapped_result, **kwargs), - qs_hook=self.get_queryset_hook(**kwargs), + self.get_queryset_hook(**kwargs), + qs_hook=partial(self.get_wrapped_result, **kwargs), )(result) return result diff --git a/strawberry_django/optimizer.py b/strawberry_django/optimizer.py index cf3b6d86..99c761d8 100644 --- a/strawberry_django/optimizer.py +++ b/strawberry_django/optimizer.py @@ -49,7 +49,7 @@ from typing_extensions import assert_never, assert_type from strawberry_django.fields.types import resolve_model_field_name -from strawberry_django.pagination import apply_window_pagination +from strawberry_django.pagination import Paginated, apply_window_pagination from strawberry_django.queryset import get_queryset_config, run_type_get_queryset from strawberry_django.relay import ListConnectionWithTotalCount from strawberry_django.resolvers import django_fetch @@ -528,7 +528,7 @@ def _optimize_prefetch_queryset( ) field_kwargs.pop("info", None) - # Disable the optimizer to avoid doint double optimization while running get_queryset + # Disable the optimizer to avoid doing double optimization while running get_queryset with DjangoOptimizerExtension.disabled(): qs = field.get_queryset( qs, @@ -574,6 +574,15 @@ def _optimize_prefetch_queryset( else: mark_optimized = False + if isinstance(field.type, type) and issubclass(field.type, Paginated): + pagination = field_kwargs.get("pagination") + qs = apply_window_pagination( + qs, + related_field_id=related_field_id, + offset=pagination.offset if pagination else 0, + limit=pagination.limit if pagination else -1, + ) + if mark_optimized: qs = mark_optimized_by_prefetching(qs) @@ -977,8 +986,7 @@ def _get_model_hints( ) -> OptimizerStore | None: cache = cache or {} - # In case this is a relay field, find the selected edges/nodes, the selected fields - # are actually inside edges -> node selection... + # In case this is a relay field, the selected fields are inside edges -> node selection if issubclass(object_definition.origin, relay.Connection): return _get_model_hints_from_connection( model, @@ -992,6 +1000,20 @@ def _get_model_hints( level=level, ) + # In case this is a relay field, the selected fields are inside results selection + if issubclass(object_definition.origin, Paginated): + return _get_model_hints_from_paginated( + model, + schema, + object_definition, + parent_type=parent_type, + info=info, + config=config, + prefix=prefix, + cache=cache, + level=level, + ) + store = OptimizerStore() config = config or OptimizerConfig() @@ -1156,6 +1178,55 @@ def _get_model_hints_from_connection( return store +def _get_model_hints_from_paginated( + model: type[models.Model], + schema: Schema, + object_definition: StrawberryObjectDefinition, + *, + parent_type: GraphQLObjectType | GraphQLInterfaceType, + info: GraphQLResolveInfo, + config: OptimizerConfig | None = None, + prefix: str = "", + cache: dict[type[models.Model], list[tuple[int, OptimizerStore]]] | None = None, + level: int = 0, +) -> OptimizerStore | None: + store = None + + n_type = object_definition.type_var_map.get("NodeType") + n_definition = get_object_definition(n_type, strict=True) + n_gql_definition = _get_gql_definition( + schema, + get_object_definition(n_type, strict=True), + ) + assert isinstance(n_gql_definition, (GraphQLObjectType, GraphQLInterfaceType)) + + for selections in _get_selections(info, parent_type).values(): + selection = selections[0] + if selection.name.value != "results": + continue + + n_info = _generate_selection_resolve_info( + info, + selections, + n_gql_definition, + n_gql_definition, + ) + + store = _get_model_hints( + model=model, + schema=schema, + object_definition=n_definition, + parent_type=n_gql_definition, + info=n_info, + config=config, + prefix=prefix, + cache=cache, + level=level, + ) + + return store + + def optimize( qs: QuerySet[_M] | BaseManager[_M], info: GraphQLResolveInfo | Info, diff --git a/strawberry_django/pagination.py b/strawberry_django/pagination.py index 14bc53d2..99982f07 100644 --- a/strawberry_django/pagination.py +++ b/strawberry_django/pagination.py @@ -45,7 +45,7 @@ def offset(self) -> int: @strawberry.field(description="Total count of existing results.") @django_resolver - def total_count(self) -> int: + def total_count(self, root) -> int: return get_total_count(self.queryset) @strawberry.field(description="List of paginated results.") @@ -103,6 +103,10 @@ def apply( return queryset +class _PaginationWindow(Window): + """Marker to be able to remove where clause at `get_total_count` if needed.""" + + def apply_window_pagination( queryset: _QS, *, @@ -131,13 +135,14 @@ def apply_window_pagination( using=queryset._db or DEFAULT_DB_ALIAS # type: ignore ).get_order_by() ] + queryset = queryset.annotate( - _strawberry_row_number=Window( + _strawberry_row_number=_PaginationWindow( RowNumber(), partition_by=related_field_id, order_by=order_by, ), - _strawberry_total_count=Window( + _strawberry_total_count=_PaginationWindow( Count(1), partition_by=related_field_id, ), @@ -165,17 +170,31 @@ def get_total_count(queryset: QuerySet) -> int: if is_optimized_by_prefetching(queryset): results = queryset._result_cache # type: ignore - try: - return results[0]._strawberry_total_count if results else 0 - except AttributeError: - warnings.warn( - ( - "Pagination annotations not found, falling back to QuerySet resolution. " - "This might cause n+1 issues..." - ), - RuntimeWarning, - stacklevel=2, + if results: + try: + return results[0]._strawberry_total_count + except AttributeError: + warnings.warn( + ( + "Pagination annotations not found, falling back to QuerySet resolution. " + "This might cause n+1 issues..." + ), + RuntimeWarning, + stacklevel=2, + ) + + # If we have no results, we can't get the total count from the cache. + # In this case we will remove the pagination filter to be able to `.count()` + # the whole queryset with its original filters. + queryset = queryset._chain() # type: ignore + queryset.query.where.children = [ + child + for child in queryset.query.where.children + if ( + not hasattr(child, "lhs") + or not isinstance(child.lhs, _PaginationWindow) ) + ] return queryset.count() @@ -243,6 +262,12 @@ def get_queryset( **kwargs, ) -> _QS: queryset = super().get_queryset(queryset, info, **kwargs) + + # If this is `Paginated`, return the queryset as is as the pagination will + # be resolved when resolving its results. + if self.is_paginated: + return queryset + return self.apply_pagination( queryset, pagination, diff --git a/strawberry_django/resolvers.py b/strawberry_django/resolvers.py index c4e186b0..491b486e 100644 --- a/strawberry_django/resolvers.py +++ b/strawberry_django/resolvers.py @@ -193,7 +193,7 @@ def resolve_base_manager(manager: BaseManager) -> Any: # prevents us from importing and checking isinstance on them directly. try: # ManyRelatedManager - return list(prefetched_cache[manager.prefetch_cache_name]) # type: ignore + return prefetched_cache[manager.prefetch_cache_name] # type: ignore except (AttributeError, KeyError): try: # RelatedManager @@ -203,7 +203,7 @@ def resolve_base_manager(manager: BaseManager) -> Any: getattr(result_field.remote_field, "cache_name", None) or result_field.remote_field.get_cache_name() ) - return list(prefetched_cache[cache_name]) + return prefetched_cache[cache_name] except (AttributeError, KeyError): pass diff --git a/tests/projects/schema.py b/tests/projects/schema.py index bedd5934..3e8c1fa7 100644 --- a/tests/projects/schema.py +++ b/tests/projects/schema.py @@ -29,6 +29,7 @@ from strawberry_django.fields.types import ListInput, NodeInput, NodeInputPartial from strawberry_django.mutations import resolvers from strawberry_django.optimizer import DjangoOptimizerExtension +from strawberry_django.pagination import Paginated from strawberry_django.permissions import ( HasPerm, HasRetvalPerm, @@ -158,6 +159,10 @@ class MilestoneType(relay.Node, Named): order=IssueOrder, pagination=True, ) + issues_paginated: Paginated["IssueType"] = strawberry_django.field( + field_name="issues", + order=IssueOrder, + ) issues_with_filters: ListConnectionWithTotalCount["IssueType"] = ( strawberry_django.connection( field_name="issues", @@ -375,6 +380,7 @@ class Query: staff_list: list[Optional[StaffType]] = strawberry_django.node() issue_list: list[IssueType] = strawberry_django.field() + issues_paginated: Paginated[IssueType] = strawberry_django.field() milestone_list: list[MilestoneType] = strawberry_django.field( order=MilestoneOrder, filters=MilestoneFilter, @@ -429,6 +435,9 @@ class Query: issue_list_perm_required: list[IssueType] = strawberry_django.field( extensions=[HasPerm(perms=["projects.view_issue"])], ) + issue_paginated_list_perm_required: Paginated[IssueType] = strawberry_django.field( + extensions=[HasPerm(perms=["projects.view_issue"])], + ) issue_conn_perm_required: ListConnectionWithTotalCount[IssueType] = ( strawberry_django.connection( extensions=[HasPerm(perms=["projects.view_issue"])], @@ -447,6 +456,11 @@ class Query: issue_list_obj_perm_required_paginated: list[IssueType] = strawberry_django.field( extensions=[HasRetvalPerm(perms=["projects.view_issue"])], pagination=True ) + issue_paginated_list_obj_perm_required_paginated: Paginated[IssueType] = ( + strawberry_django.field( + extensions=[HasRetvalPerm(perms=["projects.view_issue"])], pagination=True + ) + ) issue_conn_obj_perm_required: ListConnectionWithTotalCount[IssueType] = ( strawberry_django.connection( extensions=[HasRetvalPerm(perms=["projects.view_issue"])], diff --git a/tests/projects/snapshots/schema.gql b/tests/projects/snapshots/schema.gql index 16533658..212a75ed 100644 --- a/tests/projects/snapshots/schema.gql +++ b/tests/projects/snapshots/schema.gql @@ -336,6 +336,17 @@ type IssueTypeEdge { node: IssueType! } +type IssueTypePaginated { + limit: Int! + offset: Int! + + """Total count of existing results.""" + totalCount: Int! + + """List of paginated results.""" + results: [IssueType!]! +} + input MilestoneFilter { name: StrFilterLookup project: DjangoModelFilterInput @@ -373,6 +384,7 @@ type MilestoneType implements Node & Named { dueDate: Date project: ProjectType! issues(filters: IssueFilter, order: IssueOrder, pagination: OffsetPaginationInput): [IssueType!]! + issuesPaginated(pagination: OffsetPaginationInput): IssueTypePaginated! issuesWithFilters( filters: IssueFilter @@ -610,6 +622,7 @@ type Query { ids: [GlobalID!]! ): [StaffType]! issueList: [IssueType!]! + issuesPaginated(pagination: OffsetPaginationInput): IssueTypePaginated! milestoneList(filters: MilestoneFilter, order: MilestoneOrder, pagination: OffsetPaginationInput): [MilestoneType!]! projectList(filters: ProjectFilter): [ProjectType!]! tagList: [TagType!]! @@ -729,6 +742,7 @@ type Query { id: GlobalID! ): IssueType @hasPerm(permissions: [{app: "projects", permission: "view_issue"}], any: true) issueListPermRequired: [IssueType!]! @hasPerm(permissions: [{app: "projects", permission: "view_issue"}], any: true) + issuePaginatedListPermRequired(pagination: OffsetPaginationInput): IssueTypePaginated! @hasPerm(permissions: [{app: "projects", permission: "view_issue"}], any: true) issueConnPermRequired( """Returns the items in the list that come before the specified cursor.""" before: String = null @@ -752,6 +766,7 @@ type Query { ): IssueType @hasRetvalPerm(permissions: [{app: "projects", permission: "view_issue"}], any: true) issueListObjPermRequired: [IssueType!]! @hasRetvalPerm(permissions: [{app: "projects", permission: "view_issue"}], any: true) issueListObjPermRequiredPaginated(pagination: OffsetPaginationInput): [IssueType!]! @hasRetvalPerm(permissions: [{app: "projects", permission: "view_issue"}], any: true) + issuePaginatedListObjPermRequiredPaginated(pagination: OffsetPaginationInput): IssueTypePaginated! @hasRetvalPerm(permissions: [{app: "projects", permission: "view_issue"}], any: true) issueConnObjPermRequired( """Returns the items in the list that come before the specified cursor.""" before: String = null diff --git a/tests/projects/snapshots/schema_with_inheritance.gql b/tests/projects/snapshots/schema_with_inheritance.gql index 8c891a65..2dcf1e1a 100644 --- a/tests/projects/snapshots/schema_with_inheritance.gql +++ b/tests/projects/snapshots/schema_with_inheritance.gql @@ -136,6 +136,17 @@ type IssueTypeEdge { node: IssueType! } +type IssueTypePaginated { + limit: Int! + offset: Int! + + """Total count of existing results.""" + totalCount: Int! + + """List of paginated results.""" + results: [IssueType!]! +} + input MilestoneFilter { name: StrFilterLookup project: DjangoModelFilterInput @@ -163,6 +174,7 @@ type MilestoneType implements Node & Named { dueDate: Date project: ProjectType! issues(filters: IssueFilter, order: IssueOrder, pagination: OffsetPaginationInput): [IssueType!]! + issuesPaginated(pagination: OffsetPaginationInput): IssueTypePaginated! issuesWithFilters( filters: IssueFilter @@ -190,6 +202,7 @@ type MilestoneTypeSubclass implements Node & Named { dueDate: Date project: ProjectType! issues(filters: IssueFilter, order: IssueOrder, pagination: OffsetPaginationInput): [IssueType!]! + issuesPaginated(pagination: OffsetPaginationInput): IssueTypePaginated! issuesWithFilters( filters: IssueFilter diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index a2046170..5ed3d957 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -1425,3 +1425,183 @@ class Query: "issues": [{"pk": str(issue1.pk)}], }, } + + +@pytest.mark.django_db(transaction=True) +def test_query_paginated(db, gql_client: GraphQLTestClient): + query = """ + query TestQuery ($pagination: OffsetPaginationInput) { + issuesPaginated (pagination: $pagination) { + totalCount + results { + name + milestone { + name + } + } + } + } + """ + + milestone1 = MilestoneFactory.create() + milestone2 = MilestoneFactory.create() + + issue1 = IssueFactory.create(milestone=milestone1) + issue2 = IssueFactory.create(milestone=milestone1) + issue3 = IssueFactory.create(milestone=milestone1) + issue4 = IssueFactory.create(milestone=milestone2) + issue5 = IssueFactory.create(milestone=milestone2) + + with assert_num_queries(2 if DjangoOptimizerExtension.enabled.get() else 7): + res = gql_client.query(query) + + assert res.data == { + "issuesPaginated": { + "totalCount": 5, + "results": [ + {"name": issue1.name, "milestone": {"name": milestone1.name}}, + {"name": issue2.name, "milestone": {"name": milestone1.name}}, + {"name": issue3.name, "milestone": {"name": milestone1.name}}, + {"name": issue4.name, "milestone": {"name": milestone2.name}}, + {"name": issue5.name, "milestone": {"name": milestone2.name}}, + ], + } + } + + with assert_num_queries(2 if DjangoOptimizerExtension.enabled.get() else 4): + res = gql_client.query(query, variables={"pagination": {"limit": 2}}) + + assert res.data == { + "issuesPaginated": { + "totalCount": 5, + "results": [ + {"name": issue1.name, "milestone": {"name": milestone1.name}}, + {"name": issue2.name, "milestone": {"name": milestone1.name}}, + ], + } + } + + with assert_num_queries(2 if DjangoOptimizerExtension.enabled.get() else 4): + res = gql_client.query( + query, variables={"pagination": {"limit": 2, "offset": 2}} + ) + + assert res.data == { + "issuesPaginated": { + "totalCount": 5, + "results": [ + {"name": issue3.name, "milestone": {"name": milestone1.name}}, + {"name": issue4.name, "milestone": {"name": milestone2.name}}, + ], + } + } + + +@pytest.mark.django_db(transaction=True) +def test_query_paginated_nested(db, gql_client: GraphQLTestClient): + query = """ + query TestQuery ($pagination: OffsetPaginationInput) { + milestoneList { + name + issuesPaginated (pagination: $pagination) { + totalCount + results { + name + milestone { + name + } + } + } + } + } + """ + + milestone1 = MilestoneFactory.create() + milestone2 = MilestoneFactory.create() + + issue1 = IssueFactory.create(milestone=milestone1) + issue2 = IssueFactory.create(milestone=milestone1) + issue3 = IssueFactory.create(milestone=milestone1) + issue4 = IssueFactory.create(milestone=milestone2) + issue5 = IssueFactory.create(milestone=milestone2) + + with assert_num_queries(2 if DjangoOptimizerExtension.enabled.get() else 5): + res = gql_client.query(query) + + assert res.data == { + "milestoneList": [ + { + "name": milestone1.name, + "issuesPaginated": { + "totalCount": 3, + "results": [ + {"name": issue1.name, "milestone": {"name": milestone1.name}}, + {"name": issue2.name, "milestone": {"name": milestone1.name}}, + {"name": issue3.name, "milestone": {"name": milestone1.name}}, + ], + }, + }, + { + "name": milestone2.name, + "issuesPaginated": { + "totalCount": 2, + "results": [ + {"name": issue4.name, "milestone": {"name": milestone2.name}}, + {"name": issue5.name, "milestone": {"name": milestone2.name}}, + ], + }, + }, + ] + } + + with assert_num_queries(2 if DjangoOptimizerExtension.enabled.get() else 5): + res = gql_client.query(query, variables={"pagination": {"limit": 1}}) + + assert res.data == { + "milestoneList": [ + { + "name": milestone1.name, + "issuesPaginated": { + "totalCount": 3, + "results": [ + {"name": issue1.name, "milestone": {"name": milestone1.name}}, + ], + }, + }, + { + "name": milestone2.name, + "issuesPaginated": { + "totalCount": 2, + "results": [ + {"name": issue4.name, "milestone": {"name": milestone2.name}}, + ], + }, + }, + ] + } + + with assert_num_queries(3 if DjangoOptimizerExtension.enabled.get() else 5): + res = gql_client.query( + query, variables={"pagination": {"limit": 1, "offset": 2}} + ) + + assert res.data == { + "milestoneList": [ + { + "name": milestone1.name, + "issuesPaginated": { + "totalCount": 3, + "results": [ + {"name": issue3.name, "milestone": {"name": milestone1.name}}, + ], + }, + }, + { + "name": milestone2.name, + "issuesPaginated": { + "totalCount": 2, + "results": [], + }, + }, + ] + } From eda46e7a1d21be98b3701b15488b32845e83beca Mon Sep 17 00:00:00 2001 From: Thiago Bellini Ribeiro Date: Thu, 17 Oct 2024 10:34:41 -0300 Subject: [PATCH 03/10] Make permissions work with Paginated results --- strawberry_django/pagination.py | 8 +- strawberry_django/permissions.py | 4 + tests/projects/schema.py | 8 +- tests/projects/snapshots/schema.gql | 4 +- tests/test_permissions.py | 167 +++++++++++++++++++++++++++- 5 files changed, 182 insertions(+), 9 deletions(-) diff --git a/strawberry_django/pagination.py b/strawberry_django/pagination.py index 99982f07..34145afa 100644 --- a/strawberry_django/pagination.py +++ b/strawberry_django/pagination.py @@ -32,7 +32,7 @@ class OffsetPaginationInput: @strawberry.type class Paginated(Generic[NodeType]): - queryset: strawberry.Private[QuerySet] + queryset: strawberry.Private[Optional[QuerySet]] pagination: strawberry.Private[OffsetPaginationInput] @strawberry.field @@ -46,6 +46,9 @@ def offset(self) -> int: @strawberry.field(description="Total count of existing results.") @django_resolver def total_count(self, root) -> int: + if self.queryset is None: + return 0 + return get_total_count(self.queryset) @strawberry.field(description="List of paginated results.") @@ -53,6 +56,9 @@ def total_count(self, root) -> int: def results(self) -> list[NodeType]: from strawberry_django.optimizer import is_optimized_by_prefetching + if self.queryset is None: + return [] + if is_optimized_by_prefetching(self.queryset): results = self.queryset._result_cache # type: ignore else: diff --git a/strawberry_django/permissions.py b/strawberry_django/permissions.py index dab1995a..71ee7eb3 100644 --- a/strawberry_django/permissions.py +++ b/strawberry_django/permissions.py @@ -39,6 +39,7 @@ from strawberry_django.auth.utils import aget_current_user, get_current_user from strawberry_django.fields.types import OperationInfo, OperationMessage +from strawberry_django.pagination import OffsetPaginationInput, Paginated from strawberry_django.resolvers import django_resolver from .utils.query import filter_for_user @@ -405,6 +406,9 @@ def handle_no_permission(self, exception: BaseException, *, info: Info): if isinstance(ret_type, StrawberryList): return [] + if isinstance(ret_type, type) and issubclass(ret_type, Paginated): + return Paginated(queryset=None, pagination=OffsetPaginationInput()) + # If it is a Connection, try to return an empty connection, but only if # it is the only possibility available... for ret_possibility in ret_types: diff --git a/tests/projects/schema.py b/tests/projects/schema.py index 3e8c1fa7..f9814c33 100644 --- a/tests/projects/schema.py +++ b/tests/projects/schema.py @@ -435,7 +435,7 @@ class Query: issue_list_perm_required: list[IssueType] = strawberry_django.field( extensions=[HasPerm(perms=["projects.view_issue"])], ) - issue_paginated_list_perm_required: Paginated[IssueType] = strawberry_django.field( + issues_paginated_perm_required: Paginated[IssueType] = strawberry_django.field( extensions=[HasPerm(perms=["projects.view_issue"])], ) issue_conn_perm_required: ListConnectionWithTotalCount[IssueType] = ( @@ -456,10 +456,8 @@ class Query: issue_list_obj_perm_required_paginated: list[IssueType] = strawberry_django.field( extensions=[HasRetvalPerm(perms=["projects.view_issue"])], pagination=True ) - issue_paginated_list_obj_perm_required_paginated: Paginated[IssueType] = ( - strawberry_django.field( - extensions=[HasRetvalPerm(perms=["projects.view_issue"])], pagination=True - ) + issues_paginated_obj_perm_required: Paginated[IssueType] = strawberry_django.field( + extensions=[HasRetvalPerm(perms=["projects.view_issue"])], ) issue_conn_obj_perm_required: ListConnectionWithTotalCount[IssueType] = ( strawberry_django.connection( diff --git a/tests/projects/snapshots/schema.gql b/tests/projects/snapshots/schema.gql index 212a75ed..2fc46b1c 100644 --- a/tests/projects/snapshots/schema.gql +++ b/tests/projects/snapshots/schema.gql @@ -742,7 +742,7 @@ type Query { id: GlobalID! ): IssueType @hasPerm(permissions: [{app: "projects", permission: "view_issue"}], any: true) issueListPermRequired: [IssueType!]! @hasPerm(permissions: [{app: "projects", permission: "view_issue"}], any: true) - issuePaginatedListPermRequired(pagination: OffsetPaginationInput): IssueTypePaginated! @hasPerm(permissions: [{app: "projects", permission: "view_issue"}], any: true) + issuesPaginatedPermRequired(pagination: OffsetPaginationInput): IssueTypePaginated! @hasPerm(permissions: [{app: "projects", permission: "view_issue"}], any: true) issueConnPermRequired( """Returns the items in the list that come before the specified cursor.""" before: String = null @@ -766,7 +766,7 @@ type Query { ): IssueType @hasRetvalPerm(permissions: [{app: "projects", permission: "view_issue"}], any: true) issueListObjPermRequired: [IssueType!]! @hasRetvalPerm(permissions: [{app: "projects", permission: "view_issue"}], any: true) issueListObjPermRequiredPaginated(pagination: OffsetPaginationInput): [IssueType!]! @hasRetvalPerm(permissions: [{app: "projects", permission: "view_issue"}], any: true) - issuePaginatedListObjPermRequiredPaginated(pagination: OffsetPaginationInput): IssueTypePaginated! @hasRetvalPerm(permissions: [{app: "projects", permission: "view_issue"}], any: true) + issuesPaginatedObjPermRequired(pagination: OffsetPaginationInput): IssueTypePaginated! @hasRetvalPerm(permissions: [{app: "projects", permission: "view_issue"}], any: true) issueConnObjPermRequired( """Returns the items in the list that come before the specified cursor.""" before: String = null diff --git a/tests/test_permissions.py b/tests/test_permissions.py index c1c492ee..c6e79318 100644 --- a/tests/test_permissions.py +++ b/tests/test_permissions.py @@ -9,11 +9,12 @@ from .projects.faker import ( GroupFactory, IssueFactory, + MilestoneFactory, StaffUserFactory, SuperuserUserFactory, UserFactory, ) -from .utils import GraphQLTestClient +from .utils import GraphQLTestClient, assert_num_queries PermKind: TypeAlias = Literal["user", "group", "superuser"] perm_kinds: list[PermKind] = ["user", "group", "superuser"] @@ -934,3 +935,167 @@ def test_conn_obj_perm_required(db, gql_client: GraphQLTestClient, kind: PermKin "totalCount": 1, }, } + + +@pytest.mark.django_db(transaction=True) +def test_query_paginated_with_permissions(db, gql_client: GraphQLTestClient): + query = """ + query TestQuery ($pagination: OffsetPaginationInput) { + issuesPaginatedPermRequired (pagination: $pagination) { + totalCount + results { + name + milestone { + name + } + } + } + } + """ + + milestone1 = MilestoneFactory.create() + milestone2 = MilestoneFactory.create() + + issue1 = IssueFactory.create(milestone=milestone1) + issue2 = IssueFactory.create(milestone=milestone1) + issue3 = IssueFactory.create(milestone=milestone1) + issue4 = IssueFactory.create(milestone=milestone2) + issue5 = IssueFactory.create(milestone=milestone2) + + # No user logged in + with assert_num_queries(0): + res = gql_client.query(query) + + assert res.data == { + "issuesPaginatedPermRequired": { + "totalCount": 0, + "results": [], + } + } + + user = UserFactory.create() + + # User logged in without permissions + with gql_client.login(user): + with assert_num_queries(4): + res = gql_client.query(query) + + assert res.data == { + "issuesPaginatedPermRequired": { + "totalCount": 0, + "results": [], + } + } + + # User logged in with permissions + user.user_permissions.add(Permission.objects.get(codename="view_issue")) + with gql_client.login(user): + with assert_num_queries(6 if DjangoOptimizerExtension.enabled.get() else 11): + res = gql_client.query(query) + + assert res.data == { + "issuesPaginatedPermRequired": { + "totalCount": 5, + "results": [ + {"name": issue1.name, "milestone": {"name": milestone1.name}}, + {"name": issue2.name, "milestone": {"name": milestone1.name}}, + {"name": issue3.name, "milestone": {"name": milestone1.name}}, + {"name": issue4.name, "milestone": {"name": milestone2.name}}, + {"name": issue5.name, "milestone": {"name": milestone2.name}}, + ], + } + } + + with assert_num_queries(6 if DjangoOptimizerExtension.enabled.get() else 8): + res = gql_client.query(query, variables={"pagination": {"limit": 2}}) + + assert res.data == { + "issuesPaginatedPermRequired": { + "totalCount": 5, + "results": [ + {"name": issue1.name, "milestone": {"name": milestone1.name}}, + {"name": issue2.name, "milestone": {"name": milestone1.name}}, + ], + } + } + + +@pytest.mark.django_db(transaction=True) +def test_query_paginated_with_obj_permissions(db, gql_client: GraphQLTestClient): + query = """ + query TestQuery ($pagination: OffsetPaginationInput) { + issuesPaginatedObjPermRequired (pagination: $pagination) { + totalCount + results { + name + milestone { + name + } + } + } + } + """ + + milestone1 = MilestoneFactory.create() + milestone2 = MilestoneFactory.create() + + IssueFactory.create(milestone=milestone1) + issue2 = IssueFactory.create(milestone=milestone1) + IssueFactory.create(milestone=milestone1) + issue4 = IssueFactory.create(milestone=milestone2) + IssueFactory.create(milestone=milestone2) + + # No user logged in + with assert_num_queries(0): + res = gql_client.query(query) + + assert res.data == { + "issuesPaginatedObjPermRequired": { + "totalCount": 0, + "results": [], + } + } + + user = UserFactory.create() + + # User logged in without permissions + with gql_client.login(user): + with assert_num_queries(5): + res = gql_client.query(query) + + assert res.data == { + "issuesPaginatedObjPermRequired": { + "totalCount": 0, + "results": [], + } + } + + assign_perm("view_issue", user, issue2) + assign_perm("view_issue", user, issue4) + + # User logged in with permissions + with gql_client.login(user): + with assert_num_queries(4 if DjangoOptimizerExtension.enabled.get() else 6): + res = gql_client.query(query) + + assert res.data == { + "issuesPaginatedObjPermRequired": { + "totalCount": 2, + "results": [ + {"name": issue2.name, "milestone": {"name": milestone1.name}}, + {"name": issue4.name, "milestone": {"name": milestone2.name}}, + ], + } + } + + with assert_num_queries(4 if DjangoOptimizerExtension.enabled.get() else 5): + res = gql_client.query(query, variables={"pagination": {"limit": 1}}) + + assert res.data == { + "issuesPaginatedObjPermRequired": { + "totalCount": 2, + "results": [ + {"name": issue2.name, "milestone": {"name": milestone1.name}}, + ], + } + } From 4c9c06849c1f0057fb9156f3ae5bbc0d8c903324 Mon Sep 17 00:00:00 2001 From: Thiago Bellini Ribeiro Date: Sat, 19 Oct 2024 10:48:58 -0300 Subject: [PATCH 04/10] Document the new `Paginated` generic --- docs/guide/pagination.md | 180 +++++++++++++++++- strawberry_django/pagination.py | 53 ++++-- tests/projects/snapshots/schema.gql | 8 +- .../snapshots/schema_with_inheritance.gql | 8 +- tests/test_paginated_type.py | 13 +- 5 files changed, 230 insertions(+), 32 deletions(-) diff --git a/docs/guide/pagination.md b/docs/guide/pagination.md index 339c199a..7851d7a7 100644 --- a/docs/guide/pagination.md +++ b/docs/guide/pagination.md @@ -12,20 +12,190 @@ An interface for limit/offset pagination can be use for basic pagination needs: @strawberry_django.type(models.Fruit, pagination=True) class Fruit: name: auto + + +@strawberry.type +class Query: + fruits: list[Fruit] = strawberry_django.field() +``` + +Would produce the following schema: + +```graphql title="schema.graphql" +type Fruit { + name: String! +} + +type Query { + fruits(pagination: PaginationInput): [Fruit!]! +} ``` +And can be queried like: + ```graphql title="schema.graphql" query { fruits(pagination: { offset: 0, limit: 2 }) { name - color } } ``` -There is not default limit defined. All elements are returned if no pagination limit is defined. +The `pagination` argument can be given to the type, which will enforce the pagination +argument every time the field is annotated as a list, but you can also give it directly +to the field for more control, like: + +```python title="types.py" +@strawberry_django.type(models.Fruit) +class Fruit: + name: auto + + +@strawberry.type +class Query: + fruits: list[Fruit] = strawberry_django.field(pagination=True) +``` + +Which will produce the exact same schema. + +> [!NOTE] +> There is no default limit defined. All elements are returned if no pagination limit is defined. + +## Paginated Generic + +For more complex pagination needs, you can use the `Paginated` generic, which alongside +the `pagination` argument, will wrap the results in an object that contains the results +and the pagination information, together with the `totalCount` of elements excluding pagination. + +```python title="types.py" +from strawberry_django.pagination import Paginated + + +@strawberry_django.type(models.Fruit) +class Fruit: + name: auto + + +@strawberry.type +class Query: + fruits: Paginated[Fruit] = strawberry_django.field() +``` + +Would produce the following schema: + +```graphql title="schema.graphql" +type Fruit { + name: String! +} + +type PaginatedInfo { + limit: Int! + offset: Int! +} + +type FruitPaginated { + pageInfo: PaginatedInfo! + totalCount: Int! + results: [Fruit]! +} + +type Query { + fruits(pagination: PaginationInput): [FruitPaginated!]! +} +``` + +Which can be queried like: + +```graphql title="schema.graphql" +query { + fruits(pagination: { offset: 0, limit: 2 }) { + totalCount + pageInfo { + limit + offset + } + results { + name + } + } +} +``` + +### Customizing the pagination + +Like other generics, `Paginated` can be customized to modify its behavior or to +add extra functionality in it. For example, suppose we want to add the average +price of the fruits in the pagination: + +```python title="types.py" +from strawberry_django.pagination import Paginated + + +@strawberry_django.type(models.Fruit) +class Fruit: + name: auto + price: auto + + +@strawberry.type +class FruitPaginated(Paginated[Fruit]): + @strawberry_django.field + def average_price(self) -> Decimal: + if self.queryset is None: + return Decimal(0) + + return self.queryset.aggregate(Avg("price"))["price__avg"] + + @strawberry_django.field + def paginated_average_price(self) -> Decimal: + paginated_queryset = self.get_paginated_queryset() + if paginated_queryset is None: + return Decimal(0) + + return paginated_queryset.aggregate(Avg("price"))["price__avg"] + + +@strawberry.type +class Query: + fruits: FruitPaginated = strawberry_django.field() +``` + +Would produce the following schema: + +```graphql title="schema.graphql" +type Fruit { + name: String! +} + +type PaginatedInfo { + limit: Int! + offset: Int! +} + +type FruitPaginated { + pageInfo: PaginatedInfo! + totalCount: Int! + results: [Fruit]! + averagePrice: Decimal! + paginatedAveragePrice: Decimal! +} + +type Query { + fruits(pagination: PaginationInput): [FruitPaginated!]! +} +``` + +The following attributes/methods can be accessed in the `Paginated` class: + +- `queryset`: The queryset original queryset with any filters/ordering applied, + but not paginated yet +- `pagination`: The `OffsetPaginationInput` object, with the `offset` and `limit` for pagination +- `get_total_count()`: Returns the total count of elements in the queryset without pagination +- `get_paginated_queryset()`: Returns the queryset with pagination applied -## Relay pagination +## Cursor pagination (aka Relay style pagination) -For more complex scenarios, a cursor pagination would be better. For this, -use the [relay integration](./relay.md) to define those. +Another option for pagination is to use a +[relay style cursor pagination](https://graphql.org/learn/pagination). For this, +you can leverage the [relay integration](./relay.md) provided by strawberry +to create a relay connection. diff --git a/strawberry_django/pagination.py b/strawberry_django/pagination.py index 34145afa..2a8e9961 100644 --- a/strawberry_django/pagination.py +++ b/strawberry_django/pagination.py @@ -30,41 +30,58 @@ class OffsetPaginationInput: limit: int = DEFAULT_LIMIT +@strawberry.type +class PaginatedInfo: + limit: int + offset: int + + @strawberry.type class Paginated(Generic[NodeType]): queryset: strawberry.Private[Optional[QuerySet]] pagination: strawberry.Private[OffsetPaginationInput] @strawberry.field - def limit(self) -> int: - return self.pagination.limit - - @strawberry.field - def offset(self) -> int: - return self.pagination.limit + def page_info(self) -> PaginatedInfo: + return PaginatedInfo( + limit=self.pagination.limit, + offset=self.pagination.offset, + ) @strawberry.field(description="Total count of existing results.") @django_resolver - def total_count(self, root) -> int: - if self.queryset is None: - return 0 - - return get_total_count(self.queryset) + def total_count(self) -> int: + return self.get_total_count() @strawberry.field(description="List of paginated results.") @django_resolver def results(self) -> list[NodeType]: + paginated_queryset = self.get_paginated_queryset() + + return cast( + list[NodeType], paginated_queryset if paginated_queryset is not None else [] + ) + + def get_total_count(self) -> int: + """Retrieve tht total count of the queryset without pagination.""" + return get_total_count(self.queryset) if self.queryset is not None else 0 + + def get_paginated_queryset(self) -> Optional[QuerySet]: + """Retrieve the queryset with pagination applied. + + This will apply the paginated arguments to the queryset and return it. + To use the original queryset, access `.queryset` directly. + """ from strawberry_django.optimizer import is_optimized_by_prefetching if self.queryset is None: - return [] - - if is_optimized_by_prefetching(self.queryset): - results = self.queryset._result_cache # type: ignore - else: - results = apply(self.pagination, self.queryset) + return None - return cast(list[NodeType], results) + return ( + self.queryset._result_cache # type: ignore + if is_optimized_by_prefetching(self.queryset) + else apply(self.pagination, self.queryset) + ) def apply( diff --git a/tests/projects/snapshots/schema.gql b/tests/projects/snapshots/schema.gql index 2fc46b1c..e6c8f5d6 100644 --- a/tests/projects/snapshots/schema.gql +++ b/tests/projects/snapshots/schema.gql @@ -337,8 +337,7 @@ type IssueTypeEdge { } type IssueTypePaginated { - limit: Int! - offset: Int! + pageInfo: PaginatedInfo! """Total count of existing results.""" totalCount: Int! @@ -522,6 +521,11 @@ type PageInfo { endCursor: String } +type PaginatedInfo { + limit: Int! + offset: Int! +} + type ProjectConnection { """Pagination data for this connection""" pageInfo: PageInfo! diff --git a/tests/projects/snapshots/schema_with_inheritance.gql b/tests/projects/snapshots/schema_with_inheritance.gql index 2dcf1e1a..a0b0eca9 100644 --- a/tests/projects/snapshots/schema_with_inheritance.gql +++ b/tests/projects/snapshots/schema_with_inheritance.gql @@ -137,8 +137,7 @@ type IssueTypeEdge { } type IssueTypePaginated { - limit: Int! - offset: Int! + pageInfo: PaginatedInfo! """Total count of existing results.""" totalCount: Int! @@ -300,6 +299,11 @@ type PageInfo { endCursor: String } +type PaginatedInfo { + limit: Int! + offset: Int! +} + input ProjectOrder { id: Ordering name: Ordering diff --git a/tests/test_paginated_type.py b/tests/test_paginated_type.py index 08641485..c88d494f 100644 --- a/tests/test_paginated_type.py +++ b/tests/test_paginated_type.py @@ -8,7 +8,7 @@ from tests import models -def test_pagination_schema(): +def test_paginated_schema(): @strawberry_django.type(models.Fruit) class Fruit: id: int @@ -35,8 +35,7 @@ class Query: } type ColorPaginated { - limit: Int! - offset: Int! + pageInfo: PaginatedInfo! """Total count of existing results.""" totalCount: Int! @@ -51,8 +50,7 @@ class Query: } type FruitPaginated { - limit: Int! - offset: Int! + pageInfo: PaginatedInfo! """Total count of existing results.""" totalCount: Int! @@ -66,6 +64,11 @@ class Query: limit: Int! = -1 } + type PaginatedInfo { + limit: Int! + offset: Int! + } + type Query { fruits(pagination: OffsetPaginationInput): FruitPaginated! colors(pagination: OffsetPaginationInput): ColorPaginated! From 816e7d95265097ed7491f1ca3d757fb19d6a34bc Mon Sep 17 00:00:00 2001 From: Thiago Bellini Ribeiro Date: Sat, 19 Oct 2024 12:59:47 -0300 Subject: [PATCH 05/10] Codereview improvements and fixes --- docs/guide/pagination.md | 21 +++++++++++++--- strawberry_django/optimizer.py | 2 +- strawberry_django/pagination.py | 43 ++++++++++++++++++++++++--------- 3 files changed, 50 insertions(+), 16 deletions(-) diff --git a/docs/guide/pagination.md b/docs/guide/pagination.md index 7851d7a7..c5944f64 100644 --- a/docs/guide/pagination.md +++ b/docs/guide/pagination.md @@ -26,8 +26,13 @@ type Fruit { name: String! } +input OffsetPaginationInput { + offset: Int! = 0 + limit: Int! = -1 +} + type Query { - fruits(pagination: PaginationInput): [Fruit!]! + fruits(pagination: OffsetPaginationInput): [Fruit!]! } ``` @@ -99,8 +104,13 @@ type FruitPaginated { results: [Fruit]! } +input OffsetPaginationInput { + offset: Int! = 0 + limit: Int! = -1 +} + type Query { - fruits(pagination: PaginationInput): [FruitPaginated!]! + fruits(pagination: OffsetPaginationInput): [FruitPaginated!]! } ``` @@ -180,8 +190,13 @@ type FruitPaginated { paginatedAveragePrice: Decimal! } +input OffsetPaginationInput { + offset: Int! = 0 + limit: Int! = -1 +} + type Query { - fruits(pagination: PaginationInput): [FruitPaginated!]! + fruits(pagination: OffsetPaginationInput): [FruitPaginated!]! } ``` diff --git a/strawberry_django/optimizer.py b/strawberry_django/optimizer.py index 99c761d8..d5ec9a54 100644 --- a/strawberry_django/optimizer.py +++ b/strawberry_django/optimizer.py @@ -1000,7 +1000,7 @@ def _get_model_hints( level=level, ) - # In case this is a relay field, the selected fields are inside results selection + # In case this is a Paginated field, the selected fields are inside results selection if issubclass(object_definition.origin, Paginated): return _get_model_hints_from_paginated( model, diff --git a/strawberry_django/pagination.py b/strawberry_django/pagination.py index 2a8e9961..dabb3e4d 100644 --- a/strawberry_django/pagination.py +++ b/strawberry_django/pagination.py @@ -127,7 +127,12 @@ def apply( class _PaginationWindow(Window): - """Marker to be able to remove where clause at `get_total_count` if needed.""" + """Window function to be used for pagination. + + This is the same as django's `Window` function, but we can easily identify + it in case we need to remove it from the queryset, as there might be other + window functions in the queryset and no other way to identify ours. + """ def apply_window_pagination( @@ -182,6 +187,26 @@ def apply_window_pagination( return queryset +def remove_window_pagination(queryset: _QS) -> _QS: + """Remove pagination window functions from a queryset. + + Utility function to remove the pagination `WHERE` clause added by + the `apply_window_pagination` function. + + Args: + ---- + queryset: The queryset to apply pagination to. + + """ + queryset = queryset._chain() # type: ignore + queryset.query.where.children = [ + child + for child in queryset.query.where.children + if (not hasattr(child, "lhs") or not isinstance(child.lhs, _PaginationWindow)) + ] + return queryset + + def get_total_count(queryset: QuerySet) -> int: """Get the total count of a queryset. @@ -209,15 +234,7 @@ def get_total_count(queryset: QuerySet) -> int: # If we have no results, we can't get the total count from the cache. # In this case we will remove the pagination filter to be able to `.count()` # the whole queryset with its original filters. - queryset = queryset._chain() # type: ignore - queryset.query.where.children = [ - child - for child in queryset.query.where.children - if ( - not hasattr(child, "lhs") - or not isinstance(child.lhs, _PaginationWindow) - ) - ] + queryset = remove_window_pagination(queryset) return queryset.count() @@ -286,8 +303,10 @@ def get_queryset( ) -> _QS: queryset = super().get_queryset(queryset, info, **kwargs) - # If this is `Paginated`, return the queryset as is as the pagination will - # be resolved when resolving its results. + # This is counter intuitive, but in case we are returning a `Paginated` + # result, we want to set the original queryset _as is_ as it will apply + # the pagination later on when resolving its `.results` field. + # Check `get_wrapped_result` below for more details. if self.is_paginated: return queryset From 18259348947d4565906a7eb806dca1d91455c3d1 Mon Sep 17 00:00:00 2001 From: Thiago Bellini Ribeiro Date: Sun, 20 Oct 2024 12:23:38 -0300 Subject: [PATCH 06/10] Default pagination limit --- docs/guide/pagination.md | 32 +++++++++++++++++-- docs/guide/settings.md | 7 +++- strawberry_django/fields/field.py | 2 +- strawberry_django/pagination.py | 17 ++++------ strawberry_django/settings.py | 7 +++- tests/projects/snapshots/schema.gql | 4 +-- .../snapshots/schema_with_inheritance.gql | 4 +-- tests/test_paginated_type.py | 4 +-- tests/test_settings.py | 2 ++ 9 files changed, 58 insertions(+), 21 deletions(-) diff --git a/docs/guide/pagination.md b/docs/guide/pagination.md index c5944f64..5ff491dd 100644 --- a/docs/guide/pagination.md +++ b/docs/guide/pagination.md @@ -63,8 +63,32 @@ class Query: Which will produce the exact same schema. -> [!NOTE] -> There is no default limit defined. All elements are returned if no pagination limit is defined. +### Default limit for pagination + +The default limit for pagination is set to `100`. This can be changed in the +[strawberry django settings](./settings.md) to increase or decrease that number, +or even set to `None` to set it to unlimited. + +To configure it on a per field basis, you can define your own `OffsetPaginationInput` +subclass and modify its default value, like: + +```python +@strawberry.input +def MyOffsetPaginationInput(OffsetPaginationInput): + limit: int = 250 + + +# Pass it to the pagination argument when defining the type +@strawberry_django.type(models.Fruit, pagination=MyOffsetPaginationInput) +class Fruit: + ... + + +@strawberry.type +class Query: + # Or pass it to the pagination argument when defining the field + fruits: list[Fruit] = strawberry_django.field(pagination=MyOffsetPaginationInput) +``` ## Paginated Generic @@ -131,6 +155,10 @@ query { } ``` +> [!NOTE] +> Paginated follow the same rules for the default pagination limit, and can be configured +> in the same way as explained above. + ### Customizing the pagination Like other generics, `Paginated` can be customized to modify its behavior or to diff --git a/docs/guide/settings.md b/docs/guide/settings.md index d8e41973..afaf8a75 100644 --- a/docs/guide/settings.md +++ b/docs/guide/settings.md @@ -62,7 +62,11 @@ A dictionary with the following optional keys: If True, [legacy filters](filters.md#legacy-filtering) are enabled. This is usefull for migrating from previous version. -These features can be enabled by adding this code to your `settings.py` file. +- **`PAGINATION_DEFAULT_LIMIT`** (default: `100`) + + Defualt limit for [pagination](pagination.md) when one is not provided by the client. Can be set to `None` to set it to unlimited. + +These features can be enabled by adding this code to your `settings.py` file, like: ```python title="settings.py" STRAWBERRY_DJANGO = { @@ -73,5 +77,6 @@ STRAWBERRY_DJANGO = { "GENERATE_ENUMS_FROM_CHOICES": False, "MAP_AUTO_ID_AS_GLOBAL_ID": True, "DEFAULT_PK_FIELD_NAME": "id", + "PAGINATION_DEFAULT_LIMIT": 250, } ``` diff --git a/strawberry_django/fields/field.py b/strawberry_django/fields/field.py index 0cd4a3c6..864d36fa 100644 --- a/strawberry_django/fields/field.py +++ b/strawberry_django/fields/field.py @@ -229,7 +229,7 @@ async def async_resolver(): kwargs["info"] = info @sync_to_async - def resolve(): + def resolve(resolved=resolved): inner_resolved = self.get_queryset_hook(**kwargs)(resolved) return self.get_wrapped_result(inner_resolved, **kwargs) diff --git a/strawberry_django/pagination.py b/strawberry_django/pagination.py index dabb3e4d..98a4db51 100644 --- a/strawberry_django/pagination.py +++ b/strawberry_django/pagination.py @@ -20,20 +20,17 @@ _T = TypeVar("_T") _QS = TypeVar("_QS", bound=QuerySet) -DEFAULT_OFFSET: int = 0 -DEFAULT_LIMIT: int = -1 - @strawberry.input class OffsetPaginationInput: - offset: int = DEFAULT_OFFSET - limit: int = DEFAULT_LIMIT + offset: int = 0 + limit: Optional[int] = None @strawberry.type class PaginatedInfo: - limit: int - offset: int + offset: int = 0 + limit: Optional[int] = None @strawberry.type @@ -117,7 +114,7 @@ def apply( ) else: start = pagination.offset - if pagination.limit >= 0: + if pagination.limit is not None and pagination.limit >= 0: stop = start + pagination.limit queryset = queryset[start:stop] else: @@ -140,7 +137,7 @@ def apply_window_pagination( *, related_field_id: str, offset: int = 0, - limit: int = -1, + limit: Optional[int] = None, ) -> _QS: """Apply pagination using window functions. @@ -181,7 +178,7 @@ def apply_window_pagination( # Limit == -1 means no limit. sys.maxsize is set by relay when paginating # from the end to as a way to mimic a "not limit" as well - if limit >= 0 and limit != sys.maxsize: + if limit is not None and limit >= 0 and limit != sys.maxsize: queryset = queryset.filter(_strawberry_row_number__lte=offset + limit) return queryset diff --git a/strawberry_django/settings.py b/strawberry_django/settings.py index e0a0606c..72a4099d 100644 --- a/strawberry_django/settings.py +++ b/strawberry_django/settings.py @@ -1,6 +1,6 @@ """Code for interacting with Django settings.""" -from typing import cast +from typing import Optional, cast from django.conf import settings from typing_extensions import TypedDict @@ -42,6 +42,10 @@ class StrawberryDjangoSettings(TypedDict): #: If True, deprecated way of using filters will be working USE_DEPRECATED_FILTERS: bool + #: The default limit for pagination when not provided. Can be set to `None` + #: to set it to unlimited. + PAGINATION_DEFAULT_LIMIT: Optional[int] + DEFAULT_DJANGO_SETTINGS = StrawberryDjangoSettings( FIELD_DESCRIPTION_FROM_HELP_TEXT=False, @@ -52,6 +56,7 @@ class StrawberryDjangoSettings(TypedDict): MAP_AUTO_ID_AS_GLOBAL_ID=False, DEFAULT_PK_FIELD_NAME="pk", USE_DEPRECATED_FILTERS=False, + PAGINATION_DEFAULT_LIMIT=100, ) diff --git a/tests/projects/snapshots/schema.gql b/tests/projects/snapshots/schema.gql index e6c8f5d6..a127814c 100644 --- a/tests/projects/snapshots/schema.gql +++ b/tests/projects/snapshots/schema.gql @@ -465,7 +465,7 @@ input NodeInputPartial { input OffsetPaginationInput { offset: Int! = 0 - limit: Int! = -1 + limit: Int = null } type OperationInfo { @@ -522,8 +522,8 @@ type PageInfo { } type PaginatedInfo { - limit: Int! offset: Int! + limit: Int } type ProjectConnection { diff --git a/tests/projects/snapshots/schema_with_inheritance.gql b/tests/projects/snapshots/schema_with_inheritance.gql index a0b0eca9..49f8aebf 100644 --- a/tests/projects/snapshots/schema_with_inheritance.gql +++ b/tests/projects/snapshots/schema_with_inheritance.gql @@ -243,7 +243,7 @@ input NodeInput { input OffsetPaginationInput { offset: Int! = 0 - limit: Int! = -1 + limit: Int = null } type OperationInfo { @@ -300,8 +300,8 @@ type PageInfo { } type PaginatedInfo { - limit: Int! offset: Int! + limit: Int } input ProjectOrder { diff --git a/tests/test_paginated_type.py b/tests/test_paginated_type.py index c88d494f..760fc727 100644 --- a/tests/test_paginated_type.py +++ b/tests/test_paginated_type.py @@ -61,12 +61,12 @@ class Query: input OffsetPaginationInput { offset: Int! = 0 - limit: Int! = -1 + limit: Int = null } type PaginatedInfo { - limit: Int! offset: Int! + limit: Int } type Query { diff --git a/tests/test_settings.py b/tests/test_settings.py index 32f572a7..3c7cf1d6 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -30,6 +30,7 @@ def test_non_defaults(): MAP_AUTO_ID_AS_GLOBAL_ID=True, DEFAULT_PK_FIELD_NAME="id", USE_DEPRECATED_FILTERS=True, + PAGINATION_DEFAULT_LIMIT=250, ), ): assert ( @@ -43,5 +44,6 @@ def test_non_defaults(): MAP_AUTO_ID_AS_GLOBAL_ID=True, DEFAULT_PK_FIELD_NAME="id", USE_DEPRECATED_FILTERS=True, + PAGINATION_DEFAULT_LIMIT=250, ) ) From a79eac064a9e95a0e0ef169010ee06d354588fc2 Mon Sep 17 00:00:00 2001 From: Thiago Bellini Ribeiro Date: Sun, 20 Oct 2024 12:30:13 -0300 Subject: [PATCH 07/10] Rename Paginated to OffsetPaginated --- docs/guide/pagination.md | 36 +++++++-------- strawberry_django/fields/base.py | 8 ++-- strawberry_django/optimizer.py | 6 +-- strawberry_django/pagination.py | 12 ++--- strawberry_django/permissions.py | 6 +-- tests/projects/schema.py | 18 +++++--- tests/projects/snapshots/schema.gql | 22 +++++----- .../snapshots/schema_with_inheritance.gql | 18 ++++---- tests/test_paginated_type.py | 44 +++++++++---------- 9 files changed, 87 insertions(+), 83 deletions(-) diff --git a/docs/guide/pagination.md b/docs/guide/pagination.md index 5ff491dd..529397f9 100644 --- a/docs/guide/pagination.md +++ b/docs/guide/pagination.md @@ -90,14 +90,14 @@ class Query: fruits: list[Fruit] = strawberry_django.field(pagination=MyOffsetPaginationInput) ``` -## Paginated Generic +## OffsetPaginated Generic -For more complex pagination needs, you can use the `Paginated` generic, which alongside +For more complex pagination needs, you can use the `OffsetPaginated` generic, which alongside the `pagination` argument, will wrap the results in an object that contains the results and the pagination information, together with the `totalCount` of elements excluding pagination. ```python title="types.py" -from strawberry_django.pagination import Paginated +from strawberry_django.pagination import OffsetPaginated @strawberry_django.type(models.Fruit) @@ -107,7 +107,7 @@ class Fruit: @strawberry.type class Query: - fruits: Paginated[Fruit] = strawberry_django.field() + fruits: OffsetPaginated[Fruit] = strawberry_django.field() ``` Would produce the following schema: @@ -117,13 +117,13 @@ type Fruit { name: String! } -type PaginatedInfo { +type PaginationInfo { limit: Int! offset: Int! } -type FruitPaginated { - pageInfo: PaginatedInfo! +type FruitOffsetPaginated { + pageInfo: PaginationInfo! totalCount: Int! results: [Fruit]! } @@ -134,7 +134,7 @@ input OffsetPaginationInput { } type Query { - fruits(pagination: OffsetPaginationInput): [FruitPaginated!]! + fruits(pagination: OffsetPaginationInput): [FruitOffsetPaginated!]! } ``` @@ -156,17 +156,17 @@ query { ``` > [!NOTE] -> Paginated follow the same rules for the default pagination limit, and can be configured +> OffsetPaginated follow the same rules for the default pagination limit, and can be configured > in the same way as explained above. ### Customizing the pagination -Like other generics, `Paginated` can be customized to modify its behavior or to +Like other generics, `OffsetPaginated` can be customized to modify its behavior or to add extra functionality in it. For example, suppose we want to add the average price of the fruits in the pagination: ```python title="types.py" -from strawberry_django.pagination import Paginated +from strawberry_django.pagination import OffsetPaginated @strawberry_django.type(models.Fruit) @@ -176,7 +176,7 @@ class Fruit: @strawberry.type -class FruitPaginated(Paginated[Fruit]): +class FruitOffsetPaginated(OffsetPaginated[Fruit]): @strawberry_django.field def average_price(self) -> Decimal: if self.queryset is None: @@ -195,7 +195,7 @@ class FruitPaginated(Paginated[Fruit]): @strawberry.type class Query: - fruits: FruitPaginated = strawberry_django.field() + fruits: FruitOffsetPaginated = strawberry_django.field() ``` Would produce the following schema: @@ -205,13 +205,13 @@ type Fruit { name: String! } -type PaginatedInfo { +type PaginationInfo { limit: Int! offset: Int! } -type FruitPaginated { - pageInfo: PaginatedInfo! +type FruitOffsetPaginated { + pageInfo: PaginationInfo! totalCount: Int! results: [Fruit]! averagePrice: Decimal! @@ -224,11 +224,11 @@ input OffsetPaginationInput { } type Query { - fruits(pagination: OffsetPaginationInput): [FruitPaginated!]! + fruits(pagination: OffsetPaginationInput): [FruitOffsetPaginated!]! } ``` -The following attributes/methods can be accessed in the `Paginated` class: +The following attributes/methods can be accessed in the `OffsetPaginated` class: - `queryset`: The queryset original queryset with any filters/ordering applied, but not paginated yet diff --git a/strawberry_django/fields/base.py b/strawberry_django/fields/base.py index 20e6864f..c2e5e2e4 100644 --- a/strawberry_django/fields/base.py +++ b/strawberry_django/fields/base.py @@ -85,7 +85,7 @@ def is_async(self) -> bool: @functools.cached_property def django_type(self) -> type[WithStrawberryDjangoObjectDefinition] | None: - from strawberry_django.pagination import Paginated + from strawberry_django.pagination import OffsetPaginated origin = self.type @@ -95,7 +95,7 @@ def django_type(self) -> type[WithStrawberryDjangoObjectDefinition] | None: object_definition = get_object_definition(origin) if object_definition and issubclass( - object_definition.origin, (relay.Connection, Paginated) + object_definition.origin, (relay.Connection, OffsetPaginated) ): origin_specialized_type_var_map = ( get_specialized_type_var_map(cast(type, origin)) or {} @@ -154,13 +154,13 @@ def is_list(self) -> bool: @functools.cached_property def is_paginated(self) -> bool: - from strawberry_django.pagination import Paginated + from strawberry_django.pagination import OffsetPaginated type_ = self.type if isinstance(type_, StrawberryOptional): type_ = type_.of_type - return isinstance(type_, type) and issubclass(type_, Paginated) + return isinstance(type_, type) and issubclass(type_, OffsetPaginated) @functools.cached_property def is_connection(self) -> bool: diff --git a/strawberry_django/optimizer.py b/strawberry_django/optimizer.py index d5ec9a54..956627fa 100644 --- a/strawberry_django/optimizer.py +++ b/strawberry_django/optimizer.py @@ -49,7 +49,7 @@ from typing_extensions import assert_never, assert_type from strawberry_django.fields.types import resolve_model_field_name -from strawberry_django.pagination import Paginated, apply_window_pagination +from strawberry_django.pagination import OffsetPaginated, apply_window_pagination from strawberry_django.queryset import get_queryset_config, run_type_get_queryset from strawberry_django.relay import ListConnectionWithTotalCount from strawberry_django.resolvers import django_fetch @@ -574,7 +574,7 @@ def _optimize_prefetch_queryset( else: mark_optimized = False - if isinstance(field.type, type) and issubclass(field.type, Paginated): + if isinstance(field.type, type) and issubclass(field.type, OffsetPaginated): pagination = field_kwargs.get("pagination") qs = apply_window_pagination( qs, @@ -1001,7 +1001,7 @@ def _get_model_hints( ) # In case this is a Paginated field, the selected fields are inside results selection - if issubclass(object_definition.origin, Paginated): + if issubclass(object_definition.origin, OffsetPaginated): return _get_model_hints_from_paginated( model, schema, diff --git a/strawberry_django/pagination.py b/strawberry_django/pagination.py index 98a4db51..49282df1 100644 --- a/strawberry_django/pagination.py +++ b/strawberry_django/pagination.py @@ -28,19 +28,19 @@ class OffsetPaginationInput: @strawberry.type -class PaginatedInfo: +class OffsetPaginationInfo: offset: int = 0 limit: Optional[int] = None @strawberry.type -class Paginated(Generic[NodeType]): +class OffsetPaginated(Generic[NodeType]): queryset: strawberry.Private[Optional[QuerySet]] pagination: strawberry.Private[OffsetPaginationInput] @strawberry.field - def page_info(self) -> PaginatedInfo: - return PaginatedInfo( + def page_info(self) -> OffsetPaginationInfo: + return OffsetPaginationInfo( limit=self.pagination.limit, offset=self.pagination.offset, ) @@ -320,7 +320,7 @@ def get_wrapped_result( *, pagination: Optional[object] = None, **kwargs, - ) -> Union[_T, Paginated[_T]]: + ) -> Union[_T, OffsetPaginated[_T]]: if not self.is_paginated: return result @@ -333,7 +333,7 @@ def get_wrapped_result( ): raise TypeError(f"Don't know how to resolve pagination {pagination!r}") - return Paginated( + return OffsetPaginated( queryset=result, pagination=pagination or OffsetPaginationInput(), ) diff --git a/strawberry_django/permissions.py b/strawberry_django/permissions.py index 71ee7eb3..7b56b2d1 100644 --- a/strawberry_django/permissions.py +++ b/strawberry_django/permissions.py @@ -39,7 +39,7 @@ from strawberry_django.auth.utils import aget_current_user, get_current_user from strawberry_django.fields.types import OperationInfo, OperationMessage -from strawberry_django.pagination import OffsetPaginationInput, Paginated +from strawberry_django.pagination import OffsetPaginated, OffsetPaginationInput from strawberry_django.resolvers import django_resolver from .utils.query import filter_for_user @@ -406,8 +406,8 @@ def handle_no_permission(self, exception: BaseException, *, info: Info): if isinstance(ret_type, StrawberryList): return [] - if isinstance(ret_type, type) and issubclass(ret_type, Paginated): - return Paginated(queryset=None, pagination=OffsetPaginationInput()) + if isinstance(ret_type, type) and issubclass(ret_type, OffsetPaginated): + return OffsetPaginated(queryset=None, pagination=OffsetPaginationInput()) # If it is a Connection, try to return an empty connection, but only if # it is the only possibility available... diff --git a/tests/projects/schema.py b/tests/projects/schema.py index f9814c33..7502f482 100644 --- a/tests/projects/schema.py +++ b/tests/projects/schema.py @@ -29,7 +29,7 @@ from strawberry_django.fields.types import ListInput, NodeInput, NodeInputPartial from strawberry_django.mutations import resolvers from strawberry_django.optimizer import DjangoOptimizerExtension -from strawberry_django.pagination import Paginated +from strawberry_django.pagination import OffsetPaginated from strawberry_django.permissions import ( HasPerm, HasRetvalPerm, @@ -159,7 +159,7 @@ class MilestoneType(relay.Node, Named): order=IssueOrder, pagination=True, ) - issues_paginated: Paginated["IssueType"] = strawberry_django.field( + issues_paginated: OffsetPaginated["IssueType"] = strawberry_django.field( field_name="issues", order=IssueOrder, ) @@ -380,7 +380,7 @@ class Query: staff_list: list[Optional[StaffType]] = strawberry_django.node() issue_list: list[IssueType] = strawberry_django.field() - issues_paginated: Paginated[IssueType] = strawberry_django.field() + issues_paginated: OffsetPaginated[IssueType] = strawberry_django.field() milestone_list: list[MilestoneType] = strawberry_django.field( order=MilestoneOrder, filters=MilestoneFilter, @@ -435,8 +435,10 @@ class Query: issue_list_perm_required: list[IssueType] = strawberry_django.field( extensions=[HasPerm(perms=["projects.view_issue"])], ) - issues_paginated_perm_required: Paginated[IssueType] = strawberry_django.field( - extensions=[HasPerm(perms=["projects.view_issue"])], + issues_paginated_perm_required: OffsetPaginated[IssueType] = ( + strawberry_django.field( + extensions=[HasPerm(perms=["projects.view_issue"])], + ) ) issue_conn_perm_required: ListConnectionWithTotalCount[IssueType] = ( strawberry_django.connection( @@ -456,8 +458,10 @@ class Query: issue_list_obj_perm_required_paginated: list[IssueType] = strawberry_django.field( extensions=[HasRetvalPerm(perms=["projects.view_issue"])], pagination=True ) - issues_paginated_obj_perm_required: Paginated[IssueType] = strawberry_django.field( - extensions=[HasRetvalPerm(perms=["projects.view_issue"])], + issues_paginated_obj_perm_required: OffsetPaginated[IssueType] = ( + strawberry_django.field( + extensions=[HasRetvalPerm(perms=["projects.view_issue"])], + ) ) issue_conn_obj_perm_required: ListConnectionWithTotalCount[IssueType] = ( strawberry_django.connection( diff --git a/tests/projects/snapshots/schema.gql b/tests/projects/snapshots/schema.gql index a127814c..68c7936f 100644 --- a/tests/projects/snapshots/schema.gql +++ b/tests/projects/snapshots/schema.gql @@ -336,8 +336,8 @@ type IssueTypeEdge { node: IssueType! } -type IssueTypePaginated { - pageInfo: PaginatedInfo! +type IssueTypeOffsetPaginated { + pageInfo: OffsetPaginationInfo! """Total count of existing results.""" totalCount: Int! @@ -383,7 +383,7 @@ type MilestoneType implements Node & Named { dueDate: Date project: ProjectType! issues(filters: IssueFilter, order: IssueOrder, pagination: OffsetPaginationInput): [IssueType!]! - issuesPaginated(pagination: OffsetPaginationInput): IssueTypePaginated! + issuesPaginated(pagination: OffsetPaginationInput): IssueTypeOffsetPaginated! issuesWithFilters( filters: IssueFilter @@ -463,6 +463,11 @@ input NodeInputPartial { id: GlobalID } +type OffsetPaginationInfo { + offset: Int! + limit: Int +} + input OffsetPaginationInput { offset: Int! = 0 limit: Int = null @@ -521,11 +526,6 @@ type PageInfo { endCursor: String } -type PaginatedInfo { - offset: Int! - limit: Int -} - type ProjectConnection { """Pagination data for this connection""" pageInfo: PageInfo! @@ -626,7 +626,7 @@ type Query { ids: [GlobalID!]! ): [StaffType]! issueList: [IssueType!]! - issuesPaginated(pagination: OffsetPaginationInput): IssueTypePaginated! + issuesPaginated(pagination: OffsetPaginationInput): IssueTypeOffsetPaginated! milestoneList(filters: MilestoneFilter, order: MilestoneOrder, pagination: OffsetPaginationInput): [MilestoneType!]! projectList(filters: ProjectFilter): [ProjectType!]! tagList: [TagType!]! @@ -746,7 +746,7 @@ type Query { id: GlobalID! ): IssueType @hasPerm(permissions: [{app: "projects", permission: "view_issue"}], any: true) issueListPermRequired: [IssueType!]! @hasPerm(permissions: [{app: "projects", permission: "view_issue"}], any: true) - issuesPaginatedPermRequired(pagination: OffsetPaginationInput): IssueTypePaginated! @hasPerm(permissions: [{app: "projects", permission: "view_issue"}], any: true) + issuesPaginatedPermRequired(pagination: OffsetPaginationInput): IssueTypeOffsetPaginated! @hasPerm(permissions: [{app: "projects", permission: "view_issue"}], any: true) issueConnPermRequired( """Returns the items in the list that come before the specified cursor.""" before: String = null @@ -770,7 +770,7 @@ type Query { ): IssueType @hasRetvalPerm(permissions: [{app: "projects", permission: "view_issue"}], any: true) issueListObjPermRequired: [IssueType!]! @hasRetvalPerm(permissions: [{app: "projects", permission: "view_issue"}], any: true) issueListObjPermRequiredPaginated(pagination: OffsetPaginationInput): [IssueType!]! @hasRetvalPerm(permissions: [{app: "projects", permission: "view_issue"}], any: true) - issuesPaginatedObjPermRequired(pagination: OffsetPaginationInput): IssueTypePaginated! @hasRetvalPerm(permissions: [{app: "projects", permission: "view_issue"}], any: true) + issuesPaginatedObjPermRequired(pagination: OffsetPaginationInput): IssueTypeOffsetPaginated! @hasRetvalPerm(permissions: [{app: "projects", permission: "view_issue"}], any: true) issueConnObjPermRequired( """Returns the items in the list that come before the specified cursor.""" before: String = null diff --git a/tests/projects/snapshots/schema_with_inheritance.gql b/tests/projects/snapshots/schema_with_inheritance.gql index 49f8aebf..f104ce92 100644 --- a/tests/projects/snapshots/schema_with_inheritance.gql +++ b/tests/projects/snapshots/schema_with_inheritance.gql @@ -136,8 +136,8 @@ type IssueTypeEdge { node: IssueType! } -type IssueTypePaginated { - pageInfo: PaginatedInfo! +type IssueTypeOffsetPaginated { + pageInfo: OffsetPaginationInfo! """Total count of existing results.""" totalCount: Int! @@ -173,7 +173,7 @@ type MilestoneType implements Node & Named { dueDate: Date project: ProjectType! issues(filters: IssueFilter, order: IssueOrder, pagination: OffsetPaginationInput): [IssueType!]! - issuesPaginated(pagination: OffsetPaginationInput): IssueTypePaginated! + issuesPaginated(pagination: OffsetPaginationInput): IssueTypeOffsetPaginated! issuesWithFilters( filters: IssueFilter @@ -201,7 +201,7 @@ type MilestoneTypeSubclass implements Node & Named { dueDate: Date project: ProjectType! issues(filters: IssueFilter, order: IssueOrder, pagination: OffsetPaginationInput): [IssueType!]! - issuesPaginated(pagination: OffsetPaginationInput): IssueTypePaginated! + issuesPaginated(pagination: OffsetPaginationInput): IssueTypeOffsetPaginated! issuesWithFilters( filters: IssueFilter @@ -241,6 +241,11 @@ input NodeInput { id: GlobalID! } +type OffsetPaginationInfo { + offset: Int! + limit: Int +} + input OffsetPaginationInput { offset: Int! = 0 limit: Int = null @@ -299,11 +304,6 @@ type PageInfo { endCursor: String } -type PaginatedInfo { - offset: Int! - limit: Int -} - input ProjectOrder { id: Ordering name: Ordering diff --git a/tests/test_paginated_type.py b/tests/test_paginated_type.py index 760fc727..1e032f89 100644 --- a/tests/test_paginated_type.py +++ b/tests/test_paginated_type.py @@ -4,7 +4,7 @@ import strawberry import strawberry_django -from strawberry_django.pagination import Paginated +from strawberry_django.pagination import OffsetPaginated from tests import models @@ -18,12 +18,12 @@ class Fruit: class Color: id: int name: str - fruits: Paginated[Fruit] + fruits: OffsetPaginated[Fruit] @strawberry.type class Query: - fruits: Paginated[Fruit] = strawberry_django.field() - colors: Paginated[Color] = strawberry_django.field() + fruits: OffsetPaginated[Fruit] = strawberry_django.field() + colors: OffsetPaginated[Color] = strawberry_django.field() schema = strawberry.Schema(query=Query) @@ -31,11 +31,11 @@ class Query: type Color { id: Int! name: String! - fruits(pagination: OffsetPaginationInput): FruitPaginated! + fruits(pagination: OffsetPaginationInput): FruitOffsetPaginated! } - type ColorPaginated { - pageInfo: PaginatedInfo! + type ColorOffsetPaginated { + pageInfo: OffsetPaginationInfo! """Total count of existing results.""" totalCount: Int! @@ -49,8 +49,8 @@ class Query: name: String! } - type FruitPaginated { - pageInfo: PaginatedInfo! + type FruitOffsetPaginated { + pageInfo: OffsetPaginationInfo! """Total count of existing results.""" totalCount: Int! @@ -59,19 +59,19 @@ class Query: results: [Fruit!]! } + type OffsetPaginationInfo { + offset: Int! + limit: Int + } + input OffsetPaginationInput { offset: Int! = 0 limit: Int = null } - type PaginatedInfo { - offset: Int! - limit: Int - } - type Query { - fruits(pagination: OffsetPaginationInput): FruitPaginated! - colors(pagination: OffsetPaginationInput): ColorPaginated! + fruits(pagination: OffsetPaginationInput): FruitOffsetPaginated! + colors(pagination: OffsetPaginationInput): ColorOffsetPaginated! } ''' @@ -87,7 +87,7 @@ class Fruit: @strawberry.type class Query: - fruits: Paginated[Fruit] = strawberry_django.field() + fruits: OffsetPaginated[Fruit] = strawberry_django.field() models.Fruit.objects.create(name="Apple") models.Fruit.objects.create(name="Banana") @@ -145,7 +145,7 @@ class Fruit: @strawberry.type class Query: - fruits: Paginated[Fruit] = strawberry_django.field() + fruits: OffsetPaginated[Fruit] = strawberry_django.field() await models.Fruit.objects.acreate(name="Apple") await models.Fruit.objects.acreate(name="Banana") @@ -205,11 +205,11 @@ class Fruit: class Color: id: int name: str - fruits: Paginated[Fruit] + fruits: OffsetPaginated[Fruit] @strawberry.type class Query: - colors: Paginated[Color] = strawberry_django.field() + colors: OffsetPaginated[Color] = strawberry_django.field() red = models.Color.objects.create(name="Red") yellow = models.Color.objects.create(name="Yellow") @@ -316,11 +316,11 @@ class Fruit: class Color: id: int name: str - fruits: Paginated[Fruit] + fruits: OffsetPaginated[Fruit] @strawberry.type class Query: - colors: Paginated[Color] = strawberry_django.field() + colors: OffsetPaginated[Color] = strawberry_django.field() red = await models.Color.objects.acreate(name="Red") yellow = await models.Color.objects.acreate(name="Yellow") From 89074b13852d26f1de3b9a589a938da9e4dc8c83 Mon Sep 17 00:00:00 2001 From: Thiago Bellini Ribeiro Date: Sun, 20 Oct 2024 12:54:31 -0300 Subject: [PATCH 08/10] Allow further customization by letting the type resolve itself --- docs/guide/pagination.md | 2 + strawberry_django/pagination.py | 42 ++++++++++++++--- tests/test_paginated_type.py | 80 ++++++++++++++++++++++++++++++++- 3 files changed, 118 insertions(+), 6 deletions(-) diff --git a/docs/guide/pagination.md b/docs/guide/pagination.md index 529397f9..0e8453ee 100644 --- a/docs/guide/pagination.md +++ b/docs/guide/pagination.md @@ -235,6 +235,8 @@ The following attributes/methods can be accessed in the `OffsetPaginated` class: - `pagination`: The `OffsetPaginationInput` object, with the `offset` and `limit` for pagination - `get_total_count()`: Returns the total count of elements in the queryset without pagination - `get_paginated_queryset()`: Returns the queryset with pagination applied +- `resolve_paginated(queryset, *, info, pagiantion, **kwargs)`: The classmethod that + strawberry-django calls to create an instance of the `OffsetPaginated` class/subclass. ## Cursor pagination (aka Relay style pagination) diff --git a/strawberry_django/pagination.py b/strawberry_django/pagination.py index 49282df1..0e9b72f2 100644 --- a/strawberry_django/pagination.py +++ b/strawberry_django/pagination.py @@ -59,6 +59,32 @@ def results(self) -> list[NodeType]: list[NodeType], paginated_queryset if paginated_queryset is not None else [] ) + @classmethod + def resolve_paginated( + cls, + queryset: QuerySet, + *, + info: Info, + pagination: Optional[OffsetPaginationInput] = None, + **kwargs, + ) -> Self: + """Resolve the paginated queryset. + + Args: + queryset: The queryset to be paginated. + info: The strawberry execution info resolve the type name from. + pagination: The pagination input to be applied. + kwargs: Additional arguments passed to the resolver. + + Returns: + The resolved `OffsetPaginated` + + """ + return cls( + queryset=queryset, + pagination=pagination or OffsetPaginationInput(), + ) + def get_total_count(self) -> int: """Retrieve tht total count of the queryset without pagination.""" return get_total_count(self.queryset) if self.queryset is not None else 0 @@ -294,7 +320,7 @@ def get_queryset( queryset: _QS, info: Info, *, - pagination: Optional[object] = None, + pagination: Optional[OffsetPaginationInput] = None, _strawberry_related_field_id: Optional[str] = None, **kwargs, ) -> _QS: @@ -318,7 +344,7 @@ def get_wrapped_result( result: _T, info: Info, *, - pagination: Optional[object] = None, + pagination: Optional[OffsetPaginationInput] = None, **kwargs, ) -> Union[_T, OffsetPaginated[_T]]: if not self.is_paginated: @@ -333,7 +359,13 @@ def get_wrapped_result( ): raise TypeError(f"Don't know how to resolve pagination {pagination!r}") - return OffsetPaginated( - queryset=result, - pagination=pagination or OffsetPaginationInput(), + paginated_type = self.type + assert isinstance(paginated_type, type) + assert issubclass(paginated_type, OffsetPaginated) + + return paginated_type.resolve_paginated( + result, + info=info, + pagination=pagination, + **kwargs, ) diff --git a/tests/test_paginated_type.py b/tests/test_paginated_type.py index 1e032f89..16813bb0 100644 --- a/tests/test_paginated_type.py +++ b/tests/test_paginated_type.py @@ -4,7 +4,7 @@ import strawberry import strawberry_django -from strawberry_django.pagination import OffsetPaginated +from strawberry_django.pagination import OffsetPaginated, OffsetPaginationInput from tests import models @@ -414,3 +414,81 @@ class Query: ], } } + + +@pytest.mark.django_db(transaction=True) +def test_pagination_query_with_subclass(): + @strawberry_django.type(models.Fruit) + class Fruit: + id: int + name: str + + @strawberry.type + class FruitPaginated(OffsetPaginated[Fruit]): + _custom_field: strawberry.Private[str] + + @strawberry_django.field + def custom_field(self) -> str: + return self._custom_field + + @classmethod + def resolve_paginated(cls, queryset, *, info, pagination=None, **kwargs): + return cls( + queryset=queryset, + pagination=pagination or OffsetPaginationInput(), + _custom_field="pagination rocks", + ) + + @strawberry.type + class Query: + fruits: FruitPaginated = strawberry_django.field() + + models.Fruit.objects.create(name="Apple") + models.Fruit.objects.create(name="Banana") + models.Fruit.objects.create(name="Strawberry") + + schema = strawberry.Schema(query=Query) + + query = """\ + query GetFruits ($pagination: OffsetPaginationInput) { + fruits (pagination: $pagination) { + totalCount + customField + results { + name + } + } + } + """ + + res = schema.execute_sync(query) + assert res.errors is None + assert res.data == { + "fruits": { + "totalCount": 3, + "customField": "pagination rocks", + "results": [{"name": "Apple"}, {"name": "Banana"}, {"name": "Strawberry"}], + } + } + + res = schema.execute_sync(query, variable_values={"pagination": {"limit": 1}}) + assert res.errors is None + assert res.data == { + "fruits": { + "totalCount": 3, + "customField": "pagination rocks", + "results": [{"name": "Apple"}], + } + } + + res = schema.execute_sync( + query, variable_values={"pagination": {"limit": 1, "offset": 2}} + ) + assert res.errors is None + assert res.data == { + "fruits": { + "totalCount": 3, + "customField": "pagination rocks", + "results": [{"name": "Strawberry"}], + } + } From 9705ab00e8c992b4d087729a8530459973838315 Mon Sep 17 00:00:00 2001 From: Thiago Bellini Ribeiro Date: Sun, 20 Oct 2024 14:49:27 -0300 Subject: [PATCH 09/10] Refactor Paginated to be used similarly to how Connection is used --- docs/guide/pagination.md | 85 ++++- strawberry_django/__init__.py | 3 +- strawberry_django/fields/field.py | 343 +++++++++++++++-- strawberry_django/pagination.py | 34 +- strawberry_django/permissions.py | 8 +- tests/projects/schema.py | 8 +- tests/projects/snapshots/schema.gql | 2 +- .../snapshots/schema_with_inheritance.gql | 4 +- tests/test_paginated_type.py | 355 +++++++++++++++++- 9 files changed, 751 insertions(+), 91 deletions(-) diff --git a/docs/guide/pagination.md b/docs/guide/pagination.md index 0e8453ee..67beb920 100644 --- a/docs/guide/pagination.md +++ b/docs/guide/pagination.md @@ -28,7 +28,7 @@ type Fruit { input OffsetPaginationInput { offset: Int! = 0 - limit: Int! = -1 + limit: Int = null } type Query { @@ -107,7 +107,7 @@ class Fruit: @strawberry.type class Query: - fruits: OffsetPaginated[Fruit] = strawberry_django.field() + fruits: OffsetPaginated[Fruit] = strawberry_django.offset_paginated() ``` Would produce the following schema: @@ -118,7 +118,7 @@ type Fruit { } type PaginationInfo { - limit: Int! + limit: Int = null offset: Int! } @@ -130,7 +130,7 @@ type FruitOffsetPaginated { input OffsetPaginationInput { offset: Int! = 0 - limit: Int! = -1 + limit: Int = null } type Query { @@ -159,6 +159,77 @@ query { > OffsetPaginated follow the same rules for the default pagination limit, and can be configured > in the same way as explained above. +### Customizing queryset resolver + +It is possible to define a custom resolver for the queryset to either provide a custom +queryset for it, or even to receive extra arguments alongside the pagination arguments. + +Suppose we want to pre-filter a queryset of fruits for only available ones, +while also adding [ordering](./ordering.md) to it. This can be achieved like: + +```python title="types.py" + +@strawberry_django.type(models.Fruit) +class Fruit: + name: auto + price: auto + + +@strawberry_django.order(models.Fruit) +class FruitOrder: + name: auto + price: auto + + +@strawberry.type +class Query: + @straberry.offset_paginated(OffsetPaginated[Fruit], order=order) + def fruits(self, only_available: bool = True) -> QuerySet[Fruit]: + queryset = models.Fruit.objects.all() + if only_available: + queryset = queryset.filter(available=True) + + return queryset +``` + +This would produce the following schema: + +```graphql title="schema.graphql" +type Fruit { + name: String! + price: Decimal! +} + +type FruitOrder { + name: Ordering + price: Ordering +} + +type PaginationInfo { + limit: Int! + offset: Int! +} + +type FruitOffsetPaginated { + pageInfo: PaginationInfo! + totalCount: Int! + results: [Fruit]! +} + +input OffsetPaginationInput { + offset: Int! = 0 + limit: Int = null +} + +type Query { + fruits( + onlyAvailable: Boolean! = true + pagination: OffsetPaginationInput + order: FruitOrder + ): [FruitOffsetPaginated!]! +} +``` + ### Customizing the pagination Like other generics, `OffsetPaginated` can be customized to modify its behavior or to @@ -195,7 +266,7 @@ class FruitOffsetPaginated(OffsetPaginated[Fruit]): @strawberry.type class Query: - fruits: FruitOffsetPaginated = strawberry_django.field() + fruits: FruitOffsetPaginated = strawberry_django.offset_paginated() ``` Would produce the following schema: @@ -206,7 +277,7 @@ type Fruit { } type PaginationInfo { - limit: Int! + limit: Int = null offset: Int! } @@ -220,7 +291,7 @@ type FruitOffsetPaginated { input OffsetPaginationInput { offset: Int! = 0 - limit: Int! = -1 + limit: Int = null } type Query { diff --git a/strawberry_django/__init__.py b/strawberry_django/__init__.py index 30e2ac69..74e0a557 100644 --- a/strawberry_django/__init__.py +++ b/strawberry_django/__init__.py @@ -1,5 +1,5 @@ from . import auth, filters, mutations, ordering, pagination, relay -from .fields.field import connection, field, node +from .fields.field import connection, field, node, offset_paginated from .fields.filter_order import filter_field, order_field from .fields.filter_types import ( BaseFilterLookup, @@ -60,6 +60,7 @@ "mutation", "mutations", "node", + "offset_paginated", "order", "order_field", "ordering", diff --git a/strawberry_django/fields/field.py b/strawberry_django/fields/field.py index 864d36fa..7806e77a 100644 --- a/strawberry_django/fields/field.py +++ b/strawberry_django/fields/field.py @@ -2,13 +2,21 @@ import dataclasses import inspect -from collections.abc import Iterable, Mapping, Sequence -from functools import cached_property, partial +from collections.abc import ( + AsyncIterable, + AsyncIterator, + Iterable, + Iterator, + Mapping, + Sequence, +) +from functools import cached_property from typing import ( TYPE_CHECKING, Any, Callable, TypeVar, + Union, cast, overload, ) @@ -26,9 +34,12 @@ from django.db.models.query_utils import DeferredAttribute from strawberry import UNSET, relay from strawberry.annotation import StrawberryAnnotation +from strawberry.extensions.field_extension import FieldExtension +from strawberry.types.field import _RESOLVER_TYPE # noqa: PLC2701 from strawberry.types.fields.resolver import StrawberryResolver from strawberry.types.info import Info # noqa: TCH002 from strawberry.utils.await_maybe import await_maybe +from typing_extensions import TypeAlias from strawberry_django import optimizer from strawberry_django.arguments import argument @@ -37,7 +48,12 @@ from strawberry_django.filters import FILTERS_ARG, StrawberryDjangoFieldFilters from strawberry_django.optimizer import OptimizerStore, is_optimized_by_prefetching from strawberry_django.ordering import ORDER_ARG, StrawberryDjangoFieldOrdering -from strawberry_django.pagination import StrawberryDjangoPagination +from strawberry_django.pagination import ( + PAGINATION_ARG, + OffsetPaginated, + OffsetPaginationInput, + StrawberryDjangoPagination, +) from strawberry_django.permissions import filter_with_perms from strawberry_django.queryset import run_type_get_queryset from strawberry_django.relay import resolve_model_nodes @@ -51,13 +67,11 @@ if TYPE_CHECKING: from graphql.pyutils import AwaitableOrValue from strawberry import BasePermission - from strawberry.extensions.field_extension import ( - FieldExtension, - SyncExtensionResolver, - ) + from strawberry.extensions.field_extension import SyncExtensionResolver from strawberry.relay.types import NodeIterableType from strawberry.types.arguments import StrawberryArgument - from strawberry.types.field import _RESOLVER_TYPE, StrawberryField + from strawberry.types.base import WithStrawberryObjectDefinition + from strawberry.types.field import StrawberryField from strawberry.types.unset import UnsetType from typing_extensions import Literal, Self @@ -228,12 +242,9 @@ async def async_resolver(): if "info" not in kwargs: kwargs["info"] = info - @sync_to_async - def resolve(resolved=resolved): - inner_resolved = self.get_queryset_hook(**kwargs)(resolved) - return self.get_wrapped_result(inner_resolved, **kwargs) - - resolved = await resolve() + resolved = await sync_to_async(self.get_queryset_hook(**kwargs))( + resolved + ) return resolved @@ -248,7 +259,7 @@ def resolve(resolved=resolved): result = django_resolver( self.get_queryset_hook(**kwargs), - qs_hook=partial(self.get_wrapped_result, **kwargs), + qs_hook=lambda qs: qs, )(result) return result @@ -300,6 +311,44 @@ def get_queryset(self, queryset, info, **kwargs): return queryset +def _get_field_arguments_for_extensions( + field: StrawberryDjangoField, + *, + add_filters: bool = True, + add_order: bool = True, + add_pagination: bool = True, +) -> list[StrawberryArgument]: + """Get a list of arguments to be set to fields using extensions. + + Because we have a base_resolver defined in those, our parents will not add + order/filters/pagination resolvers in here, so we need to add them by hand (unless they + are somewhat in there). We are not adding pagination because it doesn't make + sense together with a Connection + """ + args: dict[str, StrawberryArgument] = {a.python_name: a for a in field.arguments} + + if add_filters and FILTERS_ARG not in args: + filters = field.get_filters() + if filters not in (None, UNSET): # noqa: PLR6201 + args[FILTERS_ARG] = argument(FILTERS_ARG, filters, is_optional=True) + + if add_order and ORDER_ARG not in args: + order = field.get_order() + if order not in (None, UNSET): # noqa: PLR6201 + args[ORDER_ARG] = argument(ORDER_ARG, order, is_optional=True) + + if add_pagination and PAGINATION_ARG not in args: + pagination = field.get_pagination() + if pagination not in (None, UNSET): # noqa: PLR6201 + args[PAGINATION_ARG] = argument( + PAGINATION_ARG, + pagination, + is_optional=True, + ) + + return list(args.values()) + + class StrawberryDjangoConnectionExtension(relay.ConnectionExtension): def apply(self, field: StrawberryField) -> None: if not isinstance(field, StrawberryDjangoField): @@ -307,24 +356,10 @@ def apply(self, field: StrawberryField) -> None: "The extension can only be applied to StrawberryDjangoField" ) - # NOTE: Because we have a base_resolver defined, our parents will not add - # order/filters resolvers in here, so we need to add them by hand (unless they - # are somewhat in there). We are not adding pagination because it doesn't make - # sense together with a Connection - args: dict[str, StrawberryArgument] = { - a.python_name: a for a in field.arguments - } - - if FILTERS_ARG not in args: - filters = field.get_filters() - if filters not in (None, UNSET): # noqa: PLR6201 - args[FILTERS_ARG] = argument(FILTERS_ARG, filters, is_optional=True) - if ORDER_ARG not in args: - order = field.get_order() - if order not in (None, UNSET): # noqa: PLR6201 - args[ORDER_ARG] = argument(ORDER_ARG, order, is_optional=True) - - field.arguments = list(args.values()) + field.arguments = _get_field_arguments_for_extensions( + field, + add_pagination=False, + ) if field.base_resolver is None: @@ -412,6 +447,67 @@ async def async_resolver(): ) +class StrawberryOffsetPaginatedExtension(FieldExtension): + paginated_type: type[OffsetPaginated] + + def apply(self, field: StrawberryField) -> None: + if not isinstance(field, StrawberryDjangoField): + raise TypeError( + "The extension can only be applied to StrawberryDjangoField" + ) + + field.arguments = _get_field_arguments_for_extensions(field) + self.paginated_type = cast(type[OffsetPaginated], field.type) + + def resolve( + self, + next_: SyncExtensionResolver, + source: Any, + info: Info, + *, + pagination: OffsetPaginationInput | None = None, + order: WithStrawberryObjectDefinition | None = None, + filters: WithStrawberryObjectDefinition | None = None, + **kwargs: Any, + ) -> Any: + assert self.paginated_type is not None + queryset = cast(models.QuerySet, next_(source, info, **kwargs)) + + def get_queryset(queryset): + return cast(StrawberryDjangoField, info._field).get_queryset( + queryset, + info, + pagination=pagination, + order=order, + filters=filters, + ) + + # We have a single resolver for both sync and async, so we need to check if + # nodes is awaitable or not and resolve it accordingly + if inspect.isawaitable(queryset): + + async def async_resolver(queryset=queryset): + resolved = self.paginated_type.resolve_paginated( + get_queryset(await queryset), + info=info, + pagination=pagination, + **kwargs, + ) + if inspect.isawaitable(resolved): + resolved = await resolved + + return resolved + + return async_resolver() + + return self.paginated_type.resolve_paginated( + get_queryset(queryset), + info=info, + pagination=pagination, + **kwargs, + ) + + @overload def field( *, @@ -805,3 +901,184 @@ def connection( f = f(resolver) return f + + +_OFFSET_PAGINATED_RESOLVER_TYPE: TypeAlias = _RESOLVER_TYPE[ + Union[ + Iterator[models.Model], + Iterable[models.Model], + AsyncIterator[models.Model], + AsyncIterable[models.Model], + ] +] + + +@overload +def offset_paginated( + graphql_type: type[OffsetPaginated] | None = None, + *, + field_cls: type[StrawberryDjangoField] = StrawberryDjangoField, + name: str | None = None, + field_name: str | None = None, + is_subscription: bool = False, + description: str | None = None, + permission_classes: list[type[BasePermission]] | None = None, + deprecation_reason: str | None = None, + default: Any = dataclasses.MISSING, + default_factory: Callable[..., object] | object = dataclasses.MISSING, + metadata: Mapping[Any, Any] | None = None, + directives: Sequence[object] | None = (), + extensions: Sequence[FieldExtension] = (), + filters: type | None = UNSET, + order: type | None = UNSET, + only: TypeOrSequence[str] | None = None, + select_related: TypeOrSequence[str] | None = None, + prefetch_related: TypeOrSequence[PrefetchType] | None = None, + annotate: TypeOrMapping[AnnotateType] | None = None, + disable_optimization: bool = False, +) -> Any: ... + + +@overload +def offset_paginated( + graphql_type: type[OffsetPaginated] | None = None, + *, + field_cls: type[StrawberryDjangoField] = StrawberryDjangoField, + resolver: _OFFSET_PAGINATED_RESOLVER_TYPE | None = None, + name: str | None = None, + field_name: str | None = None, + is_subscription: bool = False, + description: str | None = None, + init: Literal[True] = True, + permission_classes: list[type[BasePermission]] | None = None, + deprecation_reason: str | None = None, + default: Any = dataclasses.MISSING, + default_factory: Callable[..., object] | object = dataclasses.MISSING, + metadata: Mapping[Any, Any] | None = None, + directives: Sequence[object] | None = (), + extensions: Sequence[FieldExtension] = (), + filters: type | None = UNSET, + order: type | None = UNSET, + only: TypeOrSequence[str] | None = None, + select_related: TypeOrSequence[str] | None = None, + prefetch_related: TypeOrSequence[PrefetchType] | None = None, + annotate: TypeOrMapping[AnnotateType] | None = None, + disable_optimization: bool = False, +) -> Any: ... + + +def offset_paginated( + graphql_type: type[OffsetPaginated] | None = None, + *, + field_cls: type[StrawberryDjangoField] = StrawberryDjangoField, + resolver: _OFFSET_PAGINATED_RESOLVER_TYPE | None = None, + name: str | None = None, + field_name: str | None = None, + is_subscription: bool = False, + description: str | None = None, + permission_classes: list[type[BasePermission]] | None = None, + deprecation_reason: str | None = None, + default: Any = dataclasses.MISSING, + default_factory: Callable[..., object] | object = dataclasses.MISSING, + metadata: Mapping[Any, Any] | None = None, + directives: Sequence[object] | None = (), + extensions: Sequence[FieldExtension] = (), + filters: type | None = UNSET, + order: type | None = UNSET, + only: TypeOrSequence[str] | None = None, + select_related: TypeOrSequence[str] | None = None, + prefetch_related: TypeOrSequence[PrefetchType] | None = None, + annotate: TypeOrMapping[AnnotateType] | None = None, + disable_optimization: bool = False, + # This init parameter is used by pyright to determine whether this field + # is added in the constructor or not. It is not used to change + # any behavior at the moment. + init: Literal[True, False, None] = None, +) -> Any: + """Annotate a property or a method to create a relay connection field. + + Relay connections_ are mostly used for pagination purposes. This decorator + helps creating a complete relay endpoint that provides default arguments + and has a default implementation for the connection slicing. + + Note that when setting a resolver to this field, it is expected for this + resolver to return an iterable of the expected node type, not the connection + itself. That iterable will then be paginated accordingly. So, the main use + case for this is to provide a filtered iterable of nodes by using some custom + filter arguments. + + Examples + -------- + Annotating something like this: + + >>> @strawberry.type + >>> class X: + ... some_node: relay.Connection[SomeType] = relay.connection( + ... description="ABC", + ... ) + ... + ... @relay.connection(description="ABC") + ... def get_some_nodes(self, age: int) -> Iterable[SomeType]: + ... ... + + Will produce a query like this: + + ``` + query { + someNode ( + before: String + after: String + first: String + after: String + age: Int + ) { + totalCount + pageInfo { + hasNextPage + hasPreviousPage + startCursor + endCursor + } + edges { + cursor + node { + id + ... + } + } + } + } + ``` + + .. _Relay connections: + https://relay.dev/graphql/connections.htm + + """ + extensions = [*extensions, StrawberryOffsetPaginatedExtension()] + f = field_cls( + python_name=None, + django_name=field_name, + graphql_name=name, + type_annotation=StrawberryAnnotation.from_annotation(graphql_type), + description=description, + is_subscription=is_subscription, + permission_classes=permission_classes or [], + deprecation_reason=deprecation_reason, + default=default, + default_factory=default_factory, + metadata=metadata, + directives=directives or (), + filters=filters, + order=order, + extensions=extensions, + only=only, + select_related=select_related, + prefetch_related=prefetch_related, + annotate=annotate, + disable_optimization=disable_optimization, + ) + + if resolver: + f = f(resolver) + + return f diff --git a/strawberry_django/pagination.py b/strawberry_django/pagination.py index 0e9b72f2..32a00fb4 100644 --- a/strawberry_django/pagination.py +++ b/strawberry_django/pagination.py @@ -17,9 +17,10 @@ from .arguments import argument NodeType = TypeVar("NodeType") -_T = TypeVar("_T") _QS = TypeVar("_QS", bound=QuerySet) +PAGINATION_ARG = "pagination" + @strawberry.input class OffsetPaginationInput: @@ -338,34 +339,3 @@ def get_queryset( pagination, related_field_id=_strawberry_related_field_id, ) - - def get_wrapped_result( - self, - result: _T, - info: Info, - *, - pagination: Optional[OffsetPaginationInput] = None, - **kwargs, - ) -> Union[_T, OffsetPaginated[_T]]: - if not self.is_paginated: - return result - - if not isinstance(result, QuerySet): - raise TypeError(f"Result expected to be a queryset, got {result!r}") - - if ( - pagination not in (None, UNSET) # noqa: PLR6201 - and not isinstance(pagination, OffsetPaginationInput) - ): - raise TypeError(f"Don't know how to resolve pagination {pagination!r}") - - paginated_type = self.type - assert isinstance(paginated_type, type) - assert issubclass(paginated_type, OffsetPaginated) - - return paginated_type.resolve_paginated( - result, - info=info, - pagination=pagination, - **kwargs, - ) diff --git a/strawberry_django/permissions.py b/strawberry_django/permissions.py index 7b56b2d1..2b1a60f8 100644 --- a/strawberry_django/permissions.py +++ b/strawberry_django/permissions.py @@ -39,7 +39,7 @@ from strawberry_django.auth.utils import aget_current_user, get_current_user from strawberry_django.fields.types import OperationInfo, OperationMessage -from strawberry_django.pagination import OffsetPaginated, OffsetPaginationInput +from strawberry_django.pagination import OffsetPaginated from strawberry_django.resolvers import django_resolver from .utils.query import filter_for_user @@ -48,6 +48,8 @@ if TYPE_CHECKING: from strawberry.django.context import StrawberryDjangoContext + from strawberry_django.fields.field import StrawberryDjangoField + _M = TypeVar("_M", bound=Model) @@ -407,7 +409,9 @@ def handle_no_permission(self, exception: BaseException, *, info: Info): return [] if isinstance(ret_type, type) and issubclass(ret_type, OffsetPaginated): - return OffsetPaginated(queryset=None, pagination=OffsetPaginationInput()) + django_model = cast("StrawberryDjangoField", info._field).django_model + assert django_model + return django_model._default_manager.none() # If it is a Connection, try to return an empty connection, but only if # it is the only possibility available... diff --git a/tests/projects/schema.py b/tests/projects/schema.py index 7502f482..d87b3f35 100644 --- a/tests/projects/schema.py +++ b/tests/projects/schema.py @@ -159,7 +159,7 @@ class MilestoneType(relay.Node, Named): order=IssueOrder, pagination=True, ) - issues_paginated: OffsetPaginated["IssueType"] = strawberry_django.field( + issues_paginated: OffsetPaginated["IssueType"] = strawberry_django.offset_paginated( field_name="issues", order=IssueOrder, ) @@ -380,7 +380,7 @@ class Query: staff_list: list[Optional[StaffType]] = strawberry_django.node() issue_list: list[IssueType] = strawberry_django.field() - issues_paginated: OffsetPaginated[IssueType] = strawberry_django.field() + issues_paginated: OffsetPaginated[IssueType] = strawberry_django.offset_paginated() milestone_list: list[MilestoneType] = strawberry_django.field( order=MilestoneOrder, filters=MilestoneFilter, @@ -436,7 +436,7 @@ class Query: extensions=[HasPerm(perms=["projects.view_issue"])], ) issues_paginated_perm_required: OffsetPaginated[IssueType] = ( - strawberry_django.field( + strawberry_django.offset_paginated( extensions=[HasPerm(perms=["projects.view_issue"])], ) ) @@ -459,7 +459,7 @@ class Query: extensions=[HasRetvalPerm(perms=["projects.view_issue"])], pagination=True ) issues_paginated_obj_perm_required: OffsetPaginated[IssueType] = ( - strawberry_django.field( + strawberry_django.offset_paginated( extensions=[HasRetvalPerm(perms=["projects.view_issue"])], ) ) diff --git a/tests/projects/snapshots/schema.gql b/tests/projects/snapshots/schema.gql index 68c7936f..b5f31d33 100644 --- a/tests/projects/snapshots/schema.gql +++ b/tests/projects/snapshots/schema.gql @@ -383,7 +383,7 @@ type MilestoneType implements Node & Named { dueDate: Date project: ProjectType! issues(filters: IssueFilter, order: IssueOrder, pagination: OffsetPaginationInput): [IssueType!]! - issuesPaginated(pagination: OffsetPaginationInput): IssueTypeOffsetPaginated! + issuesPaginated(pagination: OffsetPaginationInput, order: IssueOrder): IssueTypeOffsetPaginated! issuesWithFilters( filters: IssueFilter diff --git a/tests/projects/snapshots/schema_with_inheritance.gql b/tests/projects/snapshots/schema_with_inheritance.gql index f104ce92..352c8e1b 100644 --- a/tests/projects/snapshots/schema_with_inheritance.gql +++ b/tests/projects/snapshots/schema_with_inheritance.gql @@ -173,7 +173,7 @@ type MilestoneType implements Node & Named { dueDate: Date project: ProjectType! issues(filters: IssueFilter, order: IssueOrder, pagination: OffsetPaginationInput): [IssueType!]! - issuesPaginated(pagination: OffsetPaginationInput): IssueTypeOffsetPaginated! + issuesPaginated(pagination: OffsetPaginationInput, order: IssueOrder): IssueTypeOffsetPaginated! issuesWithFilters( filters: IssueFilter @@ -201,7 +201,7 @@ type MilestoneTypeSubclass implements Node & Named { dueDate: Date project: ProjectType! issues(filters: IssueFilter, order: IssueOrder, pagination: OffsetPaginationInput): [IssueType!]! - issuesPaginated(pagination: OffsetPaginationInput): IssueTypeOffsetPaginated! + issuesPaginated(pagination: OffsetPaginationInput, order: IssueOrder): IssueTypeOffsetPaginated! issuesWithFilters( filters: IssueFilter diff --git a/tests/test_paginated_type.py b/tests/test_paginated_type.py index 16813bb0..74537727 100644 --- a/tests/test_paginated_type.py +++ b/tests/test_paginated_type.py @@ -2,6 +2,7 @@ import pytest import strawberry +from django.db.models import QuerySet import strawberry_django from strawberry_django.pagination import OffsetPaginated, OffsetPaginationInput @@ -22,8 +23,8 @@ class Color: @strawberry.type class Query: - fruits: OffsetPaginated[Fruit] = strawberry_django.field() - colors: OffsetPaginated[Color] = strawberry_django.field() + fruits: OffsetPaginated[Fruit] = strawberry_django.offset_paginated() + colors: OffsetPaginated[Color] = strawberry_django.offset_paginated() schema = strawberry.Schema(query=Query) @@ -87,7 +88,7 @@ class Fruit: @strawberry.type class Query: - fruits: OffsetPaginated[Fruit] = strawberry_django.field() + fruits: OffsetPaginated[Fruit] = strawberry_django.offset_paginated() models.Fruit.objects.create(name="Apple") models.Fruit.objects.create(name="Banana") @@ -145,7 +146,7 @@ class Fruit: @strawberry.type class Query: - fruits: OffsetPaginated[Fruit] = strawberry_django.field() + fruits: OffsetPaginated[Fruit] = strawberry_django.offset_paginated() await models.Fruit.objects.acreate(name="Apple") await models.Fruit.objects.acreate(name="Banana") @@ -205,11 +206,11 @@ class Fruit: class Color: id: int name: str - fruits: OffsetPaginated[Fruit] + fruits: OffsetPaginated[Fruit] = strawberry_django.offset_paginated() @strawberry.type class Query: - colors: OffsetPaginated[Color] = strawberry_django.field() + colors: OffsetPaginated[Color] = strawberry_django.offset_paginated() red = models.Color.objects.create(name="Red") yellow = models.Color.objects.create(name="Yellow") @@ -316,11 +317,11 @@ class Fruit: class Color: id: int name: str - fruits: OffsetPaginated[Fruit] + fruits: OffsetPaginated[Fruit] = strawberry_django.offset_paginated() @strawberry.type class Query: - colors: OffsetPaginated[Color] = strawberry_django.field() + colors: OffsetPaginated[Color] = strawberry_django.offset_paginated() red = await models.Color.objects.acreate(name="Red") yellow = await models.Color.objects.acreate(name="Yellow") @@ -441,7 +442,7 @@ def resolve_paginated(cls, queryset, *, info, pagination=None, **kwargs): @strawberry.type class Query: - fruits: FruitPaginated = strawberry_django.field() + fruits: FruitPaginated = strawberry_django.offset_paginated() models.Fruit.objects.create(name="Apple") models.Fruit.objects.create(name="Banana") @@ -492,3 +493,339 @@ class Query: "results": [{"name": "Strawberry"}], } } + + +@pytest.mark.django_db(transaction=True) +def test_pagination_query_with_resolver_schema(): + @strawberry_django.type(models.Fruit) + class Fruit: + id: int + name: str + + @strawberry_django.filter(models.Fruit) + class FruitFilter: + name: str + + @strawberry_django.order(models.Fruit) + class FruitOrder: + name: str + + @strawberry.type + class Query: + @strawberry_django.offset_paginated(OffsetPaginated[Fruit]) + def fruits(self) -> QuerySet[models.Fruit]: ... + + @strawberry_django.offset_paginated( + OffsetPaginated[Fruit], + filters=FruitFilter, + order=FruitOrder, + ) + def fruits_with_order_and_filter(self) -> QuerySet[models.Fruit]: ... + + schema = strawberry.Schema(query=Query) + + expected = ''' + type Fruit { + id: Int! + name: String! + } + + input FruitFilter { + name: String! + AND: FruitFilter + OR: FruitFilter + NOT: FruitFilter + DISTINCT: Boolean + } + + type FruitOffsetPaginated { + pageInfo: OffsetPaginationInfo! + + """Total count of existing results.""" + totalCount: Int! + + """List of paginated results.""" + results: [Fruit!]! + } + + input FruitOrder { + name: String + } + + type OffsetPaginationInfo { + offset: Int! + limit: Int + } + + input OffsetPaginationInput { + offset: Int! = 0 + limit: Int = null + } + + type Query { + fruits(pagination: OffsetPaginationInput): FruitOffsetPaginated! + fruitsWithOrderAndFilter(filters: FruitFilter, order: FruitOrder, pagination: OffsetPaginationInput): FruitOffsetPaginated! + } + ''' + + assert textwrap.dedent(str(schema)) == textwrap.dedent(expected).strip() + + +@pytest.mark.django_db(transaction=True) +def test_pagination_query_with_resolver(): + @strawberry_django.type(models.Fruit) + class Fruit: + id: int + name: str + + @strawberry_django.filter(models.Fruit) + class FruitFilter: + name: strawberry.auto + + @strawberry_django.order(models.Fruit) + class FruitOrder: + name: strawberry.auto + + @strawberry.type + class Query: + @strawberry_django.offset_paginated(OffsetPaginated[Fruit]) + def fruits(self) -> QuerySet[models.Fruit]: + return models.Fruit.objects.filter(name__startswith="S") + + @strawberry_django.offset_paginated( + OffsetPaginated[Fruit], + filters=FruitFilter, + order=FruitOrder, + ) + def fruits_with_order_and_filter(self) -> QuerySet[models.Fruit]: + return models.Fruit.objects.filter(name__startswith="S") + + models.Fruit.objects.create(name="Apple") + models.Fruit.objects.create(name="Strawberry") + models.Fruit.objects.create(name="Banana") + models.Fruit.objects.create(name="Sugar Apple") + models.Fruit.objects.create(name="Starfruit") + + schema = strawberry.Schema(query=Query) + + query = """\ + query GetFruits ( + $pagination: OffsetPaginationInput + $filters: FruitFilter + $order: FruitOrder + ) { + fruits (pagination: $pagination) { + totalCount + results { + name + } + } + fruitsWithOrderAndFilter ( + pagination: $pagination + filters: $filters + order: $order + ) { + totalCount + results { + name + } + } + } + """ + + res = schema.execute_sync(query) + assert res.errors is None + assert res.data == { + "fruits": { + "totalCount": 3, + "results": [ + {"name": "Strawberry"}, + {"name": "Sugar Apple"}, + {"name": "Starfruit"}, + ], + }, + "fruitsWithOrderAndFilter": { + "totalCount": 3, + "results": [ + {"name": "Strawberry"}, + {"name": "Sugar Apple"}, + {"name": "Starfruit"}, + ], + }, + } + + res = schema.execute_sync(query, variable_values={"pagination": {"limit": 1}}) + assert res.errors is None + assert res.data == { + "fruits": { + "totalCount": 3, + "results": [ + {"name": "Strawberry"}, + ], + }, + "fruitsWithOrderAndFilter": { + "totalCount": 3, + "results": [ + {"name": "Strawberry"}, + ], + }, + } + + res = schema.execute_sync( + query, + variable_values={ + "pagination": {"limit": 2}, + "order": {"name": "ASC"}, + "filters": {"name": "Strawberry"}, + }, + ) + assert res.errors is None + assert res.data == { + "fruits": { + "totalCount": 3, + "results": [ + {"name": "Strawberry"}, + {"name": "Sugar Apple"}, + ], + }, + "fruitsWithOrderAndFilter": { + "totalCount": 1, + "results": [ + {"name": "Strawberry"}, + ], + }, + } + + +@pytest.mark.django_db(transaction=True) +def test_pagination_query_with_resolver_arguments(): + @strawberry_django.type(models.Fruit) + class Fruit: + id: int + name: str + + @strawberry_django.filter(models.Fruit) + class FruitFilter: + name: strawberry.auto + + @strawberry_django.order(models.Fruit) + class FruitOrder: + name: strawberry.auto + + @strawberry.type + class Query: + @strawberry_django.offset_paginated(OffsetPaginated[Fruit]) + def fruits(self, starts_with: str) -> QuerySet[models.Fruit]: + return models.Fruit.objects.filter(name__startswith=starts_with) + + @strawberry_django.offset_paginated( + OffsetPaginated[Fruit], + filters=FruitFilter, + order=FruitOrder, + ) + def fruits_with_order_and_filter( + self, starts_with: str + ) -> QuerySet[models.Fruit]: + return models.Fruit.objects.filter(name__startswith=starts_with) + + models.Fruit.objects.create(name="Apple") + models.Fruit.objects.create(name="Strawberry") + models.Fruit.objects.create(name="Banana") + models.Fruit.objects.create(name="Sugar Apple") + models.Fruit.objects.create(name="Starfruit") + + schema = strawberry.Schema(query=Query) + + query = """\ + query GetFruits ( + $pagination: OffsetPaginationInput + $filters: FruitFilter + $order: FruitOrder + $startsWith: String! + ) { + fruits (startsWith: $startsWith, pagination: $pagination) { + totalCount + results { + name + } + } + fruitsWithOrderAndFilter ( + startsWith: $startsWith + pagination: $pagination + filters: $filters + order: $order + ) { + totalCount + results { + name + } + } + } + """ + + res = schema.execute_sync(query, variable_values={"startsWith": "S"}) + assert res.errors is None + assert res.data == { + "fruits": { + "totalCount": 3, + "results": [ + {"name": "Strawberry"}, + {"name": "Sugar Apple"}, + {"name": "Starfruit"}, + ], + }, + "fruitsWithOrderAndFilter": { + "totalCount": 3, + "results": [ + {"name": "Strawberry"}, + {"name": "Sugar Apple"}, + {"name": "Starfruit"}, + ], + }, + } + + res = schema.execute_sync( + query, + variable_values={"startsWith": "S", "pagination": {"limit": 1}}, + ) + assert res.errors is None + assert res.data == { + "fruits": { + "totalCount": 3, + "results": [ + {"name": "Strawberry"}, + ], + }, + "fruitsWithOrderAndFilter": { + "totalCount": 3, + "results": [ + {"name": "Strawberry"}, + ], + }, + } + + res = schema.execute_sync( + query, + variable_values={ + "startsWith": "S", + "pagination": {"limit": 2}, + "order": {"name": "ASC"}, + "filters": {"name": "Strawberry"}, + }, + ) + assert res.errors is None + assert res.data == { + "fruits": { + "totalCount": 3, + "results": [ + {"name": "Strawberry"}, + {"name": "Sugar Apple"}, + ], + }, + "fruitsWithOrderAndFilter": { + "totalCount": 1, + "results": [ + {"name": "Strawberry"}, + ], + }, + } From a997293b9392453046360b9019059b58793d252f Mon Sep 17 00:00:00 2001 From: Thiago Bellini Ribeiro Date: Tue, 22 Oct 2024 09:08:44 -0300 Subject: [PATCH 10/10] Fix a typo in the docs --- docs/guide/pagination.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/guide/pagination.md b/docs/guide/pagination.md index 67beb920..89296d84 100644 --- a/docs/guide/pagination.md +++ b/docs/guide/pagination.md @@ -165,7 +165,7 @@ It is possible to define a custom resolver for the queryset to either provide a queryset for it, or even to receive extra arguments alongside the pagination arguments. Suppose we want to pre-filter a queryset of fruits for only available ones, -while also adding [ordering](./ordering.md) to it. This can be achieved like: +while also adding [ordering](./ordering.md) to it. This can be achieved with: ```python title="types.py" @@ -183,7 +183,7 @@ class FruitOrder: @strawberry.type class Query: - @straberry.offset_paginated(OffsetPaginated[Fruit], order=order) + @strawberry_django.offset_paginated(OffsetPaginated[Fruit], order=order) def fruits(self, only_available: bool = True) -> QuerySet[Fruit]: queryset = models.Fruit.objects.all() if only_available: