Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add metadata (the graphql.core Node) to GraphQLField and GraphQLObject. #3182

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
Release type: minor

Augment the codegen `GraphQLObjectType` and `GraphQLField` with the `graphql.language.ast.Node` that caused the
respective object to be created. This node can be introspected for additional metadata for codegen plugins to use
for specialization of type creation.
78 changes: 70 additions & 8 deletions strawberry/codegen/query_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
ListTypeNode,
ListValueNode,
NamedTypeNode,
Node,
NonNullTypeNode,
NullValueNode,
ObjectValueNode,
Expand Down Expand Up @@ -362,6 +363,8 @@ def _populate_fragment_types(self, ast: DocumentNode) -> None:
GraphQLFragmentType,
on=typename,
graphql_typename=typename,
graphql_node=fd,
graphql_type=query_type,
)

self._collect_types(
Expand Down Expand Up @@ -487,6 +490,28 @@ def _convert_operation(
variables, variables_type = self._convert_variable_definitions(
operation_definition.variable_definitions, operation_name=operation_name
)
if variables_type:
# There are multiple nodes in the `operation_definition.variable_definitions`
# list and collectively they are used to make the `variables_type`, so the
# closest node to the variable definition is the full operation definition node.
if variables_type.graphql_node is not None: # pragma: no cover
raise ValueError(
"Internal codegen error. graphql_node should NOT be set."
)
variables_type.graphql_node = operation_definition

# There isn't actually a class in the strawberry schema that represents the
# variables in a query (it's _usually_ a wrapper-class around some
# `@strawberry.input` types that are referenced in the resolver function's
# arguments, however, the query variable could also be an input to some
# directive). To this end, we currently just use the containing operation type
# (e.g. the ``Query`` or ``Mutation`` class and then we'll let consumers sort
# out how they want to use this metadata).
if variables_type.graphql_type is not None: # pragma: no cover
raise ValueError(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(same q for the one above) when could this happen?

if it is just a defensive check, let's add a nocover comment 😊

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is defensive to make sure we didn't mess something up.

These are the {Type}Variables objects that I mentioned in the other comments. Basically, at the client level, it's just a dataclass that describes the inputs to a strawberry resolver (rather than being an actual type on the strawberry/python side).

In that case, I made the decision to just set the graphql node to the Operation node and I set the graphql type to the query/mutation type.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I think that it's a little more complex ... It maps to the variables in the operation

query FooQuery($id: Int, $fetchBar: Boolean) {
  foo(id: $id) {
    bar @skip(if: $fetchBar) {
      x
      y
    }
  }
}

The client is going to generate an object:

class FooQueryVariables:
    id: int
    fetch_bar: bool

This object doesn't map to any class/field on the strawberry/server side.

"Internal codegen error. graphql_type should NOT be set."
)
variables_type.graphql_type = query_type

return GraphQLOperation(
operation_definition.name.value,
Expand Down Expand Up @@ -519,7 +544,27 @@ def _convert_variable_definitions(
variable_type,
)

type_.fields.append(GraphQLField(variable.name, None, variable_type))
# This field generally represents arguments to resolver functions. e.g.:
#
# class Query:
# @strawberry.field
# def some_resolver(self, id: int) -> ReturnType:
# ...
#
# In this case, there is no `strawberry.type` with `StrawberryField` members
# that represent the `id` argument to the query. Because of that, we do not
# set the `strawberry_field` member of the `GraphQLField`. We could potentially
# also allow the `GraphQLField` to have an additional field that would hold
# an `inspect.Parameter` field that could be used in this case if we feel like
# that would be valuable to support in the future.
type_.fields.append(
GraphQLField(
variable.name,
None,
variable_type,
graphql_node=variable_definition,
Comment on lines +562 to +565
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we use named arguments here?

and could we find where the variable is used and attach that field here? Might not probably worth, but maybe we can leave a comment about that? (The current comment is a bit unclear to me)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've tried to update the comment to make it more clear. These are for the {Type}Variables objects which generally correspond to something like:

class Query:
    @strawberry.field
    def resolver_function(self, id: int, x: float) -> ReturnType:
        ...

The client is going to make a:

class QueryResolverFunctionVariables:
    id: int
    x: float

type of class (depending on the query generated). But in this case, there isn't a StrawberryField that maps 1-to-1 with the id parameter.

)
)
variables.append(variable)

return variables, type_
Expand Down Expand Up @@ -584,6 +629,7 @@ def _collect_type_from_strawberry_type(
type_ = GraphQLObjectType(
strawberry_type.name,
[],
graphql_type=strawberry_type,
)

for field in strawberry_type.fields:
Expand All @@ -592,7 +638,13 @@ def _collect_type_from_strawberry_type(
if field.default is not MISSING:
default = _py_to_graphql_value(field.default)
type_.fields.append(
GraphQLField(field.name, None, field_type, default_value=default)
GraphQLField(
name=field.name,
alias=None,
type=field_type,
default_value=default,
strawberry_field=field,
)
)

self._collect_type(type_)
Expand Down Expand Up @@ -639,7 +691,11 @@ def _field_from_selection(
field_type = self._get_field_type(field.type)

return GraphQLField(
field.name, selection.alias.value if selection.alias else None, field_type
field.name,
selection.alias.value if selection.alias else None,
field_type,
graphql_node=selection,
strawberry_field=field,
)

def _unwrap_type(
Expand Down Expand Up @@ -688,10 +744,7 @@ def _field_from_selection_set(
# should be pretty safe.
if parent_type.type_var_map:
parent_type_name = (
"".join(
c.__name__ # type: ignore[union-attr]
for c in parent_type.type_var_map.values()
)
"".join(c.__name__ for c in parent_type.type_var_map.values()) # type: ignore[union-attr]
+ parent_type.name
)

Expand Down Expand Up @@ -724,6 +777,8 @@ def _field_from_selection_set(
selected_field.name,
selection.alias.value if selection.alias else None,
field_type,
graphql_node=selection,
strawberry_field=selected_field,
)

def _get_field(
Expand Down Expand Up @@ -777,6 +832,11 @@ def _collect_types(
)

current_type = graph_ql_object_type_factory(class_name)
if isinstance(selection, Node):
current_type.graphql_node = selection
if current_type.graphql_type is None:
current_type.graphql_type = parent_type

fields: List[Union[GraphQLFragmentSpread, GraphQLField]] = []

for sub_selection in selection_set.selections:
Expand Down Expand Up @@ -842,11 +902,13 @@ def _collect_types_using_fragments(
for fragment in fragments:
type_condition_name = fragment.type_condition.name.value
fragment_class_name = class_name + type_condition_name

strawberry_type = self.schema.get_type_by_name(type_condition_name)
current_type = GraphQLObjectType(
fragment_class_name,
list(common_fields),
graphql_typename=type_condition_name,
graphql_type=strawberry_type,
graphql_node=fragment,
)
fields: List[Union[GraphQLFragmentSpread, GraphQLField]] = []

Expand Down
10 changes: 10 additions & 0 deletions strawberry/codegen/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
from enum import EnumMeta
from typing_extensions import Literal

from graphql.language.ast import Node

from strawberry.field import StrawberryField
from strawberry.type import StrawberryType
from strawberry.unset import UnsetType


Expand All @@ -32,6 +36,8 @@ class GraphQLField:
alias: Optional[str]
type: GraphQLType
default_value: Optional[GraphQLArgumentValue] = None
graphql_node: Optional[Node] = None
strawberry_field: Optional[StrawberryField] = None


@dataclass
Expand All @@ -44,6 +50,8 @@ class GraphQLObjectType:
name: str
fields: List[GraphQLField] = field(default_factory=list)
graphql_typename: Optional[str] = None
graphql_node: Optional[Node] = None
graphql_type: Optional[StrawberryType] = None


# Subtype of GraphQLObjectType.
Expand All @@ -55,6 +63,8 @@ class GraphQLFragmentType(GraphQLObjectType):
fields: List[GraphQLField] = field(default_factory=list)
graphql_typename: Optional[str] = None
on: str = ""
graphql_node: Optional[Node] = None
graphql_type: Optional[StrawberryType] = None

def __post_init__(self) -> None:
if not self.on:
Expand Down
29 changes: 27 additions & 2 deletions tests/codegen/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,17 @@
import decimal
import enum
import random
from typing import TYPE_CHECKING, Generic, List, NewType, Optional, TypeVar, Union
from typing import (
TYPE_CHECKING,
Any,
Dict,
Generic,
List,
NewType,
Optional,
TypeVar,
Union,
)
from typing_extensions import Annotated
from uuid import UUID

Expand All @@ -26,7 +36,7 @@ class Color(enum.Enum):
@strawberry.type
class Person:
name: str
age: int
age: int = 7


@strawberry.type
Expand Down Expand Up @@ -120,6 +130,16 @@ def get_person_or_animal(self) -> Union[Person, Animal]:
p_or_a.age = 7
return p_or_a

@strawberry.field
def get_person_with_inputs(
self, name: str, age: int = 72
) -> Person: # pragma: no cover
"""Get a person."""
p_or_a = Person()
p_or_a.name = name
p_or_a.age = age
return p_or_a

@strawberry.field
def list_life() -> LifeContainer[Person, Animal]:
"""Get lists of living things."""
Expand Down Expand Up @@ -166,3 +186,8 @@ def add_blog_posts(self, input: AddBlogPostsInput) -> AddBlogPostsOutput:
@pytest.fixture
def schema() -> strawberry.Schema:
return strawberry.Schema(query=Query, mutation=Mutation, types=[BlogPost, Image])


@pytest.fixture
def conftest_globals() -> Dict[str, Any]:
return globals()
Loading
Loading