Skip to content

Commit

Permalink
Generic typing for Coordinator
Browse files Browse the repository at this point in the history
  • Loading branch information
unkcpz committed Dec 19, 2024
1 parent 4c6210a commit f46b75f
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 8 deletions.
18 changes: 12 additions & 6 deletions src/plumpy/rmq/communications.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import asyncio
import functools
from typing import TYPE_CHECKING, Any, Callable, Hashable, Optional
from typing import TYPE_CHECKING, Any, Callable, Generic, Hashable, Optional, TypeVar, final

import kiwipy

Expand Down Expand Up @@ -78,10 +78,11 @@ def converted(communicator: kiwipy.Communicator, *args: Any, **kwargs: Any) -> k

return converted

T = TypeVar('T', bound=kiwipy.Communicator)

def wrap_communicator(
communicator: kiwipy.Communicator, loop: Optional[asyncio.AbstractEventLoop] = None
) -> 'LoopCommunicator':
communicator: T, loop: Optional[asyncio.AbstractEventLoop] = None
) -> 'LoopCommunicator[T]':
"""
Wrap a communicator such that all callbacks made to any subscribers are scheduled on the
given event loop.
Expand All @@ -101,10 +102,11 @@ def wrap_communicator(
return LoopCommunicator(communicator, loop)


class LoopCommunicator(kiwipy.Communicator): # type: ignore
@final
class LoopCommunicator(Generic[T], kiwipy.Communicator): # type: ignore
"""Wrapper around a `kiwipy.Communicator` that schedules any subscriber messages on a given event loop."""

def __init__(self, communicator: kiwipy.Communicator, loop: Optional[asyncio.AbstractEventLoop] = None):
def __init__(self, communicator: T, loop: Optional[asyncio.AbstractEventLoop] = None):
"""
:param communicator: The kiwipy communicator
:param loop: The event loop to schedule callbacks on
Expand All @@ -115,6 +117,10 @@ def __init__(self, communicator: kiwipy.Communicator, loop: Optional[asyncio.Abs
self._communicator = communicator
self._loop: asyncio.AbstractEventLoop = loop or asyncio.get_event_loop()

@property
def inner(self) -> T:
return self._communicator

def loop(self) -> asyncio.AbstractEventLoop:
return self._loop

Expand Down Expand Up @@ -153,7 +159,7 @@ def broadcast_send(
sender: Optional[str] = None,
subject: Optional[str] = None,
correlation_id: Optional['ID_TYPE'] = None,
) -> futures.Future:
) -> kiwipy.Future:
return self._communicator.broadcast_send(body, sender, subject, correlation_id)

def is_closed(self) -> bool:
Expand Down
13 changes: 11 additions & 2 deletions src/plumpy/rmq/coordinator.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,24 @@
# -*- coding: utf-8 -*-
from typing import Generic, TypeVar, final
import kiwipy
import concurrent.futures

from plumpy.exceptions import CoordinatorConnectionError

__all__ = ['RmqCoordinator']

class RmqCoordinator:
def __init__(self, comm: kiwipy.Communicator):
U = TypeVar("U", bound=kiwipy.Communicator)

@final
class RmqCoordinator(Generic[U]):
def __init__(self, comm: U):
self._comm = comm

@property
def communicator(self) -> U:
"""The inner communicator."""
return self._comm

# XXX: naming - `add_receiver_rpc`
def add_rpc_subscriber(self, subscriber, identifier=None):
return self._comm.add_rpc_subscriber(subscriber, identifier)
Expand Down

0 comments on commit f46b75f

Please sign in to comment.