Skip to content

Commit

Permalink
test: add test case for optimizer using interfaces
Browse files Browse the repository at this point in the history
  • Loading branch information
bellini666 committed Jun 12, 2024
1 parent 3481e8d commit da61de8
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 35 deletions.
30 changes: 13 additions & 17 deletions tests/projects/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,16 @@
User = get_user_model()


class Project(models.Model):
class NamedModel(models.Model):
class Meta: # type: ignore
abstract = True

name = models.CharField(
max_length=255,
)


class Project(NamedModel):
class Status(models.TextChoices):
"""Project status options."""

Expand All @@ -33,10 +42,6 @@ class Status(models.TextChoices):
choices_enum=Status,
default=Status.ACTIVE,
)
name = models.CharField(
help_text="The name of the project",
max_length=255,
)
due_date = models.DateField(
null=True,
blank=True,
Expand All @@ -55,16 +60,13 @@ def is_small(self) -> bool:
return self._milestone_count < 3 # type: ignore


class Milestone(models.Model):
class Milestone(NamedModel):
issues: "RelatedManager[Issue]"

id = models.BigAutoField(
verbose_name="ID",
primary_key=True,
)
name = models.CharField(
max_length=255,
)
due_date = models.DateField(
null=True,
blank=True,
Expand Down Expand Up @@ -112,7 +114,7 @@ class Meta: # type: ignore
objects = FavoriteQuerySet.as_manager()


class Issue(models.Model):
class Issue(NamedModel):
comments: "RelatedManager[Issue]"
issue_assignees: "RelatedManager[Assignee]"

Expand All @@ -126,9 +128,6 @@ class Kind(models.TextChoices):
verbose_name="ID",
primary_key=True,
)
name = models.CharField(
max_length=255,
)
kind = models.CharField(
verbose_name="kind",
help_text="the kind of the issue",
Expand Down Expand Up @@ -203,16 +202,13 @@ class Meta: # type: ignore
)


class Tag(models.Model):
class Tag(NamedModel):
issues: "RelatedManager[Issue]"

id = models.BigAutoField(
verbose_name="ID",
primary_key=True,
)
name = models.CharField(
max_length=255,
)


class Quiz(models.Model):
Expand Down
18 changes: 10 additions & 8 deletions tests/projects/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
FavoriteQuerySet,
Issue,
Milestone,
NamedModel,
Project,
Quiz,
Tag,
Expand All @@ -53,6 +54,11 @@
UserModel = cast(Type[AbstractUser], get_user_model())


@strawberry_django.interface(NamedModel)
class Named:
name: strawberry.auto


@strawberry_django.type(UserModel)
class UserType(relay.Node):
username: relay.NodeID[str]
Expand Down Expand Up @@ -91,8 +97,7 @@ class ProjectFilter:


@strawberry_django.type(Project, filters=ProjectFilter)
class ProjectType(relay.Node):
name: strawberry.auto
class ProjectType(relay.Node, Named):
due_date: strawberry.auto
milestones: List["MilestoneType"]
milestones_count: int = strawberry_django.field(annotate=Count("milestone"))
Expand Down Expand Up @@ -143,8 +148,7 @@ class IssueOrder:


@strawberry_django.type(Milestone, filters=MilestoneFilter, order=MilestoneOrder)
class MilestoneType(relay.Node):
name: strawberry.auto
class MilestoneType(relay.Node, Named):
due_date: strawberry.auto
project: ProjectType
issues: List["IssueType"] = strawberry_django.field(
Expand Down Expand Up @@ -212,8 +216,7 @@ def get_queryset(cls, queryset: FavoriteQuerySet, info: Info, **kwargs) -> Query


@strawberry_django.type(Issue)
class IssueType(relay.Node):
name: strawberry.auto
class IssueType(relay.Node, Named):
milestone: MilestoneType
priority: strawberry.auto
kind: strawberry.auto
Expand Down Expand Up @@ -251,8 +254,7 @@ def private_name(self, root: Issue) -> Optional[str]:


@strawberry_django.type(Tag)
class TagType(relay.Node):
name: strawberry.auto
class TagType(relay.Node, Named):
issues: ListConnectionWithTotalCount[IssueType] = strawberry_django.connection()

@strawberry_django.field
Expand Down
12 changes: 8 additions & 4 deletions tests/projects/snapshots/schema.gql
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ input IssueOrder {
name: Ordering
}

type IssueType implements Node {
type IssueType implements Node & Named {
"""The Globally Unique ID of this object"""
id: GlobalID!
name: String!
Expand Down Expand Up @@ -365,7 +365,7 @@ input MilestoneOrder {
project: ProjectOrder
}

type MilestoneType implements Node {
type MilestoneType implements Node & Named {
"""The Globally Unique ID of this object"""
id: GlobalID!
name: String!
Expand Down Expand Up @@ -431,6 +431,10 @@ type Mutation {
): CreateQuizPayload!
}

interface Named {
name: String!
}

"""An object with a Globally Unique ID"""
interface Node {
"""The Globally Unique ID of this object"""
Expand Down Expand Up @@ -536,7 +540,7 @@ input ProjectOrder {
name: Ordering
}

type ProjectType implements Node {
type ProjectType implements Node & Named {
"""The Globally Unique ID of this object"""
id: GlobalID!
name: String!
Expand Down Expand Up @@ -877,7 +881,7 @@ input TagInputPartialListInput {
remove: [TagInputPartial!]
}

type TagType implements Node {
type TagType implements Node & Named {
"""The Globally Unique ID of this object"""
id: GlobalID!
name: String!
Expand Down
16 changes: 10 additions & 6 deletions tests/projects/snapshots/schema_with_inheritance.gql
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ input IssueOrder {
name: Ordering
}

type IssueType implements Node {
type IssueType implements Node & Named {
"""The Globally Unique ID of this object"""
id: GlobalID!
name: String!
Expand Down Expand Up @@ -155,7 +155,7 @@ input MilestoneOrder {
project: ProjectOrder
}

type MilestoneType implements Node {
type MilestoneType implements Node & Named {
"""The Globally Unique ID of this object"""
id: GlobalID!
name: String!
Expand All @@ -182,7 +182,7 @@ type MilestoneType implements Node {
asyncField(value: String!): String!
}

type MilestoneTypeSubclass implements Node {
type MilestoneTypeSubclass implements Node & Named {
"""The Globally Unique ID of this object"""
id: GlobalID!
name: String!
Expand Down Expand Up @@ -213,6 +213,10 @@ type Mutation {
createIssue(input: IssueInputSubclass!): CreateIssuePayload!
}

interface Named {
name: String!
}

"""An object with a Globally Unique ID"""
interface Node {
"""The Globally Unique ID of this object"""
Expand Down Expand Up @@ -287,7 +291,7 @@ input ProjectOrder {
name: Ordering
}

type ProjectType implements Node {
type ProjectType implements Node & Named {
"""The Globally Unique ID of this object"""
id: GlobalID!
name: String!
Expand All @@ -299,7 +303,7 @@ type ProjectType implements Node {
isSmall: Boolean!
}

type ProjectTypeSubclass implements Node {
type ProjectTypeSubclass implements Node & Named {
"""The Globally Unique ID of this object"""
id: GlobalID!
name: String!
Expand Down Expand Up @@ -370,7 +374,7 @@ input StrFilterLookup {
iRegex: String
}

type TagType implements Node {
type TagType implements Node & Named {
"""The Globally Unique ID of this object"""
id: GlobalID!
name: String!
Expand Down
62 changes: 62 additions & 0 deletions tests/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,68 @@ def test_query_forward(db, gql_client: GraphQLTestClient):
}


@pytest.mark.django_db(transaction=True)
def test_query_forward_with_interfaces(db, gql_client: GraphQLTestClient):
query = """
query TestQuery ($isAsync: Boolean!) {
issueConn {
totalCount
edges {
node {
id
... on Named {
name
}
milestone {
id
... on Named {
name
}
asyncField(value: "foo") @include (if: $isAsync)
project {
id
... on Named {
name
}
}
}
}
}
}
}
"""

expected = []
for p in ProjectFactory.create_batch(2):
for m in MilestoneFactory.create_batch(2, project=p):
for i in IssueFactory.create_batch(2, milestone=m):
r: dict[str, Any] = {
"id": to_base64("IssueType", i.id),
"name": i.name,
"milestone": {
"id": to_base64("MilestoneType", m.id),
"name": m.name,
"project": {
"id": to_base64("ProjectType", p.id),
"name": p.name,
},
},
}
if gql_client.is_async:
r["milestone"]["asyncField"] = "value: foo"
expected.append(r)

with assert_num_queries(2 if DjangoOptimizerExtension.enabled.get() else 18):
res = gql_client.query(query, {"isAsync": gql_client.is_async})

assert res.data == {
"issueConn": {
"totalCount": 8,
"edges": [{"node": r} for r in expected],
},
}


@pytest.mark.django_db(transaction=True)
def test_query_forward_with_fragments(db, gql_client: GraphQLTestClient):
query = """
Expand Down

0 comments on commit da61de8

Please sign in to comment.