Skip to content

Commit

Permalink
Support async context generators
Browse files Browse the repository at this point in the history
  • Loading branch information
hugobessa committed Dec 11, 2024
1 parent 5b7074f commit 9e7f16b
Showing 1 changed file with 73 additions and 13 deletions.
86 changes: 73 additions & 13 deletions vintasend/services/notification_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
import uuid
from collections.abc import Callable, Iterable
from typing import Any, ClassVar, Generic, TypeGuard, TypeVar, cast
from typing import Any, ClassVar, Coroutine, Generic, TypeGuard, TypeVar, cast

from vintasend.app_settings import NotificationSettings
from vintasend.services.notification_backends.asyncio_base import AsyncIOBaseNotificationBackend
Expand Down Expand Up @@ -49,9 +49,20 @@


class Contexts(metaclass=SingletonMeta):
_contexts: ClassVar[dict[str, Callable[[Any], NotificationContextDict]]] = {}
_contexts: ClassVar[
dict[
str,
Callable[[Any], NotificationContextDict]
| Callable[[Any], Coroutine[Any, Any, NotificationContextDict]],
]
] = {}

def register_function(self, key: str, func: Callable[[Any], NotificationContextDict]):
def register_function(
self,
key: str,
func: Callable[[Any], NotificationContextDict]
| Callable[[Any], Coroutine[Any, Any, NotificationContextDict]],
):
self._contexts[key] = func

def get_function(self, key: str):
Expand All @@ -77,7 +88,9 @@ class NotificationService(Generic[A, B]):

def __init__(
self,
notification_adapters: Iterable[A] | Iterable[tuple[str, str | tuple[str, dict[str, Any]]]] | None = None,
notification_adapters: Iterable[A]
| Iterable[tuple[str, str | tuple[str, dict[str, Any]]]]
| None = None,
notification_backend: B | str | None = None,
notification_backend_kwargs: dict | None = None,
config: Any = None,
Expand Down Expand Up @@ -117,22 +130,33 @@ def __init__(
]

def _check_is_base_notification_adapter_iterable(
self, notification_adapters: Iterable[A] | Iterable[tuple[str, str | tuple[str, dict[str, Any]]]] | None
self,
notification_adapters: Iterable[A]
| Iterable[tuple[str, str | tuple[str, dict[str, Any]]]]
| None,
) -> TypeGuard[Iterable[A]]:
return notification_adapters is not None and all(
isinstance(adapter, BaseNotificationAdapter) for adapter in notification_adapters
)

def _check_is_adapters_tuple_iterable(
self, notification_adapters: Iterable[A] | Iterable[tuple[str, str | tuple[str, dict[str, Any]]]] | None
self,
notification_adapters: Iterable[A]
| Iterable[tuple[str, str | tuple[str, dict[str, Any]]]]
| None,
) -> TypeGuard[Iterable[tuple[str, str | tuple[str, dict[str, Any]]]]]:
return notification_adapters is not None and all(
(isinstance(adapter, tuple) or isinstance(adapter, list))
and len(adapter) == 2
and isinstance(adapter[0], str)
and (isinstance(adapter[1], str) or (
isinstance(adapter[1], tuple) and isinstance(adapter[1][0], str) and isinstance(adapter[1][1], dict)
))
and (
isinstance(adapter[1], str)
or (
isinstance(adapter[1], tuple)
and isinstance(adapter[1][0], str)
and isinstance(adapter[1][1], dict)
)
)
for adapter in notification_adapters
)

Expand Down Expand Up @@ -319,6 +343,20 @@ def get_future_notifications(self, page: int, page_size: int) -> Iterable[Notifi
"""
return self.notification_backend.get_future_notifications(page, page_size)

def _is_asyncio_context_function(
self,
context_function: Callable[[Any], NotificationContextDict]
| Callable[[Any], Coroutine[Any, Any, NotificationContextDict]],
) -> TypeGuard[Callable[[Any], Coroutine[Any, Any, NotificationContextDict]]]:
return asyncio.iscoroutinefunction(context_function)

def _is_sync_context_function(
self,
context_function: Callable[[Any], NotificationContextDict]
| Callable[[Any], Coroutine[Any, Any, NotificationContextDict]],
) -> TypeGuard[Callable[[Any], NotificationContextDict]]:
return not asyncio.iscoroutinefunction(context_function)

def get_notification_context(self, notification: Notification) -> NotificationContextDict:
"""
Generate the context for a notification. It uses the context_name and context_kwargs from the notification.
Expand All @@ -334,7 +372,11 @@ def get_notification_context(self, notification: Notification) -> NotificationCo
if context_function is None:
raise NotificationContextGenerationError("Context function not found")
try:
return context_function(*[], **notification.context_kwargs)
if self._is_asyncio_context_function(context_function):
return asyncio.run(context_function(*[], **notification.context_kwargs))
elif self._is_sync_context_function(context_function):
return context_function(*[], **notification.context_kwargs)
raise NotificationContextGenerationError("Invalid context function")
except Exception as e: # noqa: BLE001
raise NotificationContextGenerationError("Failed getting notification context") from e

Expand Down Expand Up @@ -573,7 +615,7 @@ async def send(self, notification: Notification, lock: asyncio.Lock | None = Non
notification: Notification - the notification to be sent
"""
try:
context = self.get_notification_context(notification)
context = await self.get_notification_context(notification)
except NotificationContextGenerationError:
logger.exception("Failed to generate context for notification %s", notification.id)
try:
Expand Down Expand Up @@ -744,7 +786,7 @@ async def get_future_notifications(self, page: int, page_size: int) -> Iterable[
"""
return await self.notification_backend.get_future_notifications(page, page_size)

def get_notification_context(self, notification: Notification) -> NotificationContextDict:
async def get_notification_context(self, notification: Notification) -> NotificationContextDict:
"""
Generate the context for a notification. It uses the context_name and context_kwargs from the notification.
Contexts are registered using the @register_context decorator.
Expand All @@ -759,10 +801,28 @@ def get_notification_context(self, notification: Notification) -> NotificationCo
if context_function is None:
raise NotificationContextGenerationError("Context function not found")
try:
return context_function(*[], **notification.context_kwargs)
if self._is_asyncio_context_function(context_function):
return await context_function(*[], **notification.context_kwargs)
elif self._is_sync_context_function(context_function):
return context_function(*[], **notification.context_kwargs)
raise NotificationContextGenerationError("Invalid context function")
except Exception as e: # noqa: BLE001
raise NotificationContextGenerationError("Failed getting notification context") from e

def _is_asyncio_context_function(
self,
context_function: Callable[[Any], NotificationContextDict]
| Callable[[Any], Coroutine[Any, Any, NotificationContextDict]],
) -> TypeGuard[Callable[[Any], Coroutine[Any, Any, NotificationContextDict]]]:
return asyncio.iscoroutinefunction(context_function)

def _is_sync_context_function(
self,
context_function: Callable[[Any], NotificationContextDict]
| Callable[[Any], Coroutine[Any, Any, NotificationContextDict]],
) -> TypeGuard[Callable[[Any], NotificationContextDict]]:
return not asyncio.iscoroutinefunction(context_function)

async def _send_notification_with_error_logging(
self, notification: "Notification", lock: asyncio.Lock | None = None
) -> None:
Expand Down

0 comments on commit 9e7f16b

Please sign in to comment.