diff --git a/strawberry_django/mutations/resolvers.py b/strawberry_django/mutations/resolvers.py index 4c39d1153..d18e295ee 100644 --- a/strawberry_django/mutations/resolvers.py +++ b/strawberry_django/mutations/resolvers.py @@ -1,6 +1,7 @@ from __future__ import annotations import dataclasses +from enum import Enum from typing import ( TYPE_CHECKING, Any, @@ -173,6 +174,9 @@ def parse_input(info: Info, data: Any, *, key_attr: str | None = "pk"): ), ) + if isinstance(data, Enum): + return data.value + if dataclasses.is_dataclass(data): return { f.name: parse_input(info, getattr(data, f.name), key_attr=key_attr) diff --git a/tests/test_enums.py b/tests/test_enums.py index 2e07e317e..336131e02 100644 --- a/tests/test_enums.py +++ b/tests/test_enums.py @@ -1,6 +1,7 @@ import textwrap from typing import cast +import pytest import strawberry from django.db import models from django.test import override_settings @@ -9,6 +10,7 @@ from pytest_mock import MockerFixture import strawberry_django +from strawberry_django import mutations from strawberry_django.fields import types from strawberry_django.fields.types import field_type_map from strawberry_django.settings import strawberry_django_settings @@ -444,3 +446,73 @@ def obj(self) -> ChoicesWithExtraFieldsType: "extra2": 99, }, } + + +@override_settings( + STRAWBERRY_DJANGO={ + **strawberry_django_settings(), + "GENERATE_ENUMS_FROM_CHOICES": True, + }, +) +@pytest.mark.django_db(transaction=True) +def test_create_mutation_with_generated_enum_input(db): + @strawberry_django.type(ChoicesModel) + class ChoicesType: + attr1: strawberry.auto + attr2: strawberry.auto + attr3: strawberry.auto + attr4: strawberry.auto + attr5: strawberry.auto + attr6: strawberry.auto + + @strawberry_django.input(ChoicesModel) + class ChoicesInput: + attr1: strawberry.auto + attr2: strawberry.auto + attr3: strawberry.auto + attr4: strawberry.auto + attr5: strawberry.auto + attr6: strawberry.auto + + @strawberry.type + class Query: + choice: ChoicesType = strawberry_django.field() + + @strawberry.type + class Mutation: + create_choice: ChoicesType = mutations.create( + ChoicesInput, handle_django_errors=True, argument_name="input" + ) + + schema = strawberry.Schema(query=Query, mutation=Mutation) + + variables = { + "input": { + "attr1": "A", + "attr2": "X", + "attr3": Choice.C, + "attr4": 4, + "attr5": "a", + "attr6": 1, + } + } + result = schema.execute_sync( + """ + mutation CreateChoice($input: ChoicesInput!) { + createChoice(input: $input) { + ... on OperationInfo { + messages { + kind + field + message + } + } + ... on ChoicesType { + attr3 + } + } + } + """, + variables, + ) + assert result.data == {"createChoice": {"attr3": "c"}} diff --git a/tests/test_input_mutations.py b/tests/test_input_mutations.py index 4b156982f..f48d54e33 100644 --- a/tests/test_input_mutations.py +++ b/tests/test_input_mutations.py @@ -380,7 +380,7 @@ def test_input_create_with_m2m_mutation(db, gql_client: GraphQLTestClient): @pytest.mark.django_db(transaction=True) def test_input_update_mutation(db, gql_client: GraphQLTestClient): query = """ - mutation CreateIssue ($input: IssueInputPartial!) { + mutation UpdateIssue ($input: IssueInputPartial!) { updateIssue (input: $input) { __typename ... on OperationInfo { @@ -473,7 +473,7 @@ def test_input_update_mutation(db, gql_client: GraphQLTestClient): @pytest.mark.django_db(transaction=True) def test_input_nested_update_mutation(db, gql_client: GraphQLTestClient): query = """ - mutation CreateIssue ($input: IssueInputPartial!) { + mutation UpdateIssue ($input: IssueInputPartial!) { updateIssue (input: $input) { __typename ... on OperationInfo { @@ -576,7 +576,7 @@ def test_input_update_m2m_set_not_null_mutation(db, gql_client: GraphQLTestClien @pytest.mark.django_db(transaction=True) def test_input_update_m2m_set_mutation(db, gql_client: GraphQLTestClient): query = """ - mutation CreateIssue ($input: IssueInputPartial!) { + mutation UpdateIssue ($input: IssueInputPartial!) { updateIssue (input: $input) { __typename ... on OperationInfo { @@ -701,7 +701,7 @@ def test_input_update_m2m_set_mutation(db, gql_client: GraphQLTestClient): @pytest.mark.django_db(transaction=True) def test_input_update_m2m_set_through_mutation(db, gql_client: GraphQLTestClient): query = """ - mutation CreateIssue ($input: IssueInputPartial!) { + mutation UpdateIssue ($input: IssueInputPartial!) { updateIssue (input: $input) { __typename ... on OperationInfo {