Skip to content

Commit

Permalink
Make optimizer work with Paginated results
Browse files Browse the repository at this point in the history
  • Loading branch information
bellini666 committed Oct 17, 2024
1 parent 7279deb commit 2ca008c
Show file tree
Hide file tree
Showing 8 changed files with 339 additions and 21 deletions.
4 changes: 2 additions & 2 deletions strawberry_django/fields/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,8 +248,8 @@ def resolve():
kwargs["info"] = info

result = django_resolver(
partial(self.get_wrapped_result, **kwargs),
qs_hook=self.get_queryset_hook(**kwargs),
self.get_queryset_hook(**kwargs),
qs_hook=partial(self.get_wrapped_result, **kwargs),
)(result)

return result
Expand Down
79 changes: 75 additions & 4 deletions strawberry_django/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
from typing_extensions import assert_never, assert_type

from strawberry_django.fields.types import resolve_model_field_name
from strawberry_django.pagination import apply_window_pagination
from strawberry_django.pagination import Paginated, apply_window_pagination
from strawberry_django.queryset import get_queryset_config, run_type_get_queryset
from strawberry_django.relay import ListConnectionWithTotalCount
from strawberry_django.resolvers import django_fetch
Expand Down Expand Up @@ -528,7 +528,7 @@ def _optimize_prefetch_queryset(
)
field_kwargs.pop("info", None)

# Disable the optimizer to avoid doint double optimization while running get_queryset
# Disable the optimizer to avoid doing double optimization while running get_queryset
with DjangoOptimizerExtension.disabled():
qs = field.get_queryset(
qs,
Expand Down Expand Up @@ -574,6 +574,15 @@ def _optimize_prefetch_queryset(
else:
mark_optimized = False

if isinstance(field.type, type) and issubclass(field.type, Paginated):
pagination = field_kwargs.get("pagination")
qs = apply_window_pagination(
qs,
related_field_id=related_field_id,
offset=pagination.offset if pagination else 0,
limit=pagination.limit if pagination else -1,
)

if mark_optimized:
qs = mark_optimized_by_prefetching(qs)

Expand Down Expand Up @@ -977,8 +986,7 @@ def _get_model_hints(
) -> OptimizerStore | None:
cache = cache or {}

# In case this is a relay field, find the selected edges/nodes, the selected fields
# are actually inside edges -> node selection...
# In case this is a relay field, the selected fields are inside edges -> node selection
if issubclass(object_definition.origin, relay.Connection):
return _get_model_hints_from_connection(
model,
Expand All @@ -992,6 +1000,20 @@ def _get_model_hints(
level=level,
)

# In case this is a relay field, the selected fields are inside results selection
if issubclass(object_definition.origin, Paginated):
return _get_model_hints_from_paginated(
model,
schema,
object_definition,
parent_type=parent_type,
info=info,
config=config,
prefix=prefix,
cache=cache,
level=level,
)

store = OptimizerStore()
config = config or OptimizerConfig()

Expand Down Expand Up @@ -1156,6 +1178,55 @@ def _get_model_hints_from_connection(
return store


def _get_model_hints_from_paginated(
model: type[models.Model],
schema: Schema,
object_definition: StrawberryObjectDefinition,
*,
parent_type: GraphQLObjectType | GraphQLInterfaceType,
info: GraphQLResolveInfo,
config: OptimizerConfig | None = None,
prefix: str = "",
cache: dict[type[models.Model], list[tuple[int, OptimizerStore]]] | None = None,
level: int = 0,
) -> OptimizerStore | None:
store = None

n_type = object_definition.type_var_map.get("NodeType")
n_definition = get_object_definition(n_type, strict=True)
n_gql_definition = _get_gql_definition(
schema,
get_object_definition(n_type, strict=True),
)
assert isinstance(n_gql_definition, (GraphQLObjectType, GraphQLInterfaceType))

for selections in _get_selections(info, parent_type).values():
selection = selections[0]
if selection.name.value != "results":
continue

n_info = _generate_selection_resolve_info(
info,
selections,
n_gql_definition,
n_gql_definition,
)

store = _get_model_hints(
model=model,
schema=schema,
object_definition=n_definition,
parent_type=n_gql_definition,
info=n_info,
config=config,
prefix=prefix,
cache=cache,
level=level,
)

return store


def optimize(
qs: QuerySet[_M] | BaseManager[_M],
info: GraphQLResolveInfo | Info,
Expand Down
51 changes: 38 additions & 13 deletions strawberry_django/pagination.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def offset(self) -> int:

@strawberry.field(description="Total count of existing results.")
@django_resolver
def total_count(self) -> int:
def total_count(self, root) -> int:
return get_total_count(self.queryset)

@strawberry.field(description="List of paginated results.")
Expand Down Expand Up @@ -103,6 +103,10 @@ def apply(
return queryset


class _PaginationWindow(Window):
"""Marker to be able to remove where clause at `get_total_count` if needed."""


def apply_window_pagination(
queryset: _QS,
*,
Expand Down Expand Up @@ -131,13 +135,14 @@ def apply_window_pagination(
using=queryset._db or DEFAULT_DB_ALIAS # type: ignore
).get_order_by()
]

queryset = queryset.annotate(
_strawberry_row_number=Window(
_strawberry_row_number=_PaginationWindow(
RowNumber(),
partition_by=related_field_id,
order_by=order_by,
),
_strawberry_total_count=Window(
_strawberry_total_count=_PaginationWindow(
Count(1),
partition_by=related_field_id,
),
Expand Down Expand Up @@ -165,17 +170,31 @@ def get_total_count(queryset: QuerySet) -> int:
if is_optimized_by_prefetching(queryset):
results = queryset._result_cache # type: ignore

try:
return results[0]._strawberry_total_count if results else 0
except AttributeError:
warnings.warn(
(
"Pagination annotations not found, falling back to QuerySet resolution. "
"This might cause n+1 issues..."
),
RuntimeWarning,
stacklevel=2,
if results:
try:
return results[0]._strawberry_total_count
except AttributeError:
warnings.warn(
(
"Pagination annotations not found, falling back to QuerySet resolution. "
"This might cause n+1 issues..."
),
RuntimeWarning,
stacklevel=2,
)

# If we have no results, we can't get the total count from the cache.
# In this case we will remove the pagination filter to be able to `.count()`
# the whole queryset with its original filters.
queryset = queryset._chain() # type: ignore
queryset.query.where.children = [
child
for child in queryset.query.where.children
if (
not hasattr(child, "lhs")
or not isinstance(child.lhs, _PaginationWindow)
)
]

return queryset.count()

Expand Down Expand Up @@ -243,6 +262,12 @@ def get_queryset(
**kwargs,
) -> _QS:
queryset = super().get_queryset(queryset, info, **kwargs)

# If this is `Paginated`, return the queryset as is as the pagination will
# be resolved when resolving its results.
if self.is_paginated:
return queryset

return self.apply_pagination(
queryset,
pagination,
Expand Down
4 changes: 2 additions & 2 deletions strawberry_django/resolvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def resolve_base_manager(manager: BaseManager) -> Any:
# prevents us from importing and checking isinstance on them directly.
try:
# ManyRelatedManager
return list(prefetched_cache[manager.prefetch_cache_name]) # type: ignore
return prefetched_cache[manager.prefetch_cache_name] # type: ignore
except (AttributeError, KeyError):
try:
# RelatedManager
Expand All @@ -203,7 +203,7 @@ def resolve_base_manager(manager: BaseManager) -> Any:
getattr(result_field.remote_field, "cache_name", None)
or result_field.remote_field.get_cache_name()
)
return list(prefetched_cache[cache_name])
return prefetched_cache[cache_name]
except (AttributeError, KeyError):
pass

Expand Down
14 changes: 14 additions & 0 deletions tests/projects/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from strawberry_django.fields.types import ListInput, NodeInput, NodeInputPartial
from strawberry_django.mutations import resolvers
from strawberry_django.optimizer import DjangoOptimizerExtension
from strawberry_django.pagination import Paginated
from strawberry_django.permissions import (
HasPerm,
HasRetvalPerm,
Expand Down Expand Up @@ -158,6 +159,10 @@ class MilestoneType(relay.Node, Named):
order=IssueOrder,
pagination=True,
)
issues_paginated: Paginated["IssueType"] = strawberry_django.field(
field_name="issues",
order=IssueOrder,
)
issues_with_filters: ListConnectionWithTotalCount["IssueType"] = (
strawberry_django.connection(
field_name="issues",
Expand Down Expand Up @@ -375,6 +380,7 @@ class Query:
staff_list: List[Optional[StaffType]] = strawberry_django.node()

issue_list: List[IssueType] = strawberry_django.field()
issues_paginated: Paginated[IssueType] = strawberry_django.field()
milestone_list: List[MilestoneType] = strawberry_django.field(
order=MilestoneOrder,
filters=MilestoneFilter,
Expand Down Expand Up @@ -429,6 +435,9 @@ class Query:
issue_list_perm_required: List[IssueType] = strawberry_django.field(
extensions=[HasPerm(perms=["projects.view_issue"])],
)
issue_paginated_list_perm_required: Paginated[IssueType] = strawberry_django.field(
extensions=[HasPerm(perms=["projects.view_issue"])],
)
issue_conn_perm_required: ListConnectionWithTotalCount[IssueType] = (
strawberry_django.connection(
extensions=[HasPerm(perms=["projects.view_issue"])],
Expand All @@ -447,6 +456,11 @@ class Query:
issue_list_obj_perm_required_paginated: List[IssueType] = strawberry_django.field(
extensions=[HasRetvalPerm(perms=["projects.view_issue"])], pagination=True
)
issue_paginated_list_obj_perm_required_paginated: Paginated[IssueType] = (
strawberry_django.field(
extensions=[HasRetvalPerm(perms=["projects.view_issue"])], pagination=True
)
)
issue_conn_obj_perm_required: ListConnectionWithTotalCount[IssueType] = (
strawberry_django.connection(
extensions=[HasRetvalPerm(perms=["projects.view_issue"])],
Expand Down
15 changes: 15 additions & 0 deletions tests/projects/snapshots/schema.gql
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,17 @@ type IssueTypeEdge {
node: IssueType!
}

type IssueTypePaginated {
limit: Int!
offset: Int!

"""Total count of existing results."""
totalCount: Int!

"""List of paginated results."""
results: [IssueType!]!
}

input MilestoneFilter {
name: StrFilterLookup
project: DjangoModelFilterInput
Expand Down Expand Up @@ -373,6 +384,7 @@ type MilestoneType implements Node & Named {
dueDate: Date
project: ProjectType!
issues(filters: IssueFilter, order: IssueOrder, pagination: OffsetPaginationInput): [IssueType!]!
issuesPaginated(pagination: OffsetPaginationInput): IssueTypePaginated!
issuesWithFilters(
filters: IssueFilter

Expand Down Expand Up @@ -610,6 +622,7 @@ type Query {
ids: [GlobalID!]!
): [StaffType]!
issueList: [IssueType!]!
issuesPaginated(pagination: OffsetPaginationInput): IssueTypePaginated!
milestoneList(filters: MilestoneFilter, order: MilestoneOrder, pagination: OffsetPaginationInput): [MilestoneType!]!
projectList(filters: ProjectFilter): [ProjectType!]!
tagList: [TagType!]!
Expand Down Expand Up @@ -729,6 +742,7 @@ type Query {
id: GlobalID!
): IssueType @hasPerm(permissions: [{app: "projects", permission: "view_issue"}], any: true)
issueListPermRequired: [IssueType!]! @hasPerm(permissions: [{app: "projects", permission: "view_issue"}], any: true)
issuePaginatedListPermRequired(pagination: OffsetPaginationInput): IssueTypePaginated! @hasPerm(permissions: [{app: "projects", permission: "view_issue"}], any: true)
issueConnPermRequired(
"""Returns the items in the list that come before the specified cursor."""
before: String = null
Expand All @@ -752,6 +766,7 @@ type Query {
): IssueType @hasRetvalPerm(permissions: [{app: "projects", permission: "view_issue"}], any: true)
issueListObjPermRequired: [IssueType!]! @hasRetvalPerm(permissions: [{app: "projects", permission: "view_issue"}], any: true)
issueListObjPermRequiredPaginated(pagination: OffsetPaginationInput): [IssueType!]! @hasRetvalPerm(permissions: [{app: "projects", permission: "view_issue"}], any: true)
issuePaginatedListObjPermRequiredPaginated(pagination: OffsetPaginationInput): IssueTypePaginated! @hasRetvalPerm(permissions: [{app: "projects", permission: "view_issue"}], any: true)
issueConnObjPermRequired(
"""Returns the items in the list that come before the specified cursor."""
before: String = null
Expand Down
13 changes: 13 additions & 0 deletions tests/projects/snapshots/schema_with_inheritance.gql
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,17 @@ type IssueTypeEdge {
node: IssueType!
}

type IssueTypePaginated {
limit: Int!
offset: Int!

"""Total count of existing results."""
totalCount: Int!

"""List of paginated results."""
results: [IssueType!]!
}

input MilestoneFilter {
name: StrFilterLookup
project: DjangoModelFilterInput
Expand Down Expand Up @@ -163,6 +174,7 @@ type MilestoneType implements Node & Named {
dueDate: Date
project: ProjectType!
issues(filters: IssueFilter, order: IssueOrder, pagination: OffsetPaginationInput): [IssueType!]!
issuesPaginated(pagination: OffsetPaginationInput): IssueTypePaginated!
issuesWithFilters(
filters: IssueFilter

Expand Down Expand Up @@ -190,6 +202,7 @@ type MilestoneTypeSubclass implements Node & Named {
dueDate: Date
project: ProjectType!
issues(filters: IssueFilter, order: IssueOrder, pagination: OffsetPaginationInput): [IssueType!]!
issuesPaginated(pagination: OffsetPaginationInput): IssueTypePaginated!
issuesWithFilters(
filters: IssueFilter

Expand Down
Loading

0 comments on commit 2ca008c

Please sign in to comment.