Skip to content

Commit

Permalink
FIX: Add support for generated enums in mutation input
Browse files Browse the repository at this point in the history
  • Loading branch information
cngai committed Mar 7, 2024
1 parent 67433e6 commit 11068f1
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 4 deletions.
4 changes: 4 additions & 0 deletions strawberry_django/mutations/resolvers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import dataclasses
from enum import Enum
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -173,6 +174,9 @@ def parse_input(info: Info, data: Any, *, key_attr: str | None = "pk"):
),
)

if isinstance(data, Enum):
return data.value

if dataclasses.is_dataclass(data):
return {
f.name: parse_input(info, getattr(data, f.name), key_attr=key_attr)
Expand Down
72 changes: 72 additions & 0 deletions tests/test_enums.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import textwrap
from typing import cast

import pytest
import strawberry
from django.db import models
from django.test import override_settings
Expand All @@ -9,6 +10,7 @@
from pytest_mock import MockerFixture

import strawberry_django
from strawberry_django import mutations
from strawberry_django.fields import types
from strawberry_django.fields.types import field_type_map
from strawberry_django.settings import strawberry_django_settings
Expand Down Expand Up @@ -444,3 +446,73 @@ def obj(self) -> ChoicesWithExtraFieldsType:
"extra2": 99,
},
}


@override_settings(
STRAWBERRY_DJANGO={
**strawberry_django_settings(),
"GENERATE_ENUMS_FROM_CHOICES": True,
},
)
@pytest.mark.django_db(transaction=True)
def test_create_mutation_with_generated_enum_input(db):
@strawberry_django.type(ChoicesModel)
class ChoicesType:
attr1: strawberry.auto
attr2: strawberry.auto
attr3: strawberry.auto
attr4: strawberry.auto
attr5: strawberry.auto
attr6: strawberry.auto

@strawberry_django.input(ChoicesModel)
class ChoicesInput:
attr1: strawberry.auto
attr2: strawberry.auto
attr3: strawberry.auto
attr4: strawberry.auto
attr5: strawberry.auto
attr6: strawberry.auto

@strawberry.type
class Query:
choice: ChoicesType = strawberry_django.field()

@strawberry.type
class Mutation:
create_choice: ChoicesType = mutations.create(
ChoicesInput, handle_django_errors=True, argument_name="input"
)

schema = strawberry.Schema(query=Query, mutation=Mutation)

variables = {
"input": {
"attr1": "A",
"attr2": "X",
"attr3": Choice.C,
"attr4": 4,
"attr5": "a",
"attr6": 1,
}
}
result = schema.execute_sync(
"""
mutation CreateChoice($input: ChoicesInput!) {
createChoice(input: $input) {
... on OperationInfo {
messages {
kind
field
message
}
}
... on ChoicesType {
attr3
}
}
}
""",
variables,
)
assert result.data == {"createChoice": {"attr3": "c"}}
8 changes: 4 additions & 4 deletions tests/test_input_mutations.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ def test_input_create_with_m2m_mutation(db, gql_client: GraphQLTestClient):
@pytest.mark.django_db(transaction=True)
def test_input_update_mutation(db, gql_client: GraphQLTestClient):
query = """
mutation CreateIssue ($input: IssueInputPartial!) {
mutation UpdateIssue ($input: IssueInputPartial!) {
updateIssue (input: $input) {
__typename
... on OperationInfo {
Expand Down Expand Up @@ -473,7 +473,7 @@ def test_input_update_mutation(db, gql_client: GraphQLTestClient):
@pytest.mark.django_db(transaction=True)
def test_input_nested_update_mutation(db, gql_client: GraphQLTestClient):
query = """
mutation CreateIssue ($input: IssueInputPartial!) {
mutation UpdateIssue ($input: IssueInputPartial!) {
updateIssue (input: $input) {
__typename
... on OperationInfo {
Expand Down Expand Up @@ -576,7 +576,7 @@ def test_input_update_m2m_set_not_null_mutation(db, gql_client: GraphQLTestClien
@pytest.mark.django_db(transaction=True)
def test_input_update_m2m_set_mutation(db, gql_client: GraphQLTestClient):
query = """
mutation CreateIssue ($input: IssueInputPartial!) {
mutation UpdateIssue ($input: IssueInputPartial!) {
updateIssue (input: $input) {
__typename
... on OperationInfo {
Expand Down Expand Up @@ -701,7 +701,7 @@ def test_input_update_m2m_set_mutation(db, gql_client: GraphQLTestClient):
@pytest.mark.django_db(transaction=True)
def test_input_update_m2m_set_through_mutation(db, gql_client: GraphQLTestClient):
query = """
mutation CreateIssue ($input: IssueInputPartial!) {
mutation UpdateIssue ($input: IssueInputPartial!) {
updateIssue (input: $input) {
__typename
... on OperationInfo {
Expand Down

0 comments on commit 11068f1

Please sign in to comment.