Skip to content

Commit

Permalink
Add metadata (the graphql.core Node) to GraphQLField. Clients can use…
Browse files Browse the repository at this point in the history
… this metadata do specialize the creation (e.g. adding a `@defaultValue` directive to add a default value to the generated objects).
  • Loading branch information
Matt Gilson committed Mar 12, 2024
1 parent 1f156a9 commit 85bb731
Show file tree
Hide file tree
Showing 4 changed files with 237 additions and 12 deletions.
78 changes: 70 additions & 8 deletions strawberry/codegen/query_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
ListTypeNode,
ListValueNode,
NamedTypeNode,
Node,
NonNullTypeNode,
NullValueNode,
ObjectValueNode,
Expand Down Expand Up @@ -362,6 +363,8 @@ def _populate_fragment_types(self, ast: DocumentNode) -> None:
GraphQLFragmentType,
on=typename,
graphql_typename=typename,
graphql_node=fd,
graphql_type=query_type,
)

self._collect_types(
Expand Down Expand Up @@ -487,6 +490,28 @@ def _convert_operation(
variables, variables_type = self._convert_variable_definitions(
operation_definition.variable_definitions, operation_name=operation_name
)
if variables_type:
# There are multiple nodes in the `operation_definition.variable_definitions`
# list and collectively they are used to make the `variables_type`, so the
# closest node to the variable definition is the full operation definition node.
if variables_type.graphql_node is not None: # pragma: no cover
raise ValueError(
"Internal codegen error. graphql_node should NOT be set."
)
variables_type.graphql_node = operation_definition

# There isn't actually a class in the strawberry schema that represents the
# variables in a query (it's _usually_ a wrapper-class around some
# `@strawberry.input` types that are referenced in the resolver function's
# arguments, however, the query variable could also be an input to some
# directive). To this end, we currently just use the containing operation type
# (e.g. the ``Query`` or ``Mutation`` class and then we'll let consumers sort
# out how they want to use this metadata).
if variables_type.graphql_type is not None: # pragma: no cover
raise ValueError(
"Internal codegen error. graphql_type should NOT be set."
)
variables_type.graphql_type = query_type

return GraphQLOperation(
operation_definition.name.value,
Expand Down Expand Up @@ -519,7 +544,27 @@ def _convert_variable_definitions(
variable_type,
)

type_.fields.append(GraphQLField(variable.name, None, variable_type))
# This field generally represents arguments to resolver functions. e.g.:
#
# class Query:
# @strawberry.field
# def some_resolver(self, id: int) -> ReturnType:
# ...
#
# In this case, there is no `strawberry.type` with `StrawberryField` members
# that represent the `id` argument to the query. Because of that, we do not
# set the `strawberry_field` member of the `GraphQLField`. We could potentially
# also allow the `GraphQLField` to have an additional field that would hold
# an `inspect.Parameter` field that could be used in this case if we feel like
# that would be valuable to support in the future.
type_.fields.append(
GraphQLField(
variable.name,
None,
variable_type,
graphql_node=variable_definition,
)
)
variables.append(variable)

return variables, type_
Expand Down Expand Up @@ -584,6 +629,7 @@ def _collect_type_from_strawberry_type(
type_ = GraphQLObjectType(
strawberry_type.name,
[],
graphql_type=strawberry_type,
)

for field in strawberry_type.fields:
Expand All @@ -592,7 +638,13 @@ def _collect_type_from_strawberry_type(
if field.default is not MISSING:
default = _py_to_graphql_value(field.default)
type_.fields.append(
GraphQLField(field.name, None, field_type, default_value=default)
GraphQLField(
name=field.name,
alias=None,
type=field_type,
default_value=default,
strawberry_field=field,
)
)

self._collect_type(type_)
Expand Down Expand Up @@ -639,7 +691,11 @@ def _field_from_selection(
field_type = self._get_field_type(field.type)

return GraphQLField(
field.name, selection.alias.value if selection.alias else None, field_type
field.name,
selection.alias.value if selection.alias else None,
field_type,
graphql_node=selection,
strawberry_field=field,
)

def _unwrap_type(
Expand Down Expand Up @@ -688,10 +744,7 @@ def _field_from_selection_set(
# should be pretty safe.
if parent_type.type_var_map:
parent_type_name = (
"".join(
c.__name__ # type: ignore[union-attr]
for c in parent_type.type_var_map.values()
)
"".join(c.__name__ for c in parent_type.type_var_map.values()) # type: ignore[union-attr]
+ parent_type.name
)

Expand Down Expand Up @@ -724,6 +777,8 @@ def _field_from_selection_set(
selected_field.name,
selection.alias.value if selection.alias else None,
field_type,
graphql_node=selection,
strawberry_field=selected_field,
)

def _get_field(
Expand Down Expand Up @@ -777,6 +832,11 @@ def _collect_types(
)

current_type = graph_ql_object_type_factory(class_name)
if isinstance(selection, Node):
current_type.graphql_node = selection
if current_type.graphql_type is None:
current_type.graphql_type = parent_type

fields: List[Union[GraphQLFragmentSpread, GraphQLField]] = []

for sub_selection in selection_set.selections:
Expand Down Expand Up @@ -842,11 +902,13 @@ def _collect_types_using_fragments(
for fragment in fragments:
type_condition_name = fragment.type_condition.name.value
fragment_class_name = class_name + type_condition_name

strawberry_type = self.schema.get_type_by_name(type_condition_name)
current_type = GraphQLObjectType(
fragment_class_name,
list(common_fields),
graphql_typename=type_condition_name,
graphql_type=strawberry_type,
graphql_node=fragment,
)
fields: List[Union[GraphQLFragmentSpread, GraphQLField]] = []

Expand Down
10 changes: 10 additions & 0 deletions strawberry/codegen/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
from enum import EnumMeta
from typing_extensions import Literal

from graphql.language.ast import Node

from strawberry.field import StrawberryField
from strawberry.type import StrawberryType
from strawberry.unset import UnsetType


Expand All @@ -32,6 +36,8 @@ class GraphQLField:
alias: Optional[str]
type: GraphQLType
default_value: Optional[GraphQLArgumentValue] = None
graphql_node: Optional[Node] = None
strawberry_field: Optional[StrawberryField] = None


@dataclass
Expand All @@ -44,6 +50,8 @@ class GraphQLObjectType:
name: str
fields: List[GraphQLField] = field(default_factory=list)
graphql_typename: Optional[str] = None
graphql_node: Optional[Node] = None
graphql_type: Optional[StrawberryType] = None


# Subtype of GraphQLObjectType.
Expand All @@ -55,6 +63,8 @@ class GraphQLFragmentType(GraphQLObjectType):
fields: List[GraphQLField] = field(default_factory=list)
graphql_typename: Optional[str] = None
on: str = ""
graphql_node: Optional[Node] = None
graphql_type: Optional[StrawberryType] = None

def __post_init__(self) -> None:
if not self.on:
Expand Down
29 changes: 27 additions & 2 deletions tests/codegen/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,17 @@
import decimal
import enum
import random
from typing import TYPE_CHECKING, Generic, List, NewType, Optional, TypeVar, Union
from typing import (
TYPE_CHECKING,
Any,
Dict,
Generic,
List,
NewType,
Optional,
TypeVar,
Union,
)
from typing_extensions import Annotated
from uuid import UUID

Expand All @@ -26,7 +36,7 @@ class Color(enum.Enum):
@strawberry.type
class Person:
name: str
age: int
age: int = 7


@strawberry.type
Expand Down Expand Up @@ -120,6 +130,16 @@ def get_person_or_animal(self) -> Union[Person, Animal]:
p_or_a.age = 7
return p_or_a

@strawberry.field
def get_person_with_inputs(
self, name: str, age: int = 72
) -> Person: # pragma: no cover
"""Get a person."""
p_or_a = Person()
p_or_a.name = name
p_or_a.age = age
return p_or_a

@strawberry.field
def list_life() -> LifeContainer[Person, Animal]:
"""Get lists of living things."""
Expand Down Expand Up @@ -166,3 +186,8 @@ def add_blog_posts(self, input: AddBlogPostsInput) -> AddBlogPostsOutput:
@pytest.fixture
def schema() -> strawberry.Schema:
return strawberry.Schema(query=Query, mutation=Mutation, types=[BlogPost, Image])


@pytest.fixture
def conftest_globals() -> Dict[str, Any]:
return globals()
Loading

0 comments on commit 85bb731

Please sign in to comment.