diff --git a/strawberry_django/integrations/guardian.py b/strawberry_django/integrations/guardian.py index d252f4065..f15baf444 100644 --- a/strawberry_django/integrations/guardian.py +++ b/strawberry_django/integrations/guardian.py @@ -1,7 +1,7 @@ import contextlib import dataclasses import weakref -from typing import Optional, Union, cast +from typing import Optional, Type, Union, cast from django.contrib.auth import get_user_model from django.contrib.auth.models import Group @@ -24,7 +24,9 @@ class ObjectPermissionModels: group: GroupObjectPermissionBase -def get_object_permission_models(model: models.Model): +def get_object_permission_models( + model: Union[models.Model, Type[models.Model]], +) -> ObjectPermissionModels: return ObjectPermissionModels( user=cast(UserObjectPermissionBase, get_user_obj_perms_model(model)), group=cast(GroupObjectPermissionBase, get_group_obj_perms_model(model)), diff --git a/strawberry_django/mutations/resolvers.py b/strawberry_django/mutations/resolvers.py index 9c1bfe795..fd13a1f9b 100644 --- a/strawberry_django/mutations/resolvers.py +++ b/strawberry_django/mutations/resolvers.py @@ -262,7 +262,10 @@ def prepare_create_update( (ParsedObject, str), ): value, value_data = _parse_data( # noqa: PLW2901 - info, field.related_model, value, key_attr=key_attr + info, + cast(type[Model], field.related_model), + value, + key_attr=key_attr, ) if value is None and not value_data: value = None # noqa: PLW2901 @@ -508,7 +511,7 @@ def update_field(info: Info, instance: Model, field: models.Field, value: Any): and isinstance(field, models.ForeignObject) and not isinstance(value, Model) ): - value, data = _parse_pk(value, field.related_model) + value, data = _parse_pk(value, cast(type[Model], field.related_model)) field.save_form_data(instance, value) # If data was passed to the foreign key, update it recursively @@ -574,7 +577,9 @@ def update_m2m( existing = set(manager.all()) need_remove_cache = need_remove_cache or bool(values) for v in values: - obj, data = _parse_data(info, manager.model, v, key_attr=key_attr) + obj, data = _parse_data( + info, cast(type[Model], manager.model), v, key_attr=key_attr + ) if obj: data.pop(key_attr, None) through_defaults = data.pop("through_defaults", {}) @@ -632,7 +637,12 @@ def update_m2m( else: need_remove_cache = need_remove_cache or bool(value.add) for v in value.add or []: - obj, data = _parse_data(info, manager.model, v, key_attr=key_attr) + obj, data = _parse_data( + info, + cast(type[Model], manager.model), + v, + key_attr=key_attr, + ) if obj and data: data.pop(key_attr, None) if full_clean: @@ -653,7 +663,9 @@ def update_m2m( need_remove_cache = need_remove_cache or bool(value.remove) for v in value.remove or []: - obj, data = _parse_data(info, manager.model, v, key_attr=key_attr) + obj, data = _parse_data( + info, cast(type[Model], manager.model), v, key_attr=key_attr + ) data.pop(key_attr, None) assert not data to_remove.append(obj)