Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(optimizer): Fix nested pagination optimization for m2m relations #681

Merged
merged 1 commit into from
Jan 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions strawberry_django/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1421,6 +1421,11 @@ def __init__(
self.enable_nested_relations_prefetch = enable_nested_relations_prefetch
self.prefetch_custom_queryset = prefetch_custom_queryset

if enable_nested_relations_prefetch:
from strawberry_django.utils.patches import apply_pagination_fix

apply_pagination_fix()

def on_execute(self) -> Generator[None]:
token = optimizer.set(self)
try:
Expand Down
78 changes: 78 additions & 0 deletions strawberry_django/utils/patches.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import django
from django.db import (
DEFAULT_DB_ALIAS,
NotSupportedError,
connections,
)
from django.db.models import Q, Window
from django.db.models.fields import related_descriptors
from django.db.models.functions import RowNumber
from django.db.models.lookups import GreaterThan, LessThanOrEqual
from django.db.models.sql import Query
from django.db.models.sql.constants import INNER
from django.db.models.sql.where import AND


def apply_pagination_fix():
"""Apply pagination fix for Django 5.1 or older.

This is based on the fix in this patch, which is going to be included in Django 5.2:
https://code.djangoproject.com/ticket/35677#comment:9

If can safely be removed when Django 5.2 is the minimum version we support
"""
if django.VERSION >= (5, 2):
return
Comment on lines +24 to +25
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!


# This is a copy of the function, exactly as it exists on Django 4.2, 5.0 and 5.1
# (there are no differences in this function between these versions)
def _filter_prefetch_queryset(queryset, field_name, instances):
predicate = Q(**{f"{field_name}__in": instances})
db = queryset._db or DEFAULT_DB_ALIAS
if queryset.query.is_sliced:
if not connections[db].features.supports_over_clause:
raise NotSupportedError(
"Prefetching from a limited queryset is only supported on backends "
"that support window functions."
)
low_mark, high_mark = queryset.query.low_mark, queryset.query.high_mark
order_by = [
expr for expr, _ in queryset.query.get_compiler(using=db).get_order_by()
]
window = Window(RowNumber(), partition_by=field_name, order_by=order_by)
predicate &= GreaterThan(window, low_mark)
if high_mark is not None:
predicate &= LessThanOrEqual(window, high_mark)
queryset.query.clear_limits()

# >> ORIGINAL CODE
# return queryset.filter(predicate) # noqa: ERA001
# << ORIGINAL CODE
# >> PATCHED CODE
queryset.query.add_q(predicate, reuse_all_aliases=True)
return queryset
# << PATCHED CODE

related_descriptors._filter_prefetch_queryset = _filter_prefetch_queryset # type: ignore

# This is a copy of the function, exactly as it exists on Django 4.2, 5.0 and 5.1
# (there are no differences in this function between these versions)
def add_q(self, q_object, reuse_all_aliases=False):
existing_inner = {
a for a in self.alias_map if self.alias_map[a].join_type == INNER
}
# >> ORIGINAL CODE
# clause, _ = self._add_q(q_object, self.used_aliases) # noqa: ERA001
# << ORIGINAL CODE
# >> PATCHED CODE
if reuse_all_aliases: # noqa: SIM108
can_reuse = set(self.alias_map)
else:
can_reuse = self.used_aliases
bellini666 marked this conversation as resolved.
Show resolved Hide resolved
clause, _ = self._add_q(q_object, can_reuse)
# << PATCHED CODE
if clause:
self.where.add(clause, AND)
self.demote_joins(existing_inner)

Query.add_q = add_q
10 changes: 8 additions & 2 deletions tests/projects/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,15 +115,18 @@ class Meta:


class Issue(NamedModel):
comments: "RelatedManager[Issue]"
issue_assignees: "RelatedManager[Assignee]"
class Meta: # type: ignore
ordering = ("id",)

class Kind(models.TextChoices):
"""Issue kind options."""

BUG = "b", "Bug"
FEATURE = "f", "Feature"

comments: "RelatedManager[Issue]"
issue_assignees: "RelatedManager[Assignee]"

id = models.BigAutoField(
verbose_name="ID",
primary_key=True,
Expand Down Expand Up @@ -203,6 +206,9 @@ class Meta:


class Tag(NamedModel):
class Meta: # type: ignore
ordering = ("id",)

issues: "RelatedManager[Issue]"

id = models.BigAutoField(
Expand Down
76 changes: 75 additions & 1 deletion tests/test_pagination.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,19 @@
from strawberry.types import ExecutionResult

import strawberry_django
from strawberry_django.optimizer import DjangoOptimizerExtension
from strawberry_django.pagination import (
OffsetPaginationInput,
apply,
apply_window_pagination,
)
from tests import models, utils
from tests.projects.faker import MilestoneFactory, ProjectFactory
from tests.projects.faker import (
IssueFactory,
MilestoneFactory,
ProjectFactory,
TagFactory,
)
bellini666 marked this conversation as resolved.
Show resolved Hide resolved


@strawberry_django.type(models.Fruit, pagination=True)
Expand Down Expand Up @@ -145,3 +151,71 @@ def test_apply_window_pagination_with_no_limites(limit):
assert first_fruit.name == "fruit2"
assert first_fruit._strawberry_row_number == 3 # type: ignore
assert first_fruit._strawberry_total_count == 10 # type: ignore


@pytest.mark.django_db(transaction=True)
def test_nested_pagination_m2m(gql_client: utils.GraphQLTestClient):
# Create 2 tags and 3 issues
tags = [TagFactory(name=f"Tag {i + 1}") for i in range(2)]
issues = [IssueFactory(name=f"Issue {i + 1}") for i in range(3)]
# Assign issues 1 and 2 to the 1st tag
# Assign issues 2 and 3 to the 2nd tag
# This means that both tags share the 2nd issue
tags[0].issues.set(issues[:2])
tags[1].issues.set(issues[1:])
bellini666 marked this conversation as resolved.
Show resolved Hide resolved
with utils.assert_num_queries(3 if DjangoOptimizerExtension.enabled.get() else 6):
result = gql_client.query(
"""
query {
tagConn {
totalCount
edges {
node {
name
issues {
totalCount
edges {
node {
name
}
}
}
}
}
}
}
"""
)
# Check the results
assert not result.errors
assert result.data == {
"tagConn": {
"totalCount": 2,
"edges": [
{
"node": {
"name": "Tag 1",
"issues": {
"totalCount": 2,
"edges": [
{"node": {"name": "Issue 1"}},
{"node": {"name": "Issue 2"}},
],
},
}
},
{
"node": {
"name": "Tag 2",
"issues": {
"totalCount": 2,
"edges": [
{"node": {"name": "Issue 2"}},
{"node": {"name": "Issue 3"}},
],
},
}
},
],
}
}
Loading