diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 0000000000..e41e9220d4 --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,3 @@ +Release type: minor + +Attempt to merge union types during schema conversion. diff --git a/strawberry/schema/name_converter.py b/strawberry/schema/name_converter.py index ec6d0edf2a..0743d72580 100644 --- a/strawberry/schema/name_converter.py +++ b/strawberry/schema/name_converter.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, List, Optional, Union, cast +from typing import TYPE_CHECKING, List, Optional, Tuple, Union, cast from typing_extensions import Protocol from strawberry.directive import StrawberryDirective @@ -107,8 +107,14 @@ def from_union(self, union: StrawberryUnion) -> str: return union.graphql_name name = "" + types: Tuple[StrawberryType, ...] = union.types - for type_ in union.types: + if union.concrete_of and union.concrete_of.graphql_name: + concrete_of_types = set(union.concrete_of.types) + + types = tuple(type_ for type_ in types if type_ not in concrete_of_types) + + for type_ in types: if isinstance(type_, LazyType): type_ = cast("StrawberryType", type_.resolve_type()) # noqa: PLW2901 @@ -121,6 +127,9 @@ def from_union(self, union: StrawberryUnion) -> str: name += type_name + if union.concrete_of and union.concrete_of.graphql_name: + name += union.concrete_of.graphql_name + return name def from_generic( @@ -133,12 +142,12 @@ def from_generic( names: List[str] = [] for type_ in types: - name = self.get_from_type(type_) + name = self.get_name_from_type(type_) names.append(name) return "".join(names) + generic_type_name - def get_from_type(self, type_: Union[StrawberryType, type]) -> str: + def get_name_from_type(self, type_: Union[StrawberryType, type]) -> str: type_ = eval_type(type_) if isinstance(type_, LazyType): @@ -148,9 +157,9 @@ def get_from_type(self, type_: Union[StrawberryType, type]) -> str: elif isinstance(type_, StrawberryUnion): name = type_.graphql_name if type_.graphql_name else self.from_union(type_) elif isinstance(type_, StrawberryList): - name = self.get_from_type(type_.of_type) + "List" + name = self.get_name_from_type(type_.of_type) + "List" elif isinstance(type_, StrawberryOptional): - name = self.get_from_type(type_.of_type) + "Optional" + name = self.get_name_from_type(type_.of_type) + "Optional" elif hasattr(type_, "_scalar_definition"): strawberry_type = type_._scalar_definition diff --git a/strawberry/schema/schema_converter.py b/strawberry/schema/schema_converter.py index 1083b46f9b..79ab5a2b96 100644 --- a/strawberry/schema/schema_converter.py +++ b/strawberry/schema/schema_converter.py @@ -865,14 +865,20 @@ def from_union(self, union: StrawberryUnion) -> GraphQLUnionType: return graphql_union graphql_types: List[GraphQLObjectType] = [] + for type_ in union.types: graphql_type = self.from_type(type_) if isinstance(graphql_type, GraphQLInputObjectType): raise InvalidTypeInputForUnion(graphql_type) - assert isinstance(graphql_type, GraphQLObjectType) + assert isinstance(graphql_type, (GraphQLObjectType, GraphQLUnionType)) - graphql_types.append(graphql_type) + # If the graphql_type is a GraphQLUnionType, merge its child types + if isinstance(graphql_type, GraphQLUnionType): + # Add the child types of the GraphQLUnionType to the list of graphql_types + graphql_types.extend(graphql_type.types) + else: + graphql_types.append(graphql_type) graphql_union = GraphQLUnionType( name=union_name, diff --git a/strawberry/types/union.py b/strawberry/types/union.py index f5d8c6210c..a5f04d7a55 100644 --- a/strawberry/types/union.py +++ b/strawberry/types/union.py @@ -67,6 +67,7 @@ def __init__( self.directives = directives self._source_file = None self._source_line = None + self.concrete_of: Optional[StrawberryUnion] = None def __eq__(self, other: object) -> bool: if isinstance(other, StrawberryType): @@ -139,6 +140,7 @@ def copy_with( return self new_types = [] + for type_ in self.types: new_type: Union[StrawberryType, type] @@ -154,10 +156,13 @@ def copy_with( new_types.append(new_type) - return StrawberryUnion( + new_union = StrawberryUnion( type_annotations=tuple(map(StrawberryAnnotation, new_types)), description=self.description, ) + new_union.concrete_of = self + + return new_union def __call__(self, *args: str, **kwargs: Any) -> NoReturn: """Do not use. diff --git a/tests/schema/test_union.py b/tests/schema/test_union.py index 42392b9584..9f540676e4 100644 --- a/tests/schema/test_union.py +++ b/tests/schema/test_union.py @@ -850,3 +850,178 @@ class Query: assert not result.errors assert result.data["something"] == {"__typename": "A", "a": 5} + + +def test_generic_union_with_annotated(): + @strawberry.type + class SomeType: + id: strawberry.ID + name: str + + @strawberry.type + class NotFoundError: + id: strawberry.ID + message: str + + T = TypeVar("T") + + @strawberry.type + class ObjectQueries(Generic[T]): + @strawberry.field + def by_id( + self, id: strawberry.ID + ) -> Annotated[Union[T, NotFoundError], strawberry.union("ByIdResult")]: ... + + @strawberry.type + class Query: + @strawberry.field + def some_type_queries(self, id: strawberry.ID) -> ObjectQueries[SomeType]: + raise NotImplementedError() + + schema = strawberry.Schema(Query) + + assert ( + str(schema) + == textwrap.dedent( + """ + type NotFoundError { + id: ID! + message: String! + } + + type Query { + someTypeQueries(id: ID!): SomeTypeObjectQueries! + } + + type SomeType { + id: ID! + name: String! + } + + union SomeTypeByIdResult = SomeType | NotFoundError + + type SomeTypeObjectQueries { + byId(id: ID!): SomeTypeByIdResult! + } + """ + ).strip() + ) + + +def test_generic_union_with_annotated_inside(): + @strawberry.type + class SomeType: + id: strawberry.ID + name: str + + @strawberry.type + class NotFoundError: + id: strawberry.ID + message: str + + T = TypeVar("T") + + @strawberry.type + class ObjectQueries(Generic[T]): + @strawberry.field + def by_id( + self, id: strawberry.ID + ) -> Union[T, Annotated[NotFoundError, strawberry.union("ByIdResult")]]: ... + + @strawberry.type + class Query: + @strawberry.field + def some_type_queries(self, id: strawberry.ID) -> ObjectQueries[SomeType]: ... + + schema = strawberry.Schema(Query) + + assert ( + str(schema) + == textwrap.dedent( + """ + type NotFoundError { + id: ID! + message: String! + } + + type Query { + someTypeQueries(id: ID!): SomeTypeObjectQueries! + } + + type SomeType { + id: ID! + name: String! + } + + union SomeTypeByIdResult = SomeType | NotFoundError + + type SomeTypeObjectQueries { + byId(id: ID!): SomeTypeByIdResult! + } + """ + ).strip() + ) + + +def test_annoted_union_with_two_generics(): + @strawberry.type + class SomeType: + a: str + + @strawberry.type + class OtherType: + b: str + + @strawberry.type + class NotFoundError: + message: str + + T = TypeVar("T") + U = TypeVar("U") + + @strawberry.type + class UnionObjectQueries(Generic[T, U]): + @strawberry.field + def by_id( + self, id: strawberry.ID + ) -> Union[ + T, Annotated[Union[U, NotFoundError], strawberry.union("ByIdResult")] + ]: ... + + @strawberry.type + class Query: + @strawberry.field + def some_type_queries( + self, id: strawberry.ID + ) -> UnionObjectQueries[SomeType, OtherType]: ... + + schema = strawberry.Schema(Query) + + assert ( + str(schema) + == textwrap.dedent( + """ + type NotFoundError { + message: String! + } + + type OtherType { + b: String! + } + + type Query { + someTypeQueries(id: ID!): SomeTypeOtherTypeUnionObjectQueries! + } + + type SomeType { + a: String! + } + + union SomeTypeOtherTypeByIdResult = SomeType | OtherType | NotFoundError + + type SomeTypeOtherTypeUnionObjectQueries { + byId(id: ID!): SomeTypeOtherTypeByIdResult! + } + """ + ).strip() + )