From 1dd86ce71c73323e68201480d55bb74302c229cf Mon Sep 17 00:00:00 2001 From: Erik Wrede Date: Fri, 3 May 2024 18:04:16 +0200 Subject: [PATCH] fix: support passing context when nesting permissions --- strawberry/permission.py | 84 +++++++++++++++++++++------------ tests/schema/test_permission.py | 52 ++++++++++++++++++++ 2 files changed, 106 insertions(+), 30 deletions(-) diff --git a/strawberry/permission.py b/strawberry/permission.py index d285e855c8..1986df317e 100644 --- a/strawberry/permission.py +++ b/strawberry/permission.py @@ -13,6 +13,7 @@ Optional, Tuple, Type, + TypedDict, Union, ) from typing_extensions import deprecated @@ -37,6 +38,15 @@ from strawberry.types import Info +def unpack_maybe( + value: Union[object, Tuple[bool, object]], default: object = None +) -> Tuple[object, object]: + if isinstance(value, tuple) and len(value) == 2: + return value + else: + return value, default + + class BasePermission(abc.ABC): """ Base class for creating permissions @@ -52,7 +62,7 @@ class BasePermission(abc.ABC): @abc.abstractmethod def has_permission( - self, source: Any, info: Info, **kwargs: Any + self, source: Any, info: Info, **kwargs: object ) -> Union[ bool, Awaitable[bool], @@ -82,7 +92,7 @@ def has_permission( "Permission classes should override has_permission method" ) - def on_unauthorized(self, **kwargs: Any) -> None: + def on_unauthorized(self, **kwargs: object) -> None: """ Default error raising for permissions. This can be overridden to customize the behavior. @@ -128,14 +138,18 @@ def __or__(self, other: BasePermission): return OrPermission([self, other]) +class CompositePermissionContext(TypedDict): + failed_permissions: List[Tuple[BasePermission, dict]] + + class CompositePermission(BasePermission, abc.ABC): def __init__(self, child_permissions: List[BasePermission]): self.child_permissions = child_permissions - def on_unauthorized(self, **kwargs: Any) -> Any: + def on_unauthorized(self, **kwargs: object) -> Any: failed_permissions = kwargs.get("failed_permissions", []) - for permission in failed_permissions: - permission.on_unauthorized() + for permission, context in failed_permissions: + permission.on_unauthorized(**context) @cached_property def is_async(self) -> bool: @@ -144,27 +158,34 @@ def is_async(self) -> bool: class AndPermission(CompositePermission): def has_permission( - self, source: Any, info: Info, **kwargs: Any + self, source: Any, info: Info, **kwargs: object ) -> Union[ bool, Awaitable[bool], - Tuple[Literal[False], dict], - Awaitable[Tuple[Literal[False], dict]], + Tuple[Literal[False], CompositePermissionContext], + Awaitable[Tuple[Literal[False], CompositePermissionContext]], ]: if self.is_async: return self._has_permission_async(source, info, **kwargs) for permission in self.child_permissions: - if not permission.has_permission(source, info, **kwargs): - return False, {"failed_permissions": [permission]} + has_permission, context = unpack_maybe( + permission.has_permission(source, info, **kwargs), {} + ) + if not has_permission: + return False, {"failed_permissions": [(permission, context)]} return True async def _has_permission_async( - self, source: Any, info: Info, **kwargs: Any - ) -> Union[bool, Tuple[Literal[False], dict]]: + self, source: Any, info: Info, **kwargs: object + ) -> Union[bool, Tuple[Literal[False], CompositePermissionContext]]: for permission in self.child_permissions: - if not await await_maybe(permission.has_permission(source, info, **kwargs)): - return False, {"failed_permissions": [permission]} + permission_response = await await_maybe( + permission.has_permission(source, info, **kwargs) + ) + has_permission, context = unpack_maybe(permission_response, {}) + if not has_permission: + return False, {"failed_permissions": [(permission, context)]} return True def __and__(self, other: BasePermission): @@ -173,7 +194,7 @@ def __and__(self, other: BasePermission): class OrPermission(CompositePermission): def has_permission( - self, source: Any, info: Info, **kwargs: Any + self, source: Any, info: Info, **kwargs: object ) -> Union[ bool, Awaitable[bool], @@ -184,20 +205,27 @@ def has_permission( return self._has_permission_async(source, info, **kwargs) failed_permissions = [] for permission in self.child_permissions: - if permission.has_permission(source, info, **kwargs): + has_permission, context = unpack_maybe( + permission.has_permission(source, info, **kwargs), {} + ) + if has_permission: return True - failed_permissions.append(permission) + failed_permissions.append((permission, context)) return False, {"failed_permissions": failed_permissions} async def _has_permission_async( - self, source: Any, info: Info, **kwargs: Any + self, source: Any, info: Info, **kwargs: object ) -> Union[bool, Tuple[Literal[False], dict]]: failed_permissions = [] for permission in self.child_permissions: - if await await_maybe(permission.has_permission(source, info, **kwargs)): + permission_response = await await_maybe( + permission.has_permission(source, info, **kwargs) + ) + has_permission, context = unpack_maybe(permission_response, {}) + if has_permission: return True - failed_permissions.append(permission) + failed_permissions.append((permission, context)) return False, {"failed_permissions": failed_permissions} @@ -253,7 +281,7 @@ def apply(self, field: StrawberryField) -> None: else: raise PermissionFailSilentlyRequiresOptionalError(field) - def _on_unauthorized(self, permission: BasePermission, **kwargs: Any) -> Any: + def _on_unauthorized(self, permission: BasePermission, **kwargs: object) -> Any: if self.fail_silently: return [] if self.return_empty_list else None @@ -266,7 +294,7 @@ def resolve( next_: SyncExtensionResolver, source: Any, info: Info, - **kwargs: Any[str, Any], + **kwargs: object[str, Any], ) -> Any: """ Checks if the permission should be accepted and @@ -274,13 +302,9 @@ def resolve( """ for permission in self.permissions: - permission_response = permission.has_permission(source, info, **kwargs) - - context = {} - if isinstance(permission_response, tuple): - has_permission, context = permission_response - else: - has_permission = permission_response + has_permission, context = unpack_maybe( + permission.has_permission(source, info, **kwargs), {} + ) if not has_permission: return self._on_unauthorized(permission, **context) @@ -292,7 +316,7 @@ async def resolve_async( next_: AsyncExtensionResolver, source: Any, info: Info, - **kwargs: Any[str, Any], + **kwargs: object[str, Any], ) -> Any: for permission in self.permissions: permission_response = await await_maybe( diff --git a/tests/schema/test_permission.py b/tests/schema/test_permission.py index e1d1982ee6..3fe5b11638 100644 --- a/tests/schema/test_permission.py +++ b/tests/schema/test_permission.py @@ -164,6 +164,58 @@ def user_async(self) -> str: # pragma: no cover assert result.errors[0].message == "False Permission Failed" +@pytest.mark.asyncio +async def test_raises_graphql_error_when_nested(): + class FalsePermission(BasePermission): + message = "False Permission Failed" + + def has_permission( + self, source: typing.Any, info: strawberry.Info, **kwargs: typing.Any + ) -> bool: + return False + + class TruePermission(BasePermission): + message = "True Permission Failed" + + def has_permission( + self, source: typing.Any, info: strawberry.Info, **kwargs: typing.Any + ) -> bool: + return True + + class FalseAsyncPermission(BasePermission): + message = "False Permission Failed" + + async def has_permission( + self, source: typing.Any, info: strawberry.Info, **kwargs: typing.Any + ) -> bool: + return False + + @strawberry.type + class Query: + @strawberry.field( + extensions=[ + PermissionExtension( + permissions=[ + ( + (TruePermission() & FalsePermission()) + | FalseAsyncPermission() + ) + & TruePermission() + ] + ) + ] + ) + def user(self) -> str: # pragma: no cover + return "patrick" + + schema = strawberry.Schema(query=Query) + + query = "{ user }" + + result = await schema.execute(query) + assert result.errors[0].message == "False Permission Failed" + + @pytest.mark.asyncio async def test_raises_graphql_error_when_left_and_permission_is_denied(): class FalsePermission(BasePermission):