Skip to content

Commit

Permalink
Merge branch 'refs/heads/boolean-expression-permissions-erik' into bo…
Browse files Browse the repository at this point in the history
…olean-expression-permissions
  • Loading branch information
erikwrede committed May 16, 2024
2 parents 901abf0 + 1dd86ce commit 64bcbad
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 30 deletions.
84 changes: 54 additions & 30 deletions strawberry/permission.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
Optional,
Tuple,
Type,
TypedDict,
Union,
)
from typing_extensions import deprecated
Expand All @@ -37,6 +38,15 @@
from strawberry.types import Info


def unpack_maybe(
value: Union[object, Tuple[bool, object]], default: object = None
) -> Tuple[object, object]:
if isinstance(value, tuple) and len(value) == 2:
return value
else:
return value, default


class BasePermission(abc.ABC):
"""
Base class for creating permissions
Expand All @@ -52,7 +62,7 @@ class BasePermission(abc.ABC):

@abc.abstractmethod
def has_permission(
self, source: Any, info: Info, **kwargs: Any
self, source: Any, info: Info, **kwargs: object
) -> Union[
bool,
Awaitable[bool],
Expand Down Expand Up @@ -82,7 +92,7 @@ def has_permission(
"Permission classes should override has_permission method"
)

def on_unauthorized(self, **kwargs: Any) -> None:
def on_unauthorized(self, **kwargs: object) -> None:
"""
Default error raising for permissions.
This can be overridden to customize the behavior.
Expand Down Expand Up @@ -128,14 +138,18 @@ def __or__(self, other: BasePermission):
return OrPermission([self, other])


class CompositePermissionContext(TypedDict):
failed_permissions: List[Tuple[BasePermission, dict]]


class CompositePermission(BasePermission, abc.ABC):
def __init__(self, child_permissions: List[BasePermission]):
self.child_permissions = child_permissions

def on_unauthorized(self, **kwargs: Any) -> Any:
def on_unauthorized(self, **kwargs: object) -> Any:
failed_permissions = kwargs.get("failed_permissions", [])
for permission in failed_permissions:
permission.on_unauthorized()
for permission, context in failed_permissions:
permission.on_unauthorized(**context)

@cached_property
def is_async(self) -> bool:
Expand All @@ -144,27 +158,34 @@ def is_async(self) -> bool:

class AndPermission(CompositePermission):
def has_permission(
self, source: Any, info: Info, **kwargs: Any
self, source: Any, info: Info, **kwargs: object
) -> Union[
bool,
Awaitable[bool],
Tuple[Literal[False], dict],
Awaitable[Tuple[Literal[False], dict]],
Tuple[Literal[False], CompositePermissionContext],
Awaitable[Tuple[Literal[False], CompositePermissionContext]],
]:
if self.is_async:
return self._has_permission_async(source, info, **kwargs)

for permission in self.child_permissions:
if not permission.has_permission(source, info, **kwargs):
return False, {"failed_permissions": [permission]}
has_permission, context = unpack_maybe(
permission.has_permission(source, info, **kwargs), {}
)
if not has_permission:
return False, {"failed_permissions": [(permission, context)]}
return True

async def _has_permission_async(
self, source: Any, info: Info, **kwargs: Any
) -> Union[bool, Tuple[Literal[False], dict]]:
self, source: Any, info: Info, **kwargs: object
) -> Union[bool, Tuple[Literal[False], CompositePermissionContext]]:
for permission in self.child_permissions:
if not await await_maybe(permission.has_permission(source, info, **kwargs)):
return False, {"failed_permissions": [permission]}
permission_response = await await_maybe(
permission.has_permission(source, info, **kwargs)
)
has_permission, context = unpack_maybe(permission_response, {})
if not has_permission:
return False, {"failed_permissions": [(permission, context)]}
return True

def __and__(self, other: BasePermission):
Expand All @@ -173,7 +194,7 @@ def __and__(self, other: BasePermission):

class OrPermission(CompositePermission):
def has_permission(
self, source: Any, info: Info, **kwargs: Any
self, source: Any, info: Info, **kwargs: object
) -> Union[
bool,
Awaitable[bool],
Expand All @@ -184,20 +205,27 @@ def has_permission(
return self._has_permission_async(source, info, **kwargs)
failed_permissions = []
for permission in self.child_permissions:
if permission.has_permission(source, info, **kwargs):
has_permission, context = unpack_maybe(
permission.has_permission(source, info, **kwargs), {}
)
if has_permission:
return True
failed_permissions.append(permission)
failed_permissions.append((permission, context))

return False, {"failed_permissions": failed_permissions}

async def _has_permission_async(
self, source: Any, info: Info, **kwargs: Any
self, source: Any, info: Info, **kwargs: object
) -> Union[bool, Tuple[Literal[False], dict]]:
failed_permissions = []
for permission in self.child_permissions:
if await await_maybe(permission.has_permission(source, info, **kwargs)):
permission_response = await await_maybe(
permission.has_permission(source, info, **kwargs)
)
has_permission, context = unpack_maybe(permission_response, {})
if has_permission:
return True
failed_permissions.append(permission)
failed_permissions.append((permission, context))

return False, {"failed_permissions": failed_permissions}

Expand Down Expand Up @@ -253,7 +281,7 @@ def apply(self, field: StrawberryField) -> None:
else:
raise PermissionFailSilentlyRequiresOptionalError(field)

def _on_unauthorized(self, permission: BasePermission, **kwargs: Any) -> Any:
def _on_unauthorized(self, permission: BasePermission, **kwargs: object) -> Any:
if self.fail_silently:
return [] if self.return_empty_list else None

Expand All @@ -266,21 +294,17 @@ def resolve(
next_: SyncExtensionResolver,
source: Any,
info: Info,
**kwargs: Any[str, Any],
**kwargs: object[str, Any],
) -> Any:
"""
Checks if the permission should be accepted and
raises an exception if not
"""

for permission in self.permissions:
permission_response = permission.has_permission(source, info, **kwargs)

context = {}
if isinstance(permission_response, tuple):
has_permission, context = permission_response
else:
has_permission = permission_response
has_permission, context = unpack_maybe(
permission.has_permission(source, info, **kwargs), {}
)

if not has_permission:
return self._on_unauthorized(permission, **context)
Expand All @@ -292,7 +316,7 @@ async def resolve_async(
next_: AsyncExtensionResolver,
source: Any,
info: Info,
**kwargs: Any[str, Any],
**kwargs: object[str, Any],
) -> Any:
for permission in self.permissions:
permission_response = await await_maybe(
Expand Down
52 changes: 52 additions & 0 deletions tests/schema/test_permission.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,58 @@ def user_async(self) -> str: # pragma: no cover
assert result.errors[0].message == "False Permission Failed"


@pytest.mark.asyncio
async def test_raises_graphql_error_when_nested():
class FalsePermission(BasePermission):
message = "False Permission Failed"

def has_permission(
self, source: typing.Any, info: strawberry.Info, **kwargs: typing.Any
) -> bool:
return False

class TruePermission(BasePermission):
message = "True Permission Failed"

def has_permission(
self, source: typing.Any, info: strawberry.Info, **kwargs: typing.Any
) -> bool:
return True

class FalseAsyncPermission(BasePermission):
message = "False Permission Failed"

async def has_permission(
self, source: typing.Any, info: strawberry.Info, **kwargs: typing.Any
) -> bool:
return False

@strawberry.type
class Query:
@strawberry.field(
extensions=[
PermissionExtension(
permissions=[
(
(TruePermission() & FalsePermission())
| FalseAsyncPermission()
)
& TruePermission()
]
)
]
)
def user(self) -> str: # pragma: no cover
return "patrick"

schema = strawberry.Schema(query=Query)

query = "{ user }"

result = await schema.execute(query)
assert result.errors[0].message == "False Permission Failed"


@pytest.mark.asyncio
async def test_raises_graphql_error_when_left_and_permission_is_denied():
class FalsePermission(BasePermission):
Expand Down

0 comments on commit 64bcbad

Please sign in to comment.