From 253ccc7cb3a46e6605040d469f03bc21e2dd0b39 Mon Sep 17 00:00:00 2001 From: Thiago Bellini Ribeiro Date: Thu, 17 Oct 2024 00:09:35 -0300 Subject: [PATCH] Make optimizer work with Paginated results --- strawberry_django/fields/field.py | 4 +- strawberry_django/optimizer.py | 79 +++++++- strawberry_django/pagination.py | 48 +++-- strawberry_django/resolvers.py | 4 +- tests/projects/schema.py | 14 ++ tests/projects/snapshots/schema.gql | 15 ++ .../snapshots/schema_with_inheritance.gql | 13 ++ tests/test_optimizer.py | 180 ++++++++++++++++++ 8 files changed, 336 insertions(+), 21 deletions(-) diff --git a/strawberry_django/fields/field.py b/strawberry_django/fields/field.py index 859c64be..f622b8a1 100644 --- a/strawberry_django/fields/field.py +++ b/strawberry_django/fields/field.py @@ -248,8 +248,8 @@ def resolve(): kwargs["info"] = info result = django_resolver( - partial(self.get_wrapped_result, **kwargs), - qs_hook=self.get_queryset_hook(**kwargs), + self.get_queryset_hook(**kwargs), + qs_hook=partial(self.get_wrapped_result, **kwargs), )(result) return result diff --git a/strawberry_django/optimizer.py b/strawberry_django/optimizer.py index 468438cb..45f1ea2c 100644 --- a/strawberry_django/optimizer.py +++ b/strawberry_django/optimizer.py @@ -51,7 +51,7 @@ from typing_extensions import assert_never, assert_type from strawberry_django.fields.types import resolve_model_field_name -from strawberry_django.pagination import apply_window_pagination +from strawberry_django.pagination import Paginated, apply_window_pagination from strawberry_django.queryset import get_queryset_config, run_type_get_queryset from strawberry_django.relay import ListConnectionWithTotalCount from strawberry_django.resolvers import django_fetch @@ -528,7 +528,7 @@ def _optimize_prefetch_queryset( ) field_kwargs.pop("info", None) - # Disable the optimizer to avoid doint double optimization while running get_queryset + # Disable the optimizer to avoid doing double optimization while running get_queryset with DjangoOptimizerExtension.disabled(): qs = field.get_queryset( qs, @@ -574,6 +574,15 @@ def _optimize_prefetch_queryset( else: mark_optimized = False + if isinstance(field.type, type) and issubclass(field.type, Paginated): + pagination = field_kwargs.get("pagination") + qs = apply_window_pagination( + qs, + related_field_id=related_field_id, + offset=pagination.offset if pagination else 0, + limit=pagination.limit if pagination else -1, + ) + if mark_optimized: qs = mark_optimized_by_prefetching(qs) @@ -977,8 +986,7 @@ def _get_model_hints( ) -> OptimizerStore | None: cache = cache or {} - # In case this is a relay field, find the selected edges/nodes, the selected fields - # are actually inside edges -> node selection... + # In case this is a relay field, the selected fields are inside edges -> node selection if issubclass(object_definition.origin, relay.Connection): return _get_model_hints_from_connection( model, @@ -992,6 +1000,20 @@ def _get_model_hints( level=level, ) + # In case this is a relay field, the selected fields are inside results selection + if issubclass(object_definition.origin, Paginated): + return _get_model_hints_from_paginated( + model, + schema, + object_definition, + parent_type=parent_type, + info=info, + config=config, + prefix=prefix, + cache=cache, + level=level, + ) + store = OptimizerStore() config = config or OptimizerConfig() @@ -1156,6 +1178,55 @@ def _get_model_hints_from_connection( return store +def _get_model_hints_from_paginated( + model: type[models.Model], + schema: Schema, + object_definition: StrawberryObjectDefinition, + *, + parent_type: GraphQLObjectType | GraphQLInterfaceType, + info: GraphQLResolveInfo, + config: OptimizerConfig | None = None, + prefix: str = "", + cache: dict[type[models.Model], list[tuple[int, OptimizerStore]]] | None = None, + level: int = 0, +) -> OptimizerStore | None: + store = None + + n_type = object_definition.type_var_map.get("NodeType") + n_definition = get_object_definition(n_type, strict=True) + n_gql_definition = _get_gql_definition( + schema, + get_object_definition(n_type, strict=True), + ) + assert isinstance(n_gql_definition, (GraphQLObjectType, GraphQLInterfaceType)) + + for selections in _get_selections(info, parent_type).values(): + selection = selections[0] + if selection.name.value != "results": + continue + + n_info = _generate_selection_resolve_info( + info, + selections, + n_gql_definition, + n_gql_definition, + ) + + store = _get_model_hints( + model=model, + schema=schema, + object_definition=n_definition, + parent_type=n_gql_definition, + info=n_info, + config=config, + prefix=prefix, + cache=cache, + level=level, + ) + + return store + + def optimize( qs: QuerySet[_M] | BaseManager[_M], info: GraphQLResolveInfo | Info, diff --git a/strawberry_django/pagination.py b/strawberry_django/pagination.py index 446b1e6a..7eb5fdcf 100644 --- a/strawberry_django/pagination.py +++ b/strawberry_django/pagination.py @@ -45,7 +45,7 @@ def offset(self) -> int: @strawberry.field(description="Total count of existing results.") @django_resolver - def total_count(self) -> int: + def total_count(self, root) -> int: return get_total_count(self.queryset) @strawberry.field(description="List of paginated results.") @@ -103,6 +103,10 @@ def apply( return queryset +class _PaginationWindow(Window): + """Marker to be able to remove where clause at `get_total_count` if needed.""" + + def apply_window_pagination( queryset: _QS, *, @@ -131,13 +135,14 @@ def apply_window_pagination( using=queryset._db or DEFAULT_DB_ALIAS # type: ignore ).get_order_by() ] + queryset = queryset.annotate( - _strawberry_row_number=Window( + _strawberry_row_number=_PaginationWindow( RowNumber(), partition_by=related_field_id, order_by=order_by, ), - _strawberry_total_count=Window( + _strawberry_total_count=_PaginationWindow( Count(1), partition_by=related_field_id, ), @@ -165,17 +170,28 @@ def get_total_count(queryset: QuerySet) -> int: if is_optimized_by_prefetching(queryset): results = queryset._result_cache # type: ignore - try: - return results[0]._strawberry_total_count if results else 0 - except AttributeError: - warnings.warn( - ( - "Pagination annotations not found, falling back to QuerySet resolution. " - "This might cause n+1 issues..." - ), - RuntimeWarning, - stacklevel=2, + if results: + try: + return results[0]._strawberry_total_count + except AttributeError: + warnings.warn( + ( + "Pagination annotations not found, falling back to QuerySet resolution. " + "This might cause n+1 issues..." + ), + RuntimeWarning, + stacklevel=2, + ) + + queryset = queryset._chain() # type: ignore + queryset.query.where.children = [ + child + for child in queryset.query.where.children + if ( + not hasattr(child, "lhs") + or not isinstance(child.lhs, _PaginationWindow) ) + ] return queryset.count() @@ -243,6 +259,12 @@ def get_queryset( **kwargs, ) -> _QS: queryset = super().get_queryset(queryset, info, **kwargs) + + # If this is `Paginated`, return the queryset as is as the pagination will + # be resolved when resolving its results. + if self.is_paginated: + return queryset + return self.apply_pagination( queryset, pagination, diff --git a/strawberry_django/resolvers.py b/strawberry_django/resolvers.py index c4e186b0..491b486e 100644 --- a/strawberry_django/resolvers.py +++ b/strawberry_django/resolvers.py @@ -193,7 +193,7 @@ def resolve_base_manager(manager: BaseManager) -> Any: # prevents us from importing and checking isinstance on them directly. try: # ManyRelatedManager - return list(prefetched_cache[manager.prefetch_cache_name]) # type: ignore + return prefetched_cache[manager.prefetch_cache_name] # type: ignore except (AttributeError, KeyError): try: # RelatedManager @@ -203,7 +203,7 @@ def resolve_base_manager(manager: BaseManager) -> Any: getattr(result_field.remote_field, "cache_name", None) or result_field.remote_field.get_cache_name() ) - return list(prefetched_cache[cache_name]) + return prefetched_cache[cache_name] except (AttributeError, KeyError): pass diff --git a/tests/projects/schema.py b/tests/projects/schema.py index 83097832..4b62f5ed 100644 --- a/tests/projects/schema.py +++ b/tests/projects/schema.py @@ -29,6 +29,7 @@ from strawberry_django.fields.types import ListInput, NodeInput, NodeInputPartial from strawberry_django.mutations import resolvers from strawberry_django.optimizer import DjangoOptimizerExtension +from strawberry_django.pagination import Paginated from strawberry_django.permissions import ( HasPerm, HasRetvalPerm, @@ -158,6 +159,10 @@ class MilestoneType(relay.Node, Named): order=IssueOrder, pagination=True, ) + issues_paginated: Paginated["IssueType"] = strawberry_django.field( + field_name="issues", + order=IssueOrder, + ) issues_with_filters: ListConnectionWithTotalCount["IssueType"] = ( strawberry_django.connection( field_name="issues", @@ -375,6 +380,7 @@ class Query: staff_list: List[Optional[StaffType]] = strawberry_django.node() issue_list: List[IssueType] = strawberry_django.field() + issues_paginated: Paginated[IssueType] = strawberry_django.field() milestone_list: List[MilestoneType] = strawberry_django.field( order=MilestoneOrder, filters=MilestoneFilter, @@ -429,6 +435,9 @@ class Query: issue_list_perm_required: List[IssueType] = strawberry_django.field( extensions=[HasPerm(perms=["projects.view_issue"])], ) + issue_paginated_list_perm_required: Paginated[IssueType] = strawberry_django.field( + extensions=[HasPerm(perms=["projects.view_issue"])], + ) issue_conn_perm_required: ListConnectionWithTotalCount[IssueType] = ( strawberry_django.connection( extensions=[HasPerm(perms=["projects.view_issue"])], @@ -447,6 +456,11 @@ class Query: issue_list_obj_perm_required_paginated: List[IssueType] = strawberry_django.field( extensions=[HasRetvalPerm(perms=["projects.view_issue"])], pagination=True ) + issue_paginated_list_obj_perm_required_paginated: Paginated[IssueType] = ( + strawberry_django.field( + extensions=[HasRetvalPerm(perms=["projects.view_issue"])], pagination=True + ) + ) issue_conn_obj_perm_required: ListConnectionWithTotalCount[IssueType] = ( strawberry_django.connection( extensions=[HasRetvalPerm(perms=["projects.view_issue"])], diff --git a/tests/projects/snapshots/schema.gql b/tests/projects/snapshots/schema.gql index 16533658..212a75ed 100644 --- a/tests/projects/snapshots/schema.gql +++ b/tests/projects/snapshots/schema.gql @@ -336,6 +336,17 @@ type IssueTypeEdge { node: IssueType! } +type IssueTypePaginated { + limit: Int! + offset: Int! + + """Total count of existing results.""" + totalCount: Int! + + """List of paginated results.""" + results: [IssueType!]! +} + input MilestoneFilter { name: StrFilterLookup project: DjangoModelFilterInput @@ -373,6 +384,7 @@ type MilestoneType implements Node & Named { dueDate: Date project: ProjectType! issues(filters: IssueFilter, order: IssueOrder, pagination: OffsetPaginationInput): [IssueType!]! + issuesPaginated(pagination: OffsetPaginationInput): IssueTypePaginated! issuesWithFilters( filters: IssueFilter @@ -610,6 +622,7 @@ type Query { ids: [GlobalID!]! ): [StaffType]! issueList: [IssueType!]! + issuesPaginated(pagination: OffsetPaginationInput): IssueTypePaginated! milestoneList(filters: MilestoneFilter, order: MilestoneOrder, pagination: OffsetPaginationInput): [MilestoneType!]! projectList(filters: ProjectFilter): [ProjectType!]! tagList: [TagType!]! @@ -729,6 +742,7 @@ type Query { id: GlobalID! ): IssueType @hasPerm(permissions: [{app: "projects", permission: "view_issue"}], any: true) issueListPermRequired: [IssueType!]! @hasPerm(permissions: [{app: "projects", permission: "view_issue"}], any: true) + issuePaginatedListPermRequired(pagination: OffsetPaginationInput): IssueTypePaginated! @hasPerm(permissions: [{app: "projects", permission: "view_issue"}], any: true) issueConnPermRequired( """Returns the items in the list that come before the specified cursor.""" before: String = null @@ -752,6 +766,7 @@ type Query { ): IssueType @hasRetvalPerm(permissions: [{app: "projects", permission: "view_issue"}], any: true) issueListObjPermRequired: [IssueType!]! @hasRetvalPerm(permissions: [{app: "projects", permission: "view_issue"}], any: true) issueListObjPermRequiredPaginated(pagination: OffsetPaginationInput): [IssueType!]! @hasRetvalPerm(permissions: [{app: "projects", permission: "view_issue"}], any: true) + issuePaginatedListObjPermRequiredPaginated(pagination: OffsetPaginationInput): IssueTypePaginated! @hasRetvalPerm(permissions: [{app: "projects", permission: "view_issue"}], any: true) issueConnObjPermRequired( """Returns the items in the list that come before the specified cursor.""" before: String = null diff --git a/tests/projects/snapshots/schema_with_inheritance.gql b/tests/projects/snapshots/schema_with_inheritance.gql index 8c891a65..2dcf1e1a 100644 --- a/tests/projects/snapshots/schema_with_inheritance.gql +++ b/tests/projects/snapshots/schema_with_inheritance.gql @@ -136,6 +136,17 @@ type IssueTypeEdge { node: IssueType! } +type IssueTypePaginated { + limit: Int! + offset: Int! + + """Total count of existing results.""" + totalCount: Int! + + """List of paginated results.""" + results: [IssueType!]! +} + input MilestoneFilter { name: StrFilterLookup project: DjangoModelFilterInput @@ -163,6 +174,7 @@ type MilestoneType implements Node & Named { dueDate: Date project: ProjectType! issues(filters: IssueFilter, order: IssueOrder, pagination: OffsetPaginationInput): [IssueType!]! + issuesPaginated(pagination: OffsetPaginationInput): IssueTypePaginated! issuesWithFilters( filters: IssueFilter @@ -190,6 +202,7 @@ type MilestoneTypeSubclass implements Node & Named { dueDate: Date project: ProjectType! issues(filters: IssueFilter, order: IssueOrder, pagination: OffsetPaginationInput): [IssueType!]! + issuesPaginated(pagination: OffsetPaginationInput): IssueTypePaginated! issuesWithFilters( filters: IssueFilter diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index 9e456e9d..b59ac51d 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -1425,3 +1425,183 @@ class Query: "issues": [{"pk": str(issue1.pk)}], }, } + + +@pytest.mark.django_db(transaction=True) +def test_query_paginated(db, gql_client: GraphQLTestClient): + query = """ + query TestQuery ($pagination: OffsetPaginationInput) { + issuesPaginated (pagination: $pagination) { + totalCount + results { + name + milestone { + name + } + } + } + } + """ + + milestone1 = MilestoneFactory.create() + milestone2 = MilestoneFactory.create() + + issue1 = IssueFactory.create(milestone=milestone1) + issue2 = IssueFactory.create(milestone=milestone1) + issue3 = IssueFactory.create(milestone=milestone1) + issue4 = IssueFactory.create(milestone=milestone2) + issue5 = IssueFactory.create(milestone=milestone2) + + with assert_num_queries(2 if DjangoOptimizerExtension.enabled.get() else 7): + res = gql_client.query(query) + + assert res.data == { + "issuesPaginated": { + "totalCount": 5, + "results": [ + {"name": issue1.name, "milestone": {"name": milestone1.name}}, + {"name": issue2.name, "milestone": {"name": milestone1.name}}, + {"name": issue3.name, "milestone": {"name": milestone1.name}}, + {"name": issue4.name, "milestone": {"name": milestone2.name}}, + {"name": issue5.name, "milestone": {"name": milestone2.name}}, + ], + } + } + + with assert_num_queries(2 if DjangoOptimizerExtension.enabled.get() else 4): + res = gql_client.query(query, variables={"pagination": {"limit": 2}}) + + assert res.data == { + "issuesPaginated": { + "totalCount": 5, + "results": [ + {"name": issue1.name, "milestone": {"name": milestone1.name}}, + {"name": issue2.name, "milestone": {"name": milestone1.name}}, + ], + } + } + + with assert_num_queries(2 if DjangoOptimizerExtension.enabled.get() else 4): + res = gql_client.query( + query, variables={"pagination": {"limit": 2, "offset": 2}} + ) + + assert res.data == { + "issuesPaginated": { + "totalCount": 5, + "results": [ + {"name": issue3.name, "milestone": {"name": milestone1.name}}, + {"name": issue4.name, "milestone": {"name": milestone2.name}}, + ], + } + } + + +@pytest.mark.django_db(transaction=True) +def test_query_paginated_nested(db, gql_client: GraphQLTestClient): + query = """ + query TestQuery ($pagination: OffsetPaginationInput) { + milestoneList { + name + issuesPaginated (pagination: $pagination) { + totalCount + results { + name + milestone { + name + } + } + } + } + } + """ + + milestone1 = MilestoneFactory.create() + milestone2 = MilestoneFactory.create() + + issue1 = IssueFactory.create(milestone=milestone1) + issue2 = IssueFactory.create(milestone=milestone1) + issue3 = IssueFactory.create(milestone=milestone1) + issue4 = IssueFactory.create(milestone=milestone2) + issue5 = IssueFactory.create(milestone=milestone2) + + with assert_num_queries(2 if DjangoOptimizerExtension.enabled.get() else 5): + res = gql_client.query(query) + + assert res.data == { + "milestoneList": [ + { + "name": milestone1.name, + "issuesPaginated": { + "totalCount": 3, + "results": [ + {"name": issue1.name, "milestone": {"name": milestone1.name}}, + {"name": issue2.name, "milestone": {"name": milestone1.name}}, + {"name": issue3.name, "milestone": {"name": milestone1.name}}, + ], + }, + }, + { + "name": milestone2.name, + "issuesPaginated": { + "totalCount": 2, + "results": [ + {"name": issue4.name, "milestone": {"name": milestone2.name}}, + {"name": issue5.name, "milestone": {"name": milestone2.name}}, + ], + }, + }, + ] + } + + with assert_num_queries(2 if DjangoOptimizerExtension.enabled.get() else 5): + res = gql_client.query(query, variables={"pagination": {"limit": 1}}) + + assert res.data == { + "milestoneList": [ + { + "name": milestone1.name, + "issuesPaginated": { + "totalCount": 3, + "results": [ + {"name": issue1.name, "milestone": {"name": milestone1.name}}, + ], + }, + }, + { + "name": milestone2.name, + "issuesPaginated": { + "totalCount": 2, + "results": [ + {"name": issue4.name, "milestone": {"name": milestone2.name}}, + ], + }, + }, + ] + } + + with assert_num_queries(3 if DjangoOptimizerExtension.enabled.get() else 5): + res = gql_client.query( + query, variables={"pagination": {"limit": 1, "offset": 2}} + ) + + assert res.data == { + "milestoneList": [ + { + "name": milestone1.name, + "issuesPaginated": { + "totalCount": 3, + "results": [ + {"name": issue3.name, "milestone": {"name": milestone1.name}}, + ], + }, + }, + { + "name": milestone2.name, + "issuesPaginated": { + "totalCount": 2, + "results": [], + }, + }, + ] + }