Skip to content

Commit

Permalink
Add metadata (the graphql.core Node) to GraphQLField. Clients can use…
Browse files Browse the repository at this point in the history
… this metadata do specialize the creation (e.g. adding a `@defaultValue` directive to add a default value to the generated objects).
  • Loading branch information
Matt Gilson committed Nov 3, 2023
1 parent ab3c6c4 commit c0921ac
Show file tree
Hide file tree
Showing 4 changed files with 228 additions and 12 deletions.
71 changes: 63 additions & 8 deletions strawberry/codegen/query_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
ListTypeNode,
ListValueNode,
NamedTypeNode,
Node,
NonNullTypeNode,
NullValueNode,
ObjectValueNode,
Expand Down Expand Up @@ -362,6 +363,8 @@ def _populate_fragment_types(self, ast: DocumentNode) -> None:
GraphQLFragmentType,
on=typename,
graphql_typename=typename,
graphql_node=fd,
strawberry_type=query_type,
)

self._collect_types(
Expand Down Expand Up @@ -487,6 +490,28 @@ def _convert_operation(
variables, variables_type = self._convert_variable_definitions(
operation_definition.variable_definitions, operation_name=operation_name
)
if variables_type:
# There are multiple nodes in the `operation_definition.variable_definitions`
# list and collectively they are used to make the `variables_type`, so the
# closest node to the variable definition is the full operation definition node.
if variables_type.graphql_node is not None:
raise ValueError(

Check warning on line 498 in strawberry/codegen/query_codegen.py

View check run for this annotation

Codecov / codecov/patch

strawberry/codegen/query_codegen.py#L498

Added line #L498 was not covered by tests
"Internal codegen error. graphql_node should NOT be set."
)
variables_type.graphql_node = operation_definition

# There isn't actually a class in the strawberry schema that represents the
# variables in a query (it's _usually_ a wrapper-class around some
# `@strawberry.input` types that are referenced in the resolver function's
# arguments, however, the query variable could also be an input to some
# directive). To this end, we currently just use the containing operation type
# (e.g. the ``Query`` or ``Mutation`` class and then we'll let consumers sort
# out how they want to use this metadata).
if variables_type.graphql_type is not None:
raise ValueError(

Check warning on line 511 in strawberry/codegen/query_codegen.py

View check run for this annotation

Codecov / codecov/patch

strawberry/codegen/query_codegen.py#L511

Added line #L511 was not covered by tests
"Internal codegen error. graphql_type should NOT be set."
)
variables_type.graphql_type = query_type

return GraphQLOperation(
operation_definition.name.value,
Expand Down Expand Up @@ -519,7 +544,20 @@ def _convert_variable_definitions(
variable_type,
)

type_.fields.append(GraphQLField(variable.name, None, variable_type))
# Currently, we do NOT set a strawberry_field on these variables because it isn't clear
# what field would be the correct one to target. Most of the time, these variables
# probably prefer to target something like
# https://docs.python.org/3/library/inspect.html#inspect.Parameter for the resolver
# function associated with the operation (or potentially with the function associated
# with the directive that the variable gets passed to)?
type_.fields.append(
GraphQLField(
variable.name,
None,
variable_type,
graphql_node=variable_definition,
)
)
variables.append(variable)

return variables, type_
Expand Down Expand Up @@ -584,6 +622,7 @@ def _collect_type_from_strawberry_type(
type_ = GraphQLObjectType(
strawberry_type.name,
[],
graphql_type=strawberry_type,
)

for field in strawberry_type.fields:
Expand All @@ -592,7 +631,13 @@ def _collect_type_from_strawberry_type(
if field.default is not MISSING:
default = _py_to_graphql_value(field.default)
type_.fields.append(
GraphQLField(field.name, None, field_type, default_value=default)
GraphQLField(
field.name,
None,
field_type,
default_value=default,
strawberry_field=field,
)
)

self._collect_type(type_)
Expand Down Expand Up @@ -639,7 +684,11 @@ def _field_from_selection(
field_type = self._get_field_type(field.type)

return GraphQLField(
field.name, selection.alias.value if selection.alias else None, field_type
field.name,
selection.alias.value if selection.alias else None,
field_type,
graphql_node=selection,
strawberry_field=field,
)

def _unwrap_type(
Expand Down Expand Up @@ -688,10 +737,7 @@ def _field_from_selection_set(
# should be pretty safe.
if parent_type.type_var_map:
parent_type_name = (
"".join(
c.__name__ # type: ignore[union-attr]
for c in parent_type.type_var_map.values()
)
"".join(c.__name__ for c in parent_type.type_var_map.values()) # type: ignore[union-attr]
+ parent_type.name
)

Expand Down Expand Up @@ -724,6 +770,8 @@ def _field_from_selection_set(
selected_field.name,
selection.alias.value if selection.alias else None,
field_type,
graphql_node=selection,
strawberry_field=selected_field,
)

def _get_field(
Expand Down Expand Up @@ -777,6 +825,11 @@ def _collect_types(
)

current_type = graph_ql_object_type_factory(class_name)
if isinstance(selection, Node):
current_type.graphql_node = selection
if current_type.graphql_type is None:
current_type.graphql_type = parent_type

fields: List[Union[GraphQLFragmentSpread, GraphQLField]] = []

for sub_selection in selection_set.selections:
Expand Down Expand Up @@ -842,11 +895,13 @@ def _collect_types_using_fragments(
for fragment in fragments:
type_condition_name = fragment.type_condition.name.value
fragment_class_name = class_name + type_condition_name

strawberry_type = self.schema.get_type_by_name(type_condition_name)
current_type = GraphQLObjectType(
fragment_class_name,
list(common_fields),
graphql_typename=type_condition_name,
graphql_type=strawberry_type,
graphql_node=fragment,
)
fields: List[Union[GraphQLFragmentSpread, GraphQLField]] = []

Expand Down
10 changes: 10 additions & 0 deletions strawberry/codegen/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
from enum import EnumMeta
from typing_extensions import Literal

from graphql.language.ast import Node

from strawberry.field import StrawberryField
from strawberry.type import StrawberryType
from strawberry.unset import UnsetType


Expand All @@ -32,6 +36,8 @@ class GraphQLField:
alias: Optional[str]
type: GraphQLType
default_value: Optional[GraphQLArgumentValue] = None
graphql_node: Optional[Node] = None
strawberry_field: Optional[StrawberryField] = None


@dataclass
Expand All @@ -44,6 +50,8 @@ class GraphQLObjectType:
name: str
fields: List[GraphQLField] = field(default_factory=list)
graphql_typename: Optional[str] = None
graphql_node: Optional[Node] = None
graphql_type: Optional[StrawberryType] = None


# Subtype of GraphQLObjectType.
Expand All @@ -55,6 +63,8 @@ class GraphQLFragmentType(GraphQLObjectType):
fields: List[GraphQLField] = field(default_factory=list)
graphql_typename: Optional[str] = None
on: str = ""
definition_node: Optional[Node] = None
strawberry_type: Optional[StrawberryType] = None

def __post_init__(self) -> None:
if not self.on:
Expand Down
27 changes: 25 additions & 2 deletions tests/codegen/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,17 @@
import decimal
import enum
import random
from typing import TYPE_CHECKING, Generic, List, NewType, Optional, TypeVar, Union
from typing import (
TYPE_CHECKING,
Any,
Dict,
Generic,
List,
NewType,
Optional,
TypeVar,
Union,
)
from typing_extensions import Annotated
from uuid import UUID

Expand All @@ -26,7 +36,7 @@ class Color(enum.Enum):
@strawberry.type
class Person:
name: str
age: int
age: int = 7


@strawberry.type
Expand Down Expand Up @@ -120,6 +130,14 @@ def get_person_or_animal(self) -> Union[Person, Animal]:
p_or_a.age = 7
return p_or_a

@strawberry.field
def get_person_with_inputs(self, name: str, age: int = 72) -> Person:
"""Get a person."""
p_or_a = Person()
p_or_a.name = name
p_or_a.age = age
return p_or_a

Check warning on line 139 in tests/codegen/conftest.py

View check run for this annotation

Codecov / codecov/patch

tests/codegen/conftest.py#L136-L139

Added lines #L136 - L139 were not covered by tests

@strawberry.field
def list_life() -> LifeContainer[Person, Animal]:
"""Get lists of living things."""
Expand Down Expand Up @@ -166,3 +184,8 @@ def add_blog_posts(self, input: AddBlogPostsInput) -> AddBlogPostsOutput:
@pytest.fixture
def schema() -> strawberry.Schema:
return strawberry.Schema(query=Query, mutation=Mutation, types=[BlogPost, Image])


@pytest.fixture
def conftest_globals() -> Dict[str, Any]:
return globals()
132 changes: 130 additions & 2 deletions tests/codegen/test_query_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,25 @@
# - 4. test mutations (raise?)
# - 5. test subscriptions (raise)

import dataclasses
import textwrap
from pathlib import Path
from typing import Type
from typing import Dict, List, Optional, Type

import pytest
from graphql.language.ast import FieldNode, InlineFragmentNode, OperationDefinitionNode
from pytest_snapshot.plugin import Snapshot

from strawberry.codegen import QueryCodegen, QueryCodegenPlugin
from strawberry.codegen import CodegenFile, QueryCodegen, QueryCodegenPlugin
from strawberry.codegen.exceptions import (
MultipleOperationsProvidedError,
NoOperationNameProvidedError,
NoOperationProvidedError,
)
from strawberry.codegen.plugins.python import PythonPlugin
from strawberry.codegen.plugins.typescript import TypeScriptPlugin
from strawberry.codegen.types import GraphQLOperation, GraphQLType
from strawberry.types.types import StrawberryObjectDefinition

HERE = Path(__file__).parent
QUERIES = list(HERE.glob("queries/*.graphql"))
Expand Down Expand Up @@ -84,3 +89,126 @@ def test_fails_with_multiple_operations(schema, tmp_path):

with pytest.raises(MultipleOperationsProvidedError):
generator.run(data)


def test_codegen_augments_class_and_fields_with_source_objects(
schema, conftest_globals, tmp_path
):
class CustomPythonPlugin(PythonPlugin):
types_by_name: Optional[Dict[str, GraphQLType]] = None

def generate_code(
self, types: List[GraphQLType], operation: GraphQLOperation
) -> List[CodegenFile]:
self.types_by_name = {t.name: t for t in types}
return super().generate_code(types, operation)

query = tmp_path / "query.graphql"
data = textwrap.dedent(
"""\
query Operation {
getPersonOrAnimal {
... on Person {
age
}
}
}
"""
)
with query.open("w") as f:
f.write(data)

plugin = CustomPythonPlugin(query)
generator = QueryCodegen(schema, plugins=[plugin])
generator.run(data)

assert plugin.types_by_name is not None
types_by_name = plugin.types_by_name

assert set(types_by_name) == {
"Int",
"OperationResultGetPersonOrAnimalPerson",
"OperationResult",
}

person_type = types_by_name["OperationResultGetPersonOrAnimalPerson"]
assert isinstance(person_type.graphql_type, StrawberryObjectDefinition)
assert person_type.graphql_type.origin is conftest_globals["Person"]

assert isinstance(person_type.graphql_node, InlineFragmentNode)

name_field = next((fld for fld in person_type.fields if fld.name == "age"), None)

assert name_field is not None
# Check that we got the `dataclasses.Field` from the upstream ``Person`` type.
assert isinstance(name_field.strawberry_field, dataclasses.Field)
assert name_field.strawberry_field.default == 7
# Check that we got the GraphQL AST node that defined this field in the graphql AST.
assert isinstance(name_field.graphql_node, FieldNode)
assert name_field.graphql_node.name.value == "age"

result_type = types_by_name["OperationResult"]
assert result_type.graphql_type is not None
assert result_type.graphql_type.origin is conftest_globals["Query"]
assert isinstance(result_type.graphql_node, OperationDefinitionNode)


def test_codegen_augments_class_and_fields_with_source_objects_when_inputs(
schema, conftest_globals, tmp_path
):
class CustomPythonPlugin(PythonPlugin):
types_by_name: Optional[Dict[str, GraphQLType]] = None

def generate_code(
self, types: List[GraphQLType], operation: GraphQLOperation
) -> List[CodegenFile]:
self.types_by_name = {t.name: t for t in types}
return super().generate_code(types, operation)

query = tmp_path / "query.graphql"
data = textwrap.dedent(
"""\
query Operation($name: String!, $age: Int!) {
getPersonWithInputs(name: $name, age: $age) {
name
age
}
}
"""
)
with query.open("w") as f:
f.write(data)

plugin = CustomPythonPlugin(query)
generator = QueryCodegen(schema, plugins=[plugin])
generator.run(data)

assert plugin.types_by_name is not None
types_by_name = plugin.types_by_name

assert set(types_by_name) == {
"Int",
"String",
"OperationResult",
"OperationResultGetPersonWithInputs",
"OperationVariables",
}

# The special "Result" and "Variables" types hold references to the
# overarching graphql Query/Mutation/Subscription types.
query_type = types_by_name["OperationResult"]
assert isinstance(query_type.graphql_type, StrawberryObjectDefinition)
assert isinstance(query_type.graphql_node, OperationDefinitionNode)
assert query_type.graphql_type is not None
assert query_type.graphql_type.origin is conftest_globals["Query"]

variables_type = types_by_name["OperationVariables"]
assert isinstance(variables_type.graphql_type, StrawberryObjectDefinition)
assert isinstance(variables_type.graphql_node, OperationDefinitionNode)
assert variables_type.graphql_type is not None
assert variables_type.graphql_type.origin is conftest_globals["Query"]

person_type = types_by_name["OperationResultGetPersonWithInputs"]
assert isinstance(person_type.graphql_type, StrawberryObjectDefinition)
assert person_type.graphql_type is not None
assert person_type.graphql_type.origin is conftest_globals["Person"]

0 comments on commit c0921ac

Please sign in to comment.