Skip to content

Commit

Permalink
Allow further customization by letting the type resolve itself
Browse files Browse the repository at this point in the history
  • Loading branch information
bellini666 committed Oct 20, 2024
1 parent a79eac0 commit 89074b1
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 6 deletions.
2 changes: 2 additions & 0 deletions docs/guide/pagination.md
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,8 @@ The following attributes/methods can be accessed in the `OffsetPaginated` class:
- `pagination`: The `OffsetPaginationInput` object, with the `offset` and `limit` for pagination
- `get_total_count()`: Returns the total count of elements in the queryset without pagination
- `get_paginated_queryset()`: Returns the queryset with pagination applied
- `resolve_paginated(queryset, *, info, pagiantion, **kwargs)`: The classmethod that
strawberry-django calls to create an instance of the `OffsetPaginated` class/subclass.

## Cursor pagination (aka Relay style pagination)

Expand Down
42 changes: 37 additions & 5 deletions strawberry_django/pagination.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,32 @@ def results(self) -> list[NodeType]:
list[NodeType], paginated_queryset if paginated_queryset is not None else []
)

@classmethod
def resolve_paginated(
cls,
queryset: QuerySet,
*,
info: Info,
pagination: Optional[OffsetPaginationInput] = None,
**kwargs,
) -> Self:
"""Resolve the paginated queryset.
Args:
queryset: The queryset to be paginated.
info: The strawberry execution info resolve the type name from.
pagination: The pagination input to be applied.
kwargs: Additional arguments passed to the resolver.
Returns:
The resolved `OffsetPaginated`
"""
return cls(
queryset=queryset,
pagination=pagination or OffsetPaginationInput(),
)

def get_total_count(self) -> int:
"""Retrieve tht total count of the queryset without pagination."""
return get_total_count(self.queryset) if self.queryset is not None else 0
Expand Down Expand Up @@ -294,7 +320,7 @@ def get_queryset(
queryset: _QS,
info: Info,
*,
pagination: Optional[object] = None,
pagination: Optional[OffsetPaginationInput] = None,
_strawberry_related_field_id: Optional[str] = None,
**kwargs,
) -> _QS:
Expand All @@ -318,7 +344,7 @@ def get_wrapped_result(
result: _T,
info: Info,
*,
pagination: Optional[object] = None,
pagination: Optional[OffsetPaginationInput] = None,
**kwargs,
) -> Union[_T, OffsetPaginated[_T]]:
if not self.is_paginated:
Expand All @@ -333,7 +359,13 @@ def get_wrapped_result(
):
raise TypeError(f"Don't know how to resolve pagination {pagination!r}")

return OffsetPaginated(
queryset=result,
pagination=pagination or OffsetPaginationInput(),
paginated_type = self.type
assert isinstance(paginated_type, type)
assert issubclass(paginated_type, OffsetPaginated)

return paginated_type.resolve_paginated(
result,
info=info,
pagination=pagination,
**kwargs,
)
80 changes: 79 additions & 1 deletion tests/test_paginated_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import strawberry

import strawberry_django
from strawberry_django.pagination import OffsetPaginated
from strawberry_django.pagination import OffsetPaginated, OffsetPaginationInput
from tests import models


Expand Down Expand Up @@ -414,3 +414,81 @@ class Query:
],
}
}


@pytest.mark.django_db(transaction=True)
def test_pagination_query_with_subclass():
@strawberry_django.type(models.Fruit)
class Fruit:
id: int
name: str

@strawberry.type
class FruitPaginated(OffsetPaginated[Fruit]):
_custom_field: strawberry.Private[str]

@strawberry_django.field
def custom_field(self) -> str:
return self._custom_field

@classmethod
def resolve_paginated(cls, queryset, *, info, pagination=None, **kwargs):
return cls(
queryset=queryset,
pagination=pagination or OffsetPaginationInput(),
_custom_field="pagination rocks",
)

@strawberry.type
class Query:
fruits: FruitPaginated = strawberry_django.field()

models.Fruit.objects.create(name="Apple")
models.Fruit.objects.create(name="Banana")
models.Fruit.objects.create(name="Strawberry")

schema = strawberry.Schema(query=Query)

query = """\
query GetFruits ($pagination: OffsetPaginationInput) {
fruits (pagination: $pagination) {
totalCount
customField
results {
name
}
}
}
"""

res = schema.execute_sync(query)
assert res.errors is None
assert res.data == {
"fruits": {
"totalCount": 3,
"customField": "pagination rocks",
"results": [{"name": "Apple"}, {"name": "Banana"}, {"name": "Strawberry"}],
}
}

res = schema.execute_sync(query, variable_values={"pagination": {"limit": 1}})
assert res.errors is None
assert res.data == {
"fruits": {
"totalCount": 3,
"customField": "pagination rocks",
"results": [{"name": "Apple"}],
}
}

res = schema.execute_sync(
query, variable_values={"pagination": {"limit": 1, "offset": 2}}
)
assert res.errors is None
assert res.data == {
"fruits": {
"totalCount": 3,
"customField": "pagination rocks",
"results": [{"name": "Strawberry"}],
}
}

0 comments on commit 89074b1

Please sign in to comment.