diff --git a/strawberry/codegen/query_codegen.py b/strawberry/codegen/query_codegen.py index 728d4fc330..72ca70b721 100644 --- a/strawberry/codegen/query_codegen.py +++ b/strawberry/codegen/query_codegen.py @@ -33,6 +33,7 @@ ListTypeNode, ListValueNode, NamedTypeNode, + Node, NonNullTypeNode, NullValueNode, ObjectValueNode, @@ -362,6 +363,8 @@ def _populate_fragment_types(self, ast: DocumentNode) -> None: GraphQLFragmentType, on=typename, graphql_typename=typename, + graphql_node=fd, + graphql_type=query_type, ) self._collect_types( @@ -487,6 +490,28 @@ def _convert_operation( variables, variables_type = self._convert_variable_definitions( operation_definition.variable_definitions, operation_name=operation_name ) + if variables_type: + # There are multiple nodes in the `operation_definition.variable_definitions` + # list and collectively they are used to make the `variables_type`, so the + # closest node to the variable definition is the full operation definition node. + if variables_type.graphql_node is not None: # pragma: no cover + raise ValueError( + "Internal codegen error. graphql_node should NOT be set." + ) + variables_type.graphql_node = operation_definition + + # There isn't actually a class in the strawberry schema that represents the + # variables in a query (it's _usually_ a wrapper-class around some + # `@strawberry.input` types that are referenced in the resolver function's + # arguments, however, the query variable could also be an input to some + # directive). To this end, we currently just use the containing operation type + # (e.g. the ``Query`` or ``Mutation`` class and then we'll let consumers sort + # out how they want to use this metadata). + if variables_type.graphql_type is not None: # pragma: no cover + raise ValueError( + "Internal codegen error. graphql_type should NOT be set." + ) + variables_type.graphql_type = query_type return GraphQLOperation( operation_definition.name.value, @@ -519,7 +544,27 @@ def _convert_variable_definitions( variable_type, ) - type_.fields.append(GraphQLField(variable.name, None, variable_type)) + # This field generally represents arguments to resolver functions. e.g.: + # + # class Query: + # @strawberry.field + # def some_resolver(self, id: int) -> ReturnType: + # ... + # + # In this case, there is no `strawberry.type` with `StrawberryField` members + # that represent the `id` argument to the query. Because of that, we do not + # set the `strawberry_field` member of the `GraphQLField`. We could potentially + # also allow the `GraphQLField` to have an additional field that would hold + # an `inspect.Parameter` field that could be used in this case if we feel like + # that would be valuable to support in the future. + type_.fields.append( + GraphQLField( + variable.name, + None, + variable_type, + graphql_node=variable_definition, + ) + ) variables.append(variable) return variables, type_ @@ -584,6 +629,7 @@ def _collect_type_from_strawberry_type( type_ = GraphQLObjectType( strawberry_type.name, [], + graphql_type=strawberry_type, ) for field in strawberry_type.fields: @@ -592,7 +638,13 @@ def _collect_type_from_strawberry_type( if field.default is not MISSING: default = _py_to_graphql_value(field.default) type_.fields.append( - GraphQLField(field.name, None, field_type, default_value=default) + GraphQLField( + name=field.name, + alias=None, + type=field_type, + default_value=default, + strawberry_field=field, + ) ) self._collect_type(type_) @@ -639,7 +691,11 @@ def _field_from_selection( field_type = self._get_field_type(field.type) return GraphQLField( - field.name, selection.alias.value if selection.alias else None, field_type + field.name, + selection.alias.value if selection.alias else None, + field_type, + graphql_node=selection, + strawberry_field=field, ) def _unwrap_type( @@ -688,10 +744,7 @@ def _field_from_selection_set( # should be pretty safe. if parent_type.type_var_map: parent_type_name = ( - "".join( - c.__name__ # type: ignore[union-attr] - for c in parent_type.type_var_map.values() - ) + "".join(c.__name__ for c in parent_type.type_var_map.values()) # type: ignore[union-attr] + parent_type.name ) @@ -724,6 +777,8 @@ def _field_from_selection_set( selected_field.name, selection.alias.value if selection.alias else None, field_type, + graphql_node=selection, + strawberry_field=selected_field, ) def _get_field( @@ -777,6 +832,11 @@ def _collect_types( ) current_type = graph_ql_object_type_factory(class_name) + if isinstance(selection, Node): + current_type.graphql_node = selection + if current_type.graphql_type is None: + current_type.graphql_type = parent_type + fields: List[Union[GraphQLFragmentSpread, GraphQLField]] = [] for sub_selection in selection_set.selections: @@ -842,11 +902,13 @@ def _collect_types_using_fragments( for fragment in fragments: type_condition_name = fragment.type_condition.name.value fragment_class_name = class_name + type_condition_name - + strawberry_type = self.schema.get_type_by_name(type_condition_name) current_type = GraphQLObjectType( fragment_class_name, list(common_fields), graphql_typename=type_condition_name, + graphql_type=strawberry_type, + graphql_node=fragment, ) fields: List[Union[GraphQLFragmentSpread, GraphQLField]] = [] diff --git a/strawberry/codegen/types.py b/strawberry/codegen/types.py index 4d39fc4707..57531f2564 100644 --- a/strawberry/codegen/types.py +++ b/strawberry/codegen/types.py @@ -7,6 +7,10 @@ from enum import EnumMeta from typing_extensions import Literal + from graphql.language.ast import Node + + from strawberry.field import StrawberryField + from strawberry.type import StrawberryType from strawberry.unset import UnsetType @@ -32,6 +36,8 @@ class GraphQLField: alias: Optional[str] type: GraphQLType default_value: Optional[GraphQLArgumentValue] = None + graphql_node: Optional[Node] = None + strawberry_field: Optional[StrawberryField] = None @dataclass @@ -44,6 +50,8 @@ class GraphQLObjectType: name: str fields: List[GraphQLField] = field(default_factory=list) graphql_typename: Optional[str] = None + graphql_node: Optional[Node] = None + graphql_type: Optional[StrawberryType] = None # Subtype of GraphQLObjectType. @@ -55,6 +63,8 @@ class GraphQLFragmentType(GraphQLObjectType): fields: List[GraphQLField] = field(default_factory=list) graphql_typename: Optional[str] = None on: str = "" + graphql_node: Optional[Node] = None + graphql_type: Optional[StrawberryType] = None def __post_init__(self) -> None: if not self.on: diff --git a/tests/codegen/conftest.py b/tests/codegen/conftest.py index 1b7f80e467..8c3e35ab38 100644 --- a/tests/codegen/conftest.py +++ b/tests/codegen/conftest.py @@ -2,7 +2,17 @@ import decimal import enum import random -from typing import TYPE_CHECKING, Generic, List, NewType, Optional, TypeVar, Union +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Generic, + List, + NewType, + Optional, + TypeVar, + Union, +) from typing_extensions import Annotated from uuid import UUID @@ -26,7 +36,7 @@ class Color(enum.Enum): @strawberry.type class Person: name: str - age: int + age: int = 7 @strawberry.type @@ -120,6 +130,16 @@ def get_person_or_animal(self) -> Union[Person, Animal]: p_or_a.age = 7 return p_or_a + @strawberry.field + def get_person_with_inputs( + self, name: str, age: int = 72 + ) -> Person: # pragma: no cover + """Get a person.""" + p_or_a = Person() + p_or_a.name = name + p_or_a.age = age + return p_or_a + @strawberry.field def list_life() -> LifeContainer[Person, Animal]: """Get lists of living things.""" @@ -166,3 +186,8 @@ def add_blog_posts(self, input: AddBlogPostsInput) -> AddBlogPostsOutput: @pytest.fixture def schema() -> strawberry.Schema: return strawberry.Schema(query=Query, mutation=Mutation, types=[BlogPost, Image]) + + +@pytest.fixture +def conftest_globals() -> Dict[str, Any]: + return globals() diff --git a/tests/codegen/test_query_codegen.py b/tests/codegen/test_query_codegen.py index 78a3b71259..ae1bdca8b7 100644 --- a/tests/codegen/test_query_codegen.py +++ b/tests/codegen/test_query_codegen.py @@ -4,13 +4,16 @@ # - 4. test mutations (raise?) # - 5. test subscriptions (raise) +import dataclasses +import textwrap from pathlib import Path -from typing import Type +from typing import Dict, List, Optional, Type import pytest +from graphql.language.ast import FieldNode, InlineFragmentNode, OperationDefinitionNode from pytest_snapshot.plugin import Snapshot -from strawberry.codegen import QueryCodegen, QueryCodegenPlugin +from strawberry.codegen import CodegenFile, QueryCodegen, QueryCodegenPlugin from strawberry.codegen.exceptions import ( MultipleOperationsProvidedError, NoOperationNameProvidedError, @@ -18,6 +21,8 @@ ) from strawberry.codegen.plugins.python import PythonPlugin from strawberry.codegen.plugins.typescript import TypeScriptPlugin +from strawberry.codegen.types import GraphQLOperation, GraphQLType +from strawberry.types.types import StrawberryObjectDefinition HERE = Path(__file__).parent QUERIES = list(HERE.glob("queries/*.graphql")) @@ -84,3 +89,126 @@ def test_fails_with_multiple_operations(schema, tmp_path): with pytest.raises(MultipleOperationsProvidedError): generator.run(data) + + +def test_codegen_augments_class_and_fields_with_source_objects( + schema, conftest_globals, tmp_path +): + class CustomPythonPlugin(PythonPlugin): + types_by_name: Optional[Dict[str, GraphQLType]] = None + + def generate_code( + self, types: List[GraphQLType], operation: GraphQLOperation + ) -> List[CodegenFile]: + self.types_by_name = {t.name: t for t in types} + return super().generate_code(types, operation) + + query = tmp_path / "query.graphql" + data = textwrap.dedent( + """\ + query Operation { + getPersonOrAnimal { + ... on Person { + age + } + } + } + """ + ) + with query.open("w") as f: + f.write(data) + + plugin = CustomPythonPlugin(query) + generator = QueryCodegen(schema, plugins=[plugin]) + generator.run(data) + + assert plugin.types_by_name is not None + types_by_name = plugin.types_by_name + + assert set(types_by_name) == { + "Int", + "OperationResultGetPersonOrAnimalPerson", + "OperationResult", + } + + person_type = types_by_name["OperationResultGetPersonOrAnimalPerson"] + assert isinstance(person_type.graphql_type, StrawberryObjectDefinition) + assert person_type.graphql_type.origin is conftest_globals["Person"] + + assert isinstance(person_type.graphql_node, InlineFragmentNode) + + name_field = next((fld for fld in person_type.fields if fld.name == "age"), None) + + assert name_field is not None + # Check that we got the `dataclasses.Field` from the upstream ``Person`` type. + assert isinstance(name_field.strawberry_field, dataclasses.Field) + assert name_field.strawberry_field.default == 7 + # Check that we got the GraphQL AST node that defined this field in the graphql AST. + assert isinstance(name_field.graphql_node, FieldNode) + assert name_field.graphql_node.name.value == "age" + + result_type = types_by_name["OperationResult"] + assert result_type.graphql_type is not None + assert result_type.graphql_type.origin is conftest_globals["Query"] + assert isinstance(result_type.graphql_node, OperationDefinitionNode) + + +def test_codegen_augments_class_and_fields_with_source_objects_when_inputs( + schema, conftest_globals, tmp_path +): + class CustomPythonPlugin(PythonPlugin): + types_by_name: Optional[Dict[str, GraphQLType]] = None + + def generate_code( + self, types: List[GraphQLType], operation: GraphQLOperation + ) -> List[CodegenFile]: + self.types_by_name = {t.name: t for t in types} + return super().generate_code(types, operation) + + query = tmp_path / "query.graphql" + data = textwrap.dedent( + """\ + query Operation($name: String!, $age: Int!) { + getPersonWithInputs(name: $name, age: $age) { + name + age + } + } + """ + ) + with query.open("w") as f: + f.write(data) + + plugin = CustomPythonPlugin(query) + generator = QueryCodegen(schema, plugins=[plugin]) + generator.run(data) + + assert plugin.types_by_name is not None + types_by_name = plugin.types_by_name + + assert set(types_by_name) == { + "Int", + "String", + "OperationResult", + "OperationResultGetPersonWithInputs", + "OperationVariables", + } + + # The special "Result" and "Variables" types hold references to the + # overarching graphql Query/Mutation/Subscription types. + query_type = types_by_name["OperationResult"] + assert isinstance(query_type.graphql_type, StrawberryObjectDefinition) + assert isinstance(query_type.graphql_node, OperationDefinitionNode) + assert query_type.graphql_type is not None + assert query_type.graphql_type.origin is conftest_globals["Query"] + + variables_type = types_by_name["OperationVariables"] + assert isinstance(variables_type.graphql_type, StrawberryObjectDefinition) + assert isinstance(variables_type.graphql_node, OperationDefinitionNode) + assert variables_type.graphql_type is not None + assert variables_type.graphql_type.origin is conftest_globals["Query"] + + person_type = types_by_name["OperationResultGetPersonWithInputs"] + assert isinstance(person_type.graphql_type, StrawberryObjectDefinition) + assert person_type.graphql_type is not None + assert person_type.graphql_type.origin is conftest_globals["Person"]