Skip to content

Commit

Permalink
fix(optimizer): Convert select_related into Prefetch when the type de…
Browse files Browse the repository at this point in the history
…fines a custom get_queryset (#583)
  • Loading branch information
bellini666 authored Jul 14, 2024
1 parent 3aa993a commit c7c50e9
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 22 deletions.
2 changes: 2 additions & 0 deletions strawberry_django/fields/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,8 @@ def default_resolver(

return cast(Iterable[Any], retval)

default_resolver.can_optimize = True # type: ignore

field.base_resolver = StrawberryResolver(default_resolver)

return super().apply(field)
Expand Down
68 changes: 58 additions & 10 deletions strawberry_django/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
PrefetchType,
TypeOrMapping,
TypeOrSequence,
WithStrawberryDjangoObjectDefinition,
get_django_definition,
has_django_definition,
)
Expand Down Expand Up @@ -422,6 +423,20 @@ def _apply_annotate(
return qs.annotate(**to_annotate)


def _get_django_type(
field: StrawberryField,
) -> type[WithStrawberryDjangoObjectDefinition] | None:
f_type = field.type
if isinstance(f_type, LazyType):
f_type = f_type.resolve_type()
if isinstance(f_type, StrawberryContainer):
f_type = f_type.of_type
if isinstance(f_type, LazyType):
f_type = f_type.resolve_type()

return f_type if has_django_definition(f_type) else None


def _get_prefetch_queryset(
remote_model: type[models.Model],
schema: Schema,
Expand All @@ -443,15 +458,7 @@ def _get_prefetch_queryset(
else:
qs = remote_model._base_manager.all() # type: ignore

f_type = field.type
if isinstance(f_type, LazyType):
f_type = f_type.resolve_type()
while isinstance(f_type, StrawberryContainer):
f_type = f_type.of_type
if isinstance(f_type, LazyType):
f_type = f_type.resolve_type()

if has_django_definition(f_type):
if f_type := _get_django_type(field):
qs = run_type_get_queryset(
qs,
f_type,
Expand Down Expand Up @@ -690,16 +697,39 @@ def _get_hints_from_model_property(
def _get_hints_from_django_foreign_key(
field: StrawberryField,
field_definition: GraphQLObjectType,
field_selection: FieldNode,
model_field: models.ForeignKey | OneToOneRel,
model_fieldname: str,
schema: Schema,
*,
config: OptimizerConfig,
parent_type: GraphQLObjectType | GraphQLInterfaceType,
field_info: GraphQLResolveInfo,
path: str,
cache: dict[type[models.Model], list[tuple[int, OptimizerStore]]],
level: int = 0,
) -> OptimizerStore:
f_type = _get_django_type(field)
if f_type and hasattr(f_type, "get_queryset"):
# If the field has a get_queryset method, change strategy to Prefetch
# so it will be respected
store = _get_hints_from_django_relation(
field,
field_definition=field_definition,
field_selection=field_selection,
model_field=model_field,
model_fieldname=model_fieldname,
schema=schema,
config=config,
parent_type=parent_type,
field_info=field_info,
path=path,
cache=cache,
level=level,
)
store.only.append(path)
return store

store = OptimizerStore.with_hints(
only=[path],
select_related=[path],
Expand Down Expand Up @@ -737,7 +767,12 @@ def _get_hints_from_django_relation(
field_definition: GraphQLObjectType,
field_selection: FieldNode,
model_field: (
models.ManyToManyField | ManyToManyRel | ManyToOneRel | GenericRelation
models.ManyToManyField
| ManyToManyRel
| ManyToOneRel
| GenericRelation
| OneToOneRel
| models.ForeignKey
),
model_fieldname: str,
schema: Schema,
Expand Down Expand Up @@ -861,6 +896,17 @@ def _get_hints_from_django_field(
GenericRelation,
)

# If the field has a base resolver, don't try to optimize it. The user should
# be defining custom hints in this case, which should already be in the store
# GlobalID and special cases setting `can_optimize` are ok though, as those resolvers
# are auto generated by us
if (
field.base_resolver is not None
and field.type != relay.GlobalID
and not getattr(field.base_resolver.wrapped_func, "can_optimize", False)
):
return None

model_fieldname: str = getattr(field, "django_name", None) or field.python_name
if (model_field := get_model_field(model, model_fieldname)) is None:
return None
Expand All @@ -871,10 +917,12 @@ def _get_hints_from_django_field(
store = _get_hints_from_django_foreign_key(
field,
field_definition=field_definition,
field_selection=field_selection,
model_field=model_field,
model_fieldname=model_fieldname,
schema=schema,
config=config,
parent_type=parent_type,
field_info=field_info,
path=path,
cache=cache,
Expand Down
14 changes: 6 additions & 8 deletions tests/projects/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,12 @@ class MilestoneType(relay.Node, Named):
order=IssueOrder,
pagination=True,
)
issues_with_filters: ListConnectionWithTotalCount["IssueType"] = (
strawberry_django.connection(
field_name="issues",
filters=IssueFilter,
)
)

@strawberry_django.field(
prefetch_related=[
Expand All @@ -178,14 +184,6 @@ class MilestoneType(relay.Node, Named):
def my_issues(self) -> List["IssueType"]:
return self._my_issues # type: ignore

@strawberry_django.connection(
ListConnectionWithTotalCount["IssueType"],
field_name="issues",
filters=IssueFilter,
)
def issues_with_filters(self) -> List["IssueType"]:
return self.issues.all() # type: ignore

@strawberry_django.field(
annotate={
"_my_bugs_count": lambda info: Count(
Expand Down
2 changes: 1 addition & 1 deletion tests/projects/snapshots/schema.gql
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,6 @@ type MilestoneType implements Node & Named {
dueDate: Date
project: ProjectType!
issues(filters: IssueFilter, order: IssueOrder, pagination: OffsetPaginationInput): [IssueType!]!
myIssues: [IssueType!]!
issuesWithFilters(
filters: IssueFilter

Expand All @@ -389,6 +388,7 @@ type MilestoneType implements Node & Named {
"""Returns the items in the list that come after the specified cursor."""
last: Int = null
): IssueTypeConnection!
myIssues: [IssueType!]!
myBugsCount: Int!
asyncField(value: String!): String!
}
Expand Down
4 changes: 2 additions & 2 deletions tests/projects/snapshots/schema_with_inheritance.gql
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,6 @@ type MilestoneType implements Node & Named {
dueDate: Date
project: ProjectType!
issues(filters: IssueFilter, order: IssueOrder, pagination: OffsetPaginationInput): [IssueType!]!
myIssues: [IssueType!]!
issuesWithFilters(
filters: IssueFilter

Expand All @@ -179,6 +178,7 @@ type MilestoneType implements Node & Named {
"""Returns the items in the list that come after the specified cursor."""
last: Int = null
): IssueTypeConnection!
myIssues: [IssueType!]!
myBugsCount: Int!
asyncField(value: String!): String!
}
Expand All @@ -190,7 +190,6 @@ type MilestoneTypeSubclass implements Node & Named {
dueDate: Date
project: ProjectType!
issues(filters: IssueFilter, order: IssueOrder, pagination: OffsetPaginationInput): [IssueType!]!
myIssues: [IssueType!]!
issuesWithFilters(
filters: IssueFilter

Expand All @@ -206,6 +205,7 @@ type MilestoneTypeSubclass implements Node & Named {
"""Returns the items in the list that come after the specified cursor."""
last: Int = null
): IssueTypeConnection!
myIssues: [IssueType!]!
myBugsCount: Int!
asyncField(value: String!): String!
}
Expand Down
54 changes: 53 additions & 1 deletion tests/test_optimizer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import datetime
from typing import Any, List, cast
from typing import Any, List, Optional, cast

import pytest
import strawberry
Expand Down Expand Up @@ -1365,3 +1365,55 @@ def test_nested_prefetch_with_get_queryset(
},
}
mock_get_queryset.assert_called_once()


@pytest.mark.django_db(transaction=True)
def test_select_related_fallsback_to_prefetch_when_type_defines_get_queryset(
db,
):
@strawberry_django.type(Milestone)
class MilestoneType:
pk: strawberry.ID

@classmethod
def get_queryset(cls, queryset, info, **kwargs):
return queryset.filter(name__startswith="Foo")

@strawberry_django.type(Issue)
class IssueType:
pk: strawberry.ID
milestone: Optional[MilestoneType]

@strawberry.type
class Query:
issues: List[IssueType] = strawberry_django.field()

schema = strawberry.Schema(query=Query, extensions=[DjangoOptimizerExtension])

milestone1 = MilestoneFactory.create(name="Foo")
milestone2 = MilestoneFactory.create(name="Bar")

issue1 = IssueFactory.create(milestone=milestone1)
issue2 = IssueFactory.create(milestone=milestone2)

query = """\
query TestQuery {
issues {
pk
milestone {
pk
}
}
}
"""

with assert_num_queries(2):
res = schema.execute_sync(query)

assert res.errors is None
assert res.data == {
"issues": [
{"pk": str(issue1.pk), "milestone": {"pk": str(milestone1.pk)}},
{"pk": str(issue2.pk), "milestone": None},
],
}

0 comments on commit c7c50e9

Please sign in to comment.