From d0fcf3019926653fa12e3f9fbbbffe41705aff26 Mon Sep 17 00:00:00 2001 From: Thiago Bellini Ribeiro Date: Wed, 17 Jul 2024 19:03:33 -0300 Subject: [PATCH] fix: Fix typing issues --- strawberry_django/integrations/guardian.py | 6 ++++-- strawberry_django/mutations/resolvers.py | 23 +++++++++++++++++----- 2 files changed, 22 insertions(+), 7 deletions(-) diff --git a/strawberry_django/integrations/guardian.py b/strawberry_django/integrations/guardian.py index d252f406..f15baf44 100644 --- a/strawberry_django/integrations/guardian.py +++ b/strawberry_django/integrations/guardian.py @@ -1,7 +1,7 @@ import contextlib import dataclasses import weakref -from typing import Optional, Union, cast +from typing import Optional, Type, Union, cast from django.contrib.auth import get_user_model from django.contrib.auth.models import Group @@ -24,7 +24,9 @@ class ObjectPermissionModels: group: GroupObjectPermissionBase -def get_object_permission_models(model: models.Model): +def get_object_permission_models( + model: Union[models.Model, Type[models.Model]], +) -> ObjectPermissionModels: return ObjectPermissionModels( user=cast(UserObjectPermissionBase, get_user_obj_perms_model(model)), group=cast(GroupObjectPermissionBase, get_group_obj_perms_model(model)), diff --git a/strawberry_django/mutations/resolvers.py b/strawberry_django/mutations/resolvers.py index 9c1bfe79..928804c6 100644 --- a/strawberry_django/mutations/resolvers.py +++ b/strawberry_django/mutations/resolvers.py @@ -8,6 +8,7 @@ Callable, Iterable, List, + Type, TypeVar, cast, overload, @@ -262,7 +263,10 @@ def prepare_create_update( (ParsedObject, str), ): value, value_data = _parse_data( # noqa: PLW2901 - info, field.related_model, value, key_attr=key_attr + info, + cast(Type[Model], field.related_model), + value, + key_attr=key_attr, ) if value is None and not value_data: value = None # noqa: PLW2901 @@ -508,7 +512,7 @@ def update_field(info: Info, instance: Model, field: models.Field, value: Any): and isinstance(field, models.ForeignObject) and not isinstance(value, Model) ): - value, data = _parse_pk(value, field.related_model) + value, data = _parse_pk(value, cast(Type[Model], field.related_model)) field.save_form_data(instance, value) # If data was passed to the foreign key, update it recursively @@ -574,7 +578,9 @@ def update_m2m( existing = set(manager.all()) need_remove_cache = need_remove_cache or bool(values) for v in values: - obj, data = _parse_data(info, manager.model, v, key_attr=key_attr) + obj, data = _parse_data( + info, cast(Type[Model], manager.model), v, key_attr=key_attr + ) if obj: data.pop(key_attr, None) through_defaults = data.pop("through_defaults", {}) @@ -632,7 +638,12 @@ def update_m2m( else: need_remove_cache = need_remove_cache or bool(value.add) for v in value.add or []: - obj, data = _parse_data(info, manager.model, v, key_attr=key_attr) + obj, data = _parse_data( + info, + cast(Type[Model], manager.model), + v, + key_attr=key_attr, + ) if obj and data: data.pop(key_attr, None) if full_clean: @@ -653,7 +664,9 @@ def update_m2m( need_remove_cache = need_remove_cache or bool(value.remove) for v in value.remove or []: - obj, data = _parse_data(info, manager.model, v, key_attr=key_attr) + obj, data = _parse_data( + info, cast(Type[Model], manager.model), v, key_attr=key_attr + ) data.pop(key_attr, None) assert not data to_remove.append(obj)