Skip to content

Commit

Permalink
Runner use coordinator interface
Browse files Browse the repository at this point in the history
  • Loading branch information
unkcpz committed Dec 21, 2024
1 parent 36f1a86 commit 28cdb1c
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 31 deletions.
4 changes: 4 additions & 0 deletions src/aiida/brokers/broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ def __init__(self, profile: 'Profile') -> None:
def get_communicator(self):
"""Return an instance of :class:`kiwipy.Communicator`."""

@abc.abstractmethod
def get_coordinator(self):
"""Return an instance of coordinator."""

@abc.abstractmethod
def iterate_tasks(self):
"""Return an iterator over the tasks in the launch queue."""
Expand Down
8 changes: 7 additions & 1 deletion src/aiida/brokers/rabbitmq/broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import functools
import typing as t

from plumpy.rmq import RmqCoordinator

from aiida.brokers.broker import Broker
from aiida.common.log import AIIDA_LOGGER
from aiida.manage.configuration import get_config_option
Expand All @@ -13,7 +15,6 @@

if t.TYPE_CHECKING:
from kiwipy.rmq import RmqThreadCommunicator

from aiida.manage.configuration.profile import Profile

LOGGER = AIIDA_LOGGER.getChild('broker.rabbitmq')
Expand Down Expand Up @@ -58,6 +59,11 @@ def get_communicator(self) -> 'RmqThreadCommunicator':

return self._communicator

def get_coordinator(self):
coordinator = RmqCoordinator(self.get_communicator())

return coordinator

def _create_communicator(self) -> 'RmqThreadCommunicator':
"""Return an instance of :class:`kiwipy.Communicator`."""
from kiwipy.rmq import RmqThreadCommunicator
Expand Down
7 changes: 3 additions & 4 deletions src/aiida/engine/processes/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@
import plumpy.processes

# from kiwipy.communications import UnroutableError
# from plumpy.processes import ConnectionClosed # type: ignore[attr-defined]
from plumpy.process_states import Finished, ProcessState

# from plumpy.processes import ConnectionClosed # type: ignore[attr-defined]
from plumpy.processes import Process as PlumpyProcess
from plumpy.utils import AttributesFrozendict

Expand Down Expand Up @@ -174,13 +174,12 @@ def __init__(
from aiida.manage import manager

self._runner = runner if runner is not None else manager.get_manager().get_runner()
# assert self._runner.communicator is not None, 'communicator not set for runner'

super().__init__(
inputs=self.spec().inputs.serialize(inputs),
logger=logger,
loop=self._runner.loop,
coordinator=self._runner.communicator,
coordinator=self._runner.coordinator,
)

self._node: Optional[orm.ProcessNode] = None
Expand Down Expand Up @@ -320,7 +319,7 @@ def load_instance_state(
else:
self._runner = manager.get_manager().get_runner()

load_context = load_context.copyextend(loop=self._runner.loop, coordinator=self._runner.communicator)
load_context = load_context.copyextend(loop=self._runner.loop, coordinator=self._runner.coordinator)
super().load_instance_state(saved_state, load_context)

if self.SaveKeys.CALC_ID.value in saved_state:
Expand Down
34 changes: 18 additions & 16 deletions src/aiida/engine/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from typing import Any, Callable, Dict, NamedTuple, Optional, Tuple, Type, Union

import kiwipy
from plumpy.coordinator import Coordinator
from plumpy.events import reset_event_loop_policy, set_event_loop_policy
from plumpy.persistence import Persister
from plumpy.rmq import RemoteProcessThreadController, wrap_communicator
Expand Down Expand Up @@ -55,30 +56,30 @@ class Runner:
"""Class that can launch processes by running in the current interpreter or by submitting them to the daemon."""

_persister: Optional[Persister] = None
_communicator: Optional[kiwipy.Communicator] = None
_coordinator: Optional[Coordinator] = None
_controller: Optional[RemoteProcessThreadController] = None
_closed: bool = False

def __init__(
self,
poll_interval: Union[int, float] = 0,
loop: Optional[asyncio.AbstractEventLoop] = None,
communicator: Optional[kiwipy.Communicator] = None,
coordinator: Optional[Coordinator] = None,
broker_submit: bool = False,
persister: Optional[Persister] = None,
):
"""Construct a new runner.
:param poll_interval: interval in seconds between polling for status of active sub processes
:param loop: an asyncio event loop, if none is suppled a new one will be created
:param communicator: the communicator to use
:param coordinator: the coordinator to use
:param broker_submit: if True, processes will be submitted to the broker, otherwise they will be scheduled here
:param persister: the persister to use to persist processes
"""
assert not (
broker_submit and persister is None
), 'Must supply a persister if you want to submit using communicator'
), 'Must supply a persister if you want to submit using coordinator'

set_event_loop_policy()
self._loop = loop or asyncio.get_event_loop()
Expand All @@ -89,11 +90,12 @@ def __init__(
self._persister = persister
self._plugin_version_provider = PluginVersionProvider()

if communicator is not None:
self._communicator = wrap_communicator(communicator, self._loop)
self._controller = RemoteProcessThreadController(communicator)
if coordinator is not None:
# FIXME: the wrap is not needed, when passed in, the coordinator should already wrapped
self._coordinator = wrap_communicator(coordinator.communicator, self._loop)
self._controller = RemoteProcessThreadController(coordinator)
elif self._broker_submit:
LOGGER.warning('Disabling broker submission, no communicator provided')
LOGGER.warning('Disabling broker submission, no coordinator provided')
self._broker_submit = False

def __enter__(self) -> 'Runner':
Expand All @@ -117,9 +119,9 @@ def persister(self) -> Optional[Persister]:
return self._persister

@property
def communicator(self) -> Optional[kiwipy.Communicator]:
"""Get the communicator used by this runner."""
return self._communicator
def coordinator(self) -> Optional[Coordinator]:
"""Get the coordinator used by this runner."""
return self._coordinator

@property
def plugin_version_provider(self) -> PluginVersionProvider:
Expand Down Expand Up @@ -329,16 +331,16 @@ def inline_callback(event, *args, **kwargs):
callback()
finally:
event.set()
if self.communicator:
self.communicator.remove_broadcast_subscriber(subscriber_identifier)
if self.coordinator:
self.coordinator.remove_broadcast_subscriber(subscriber_identifier)

broadcast_filter = kiwipy.BroadcastFilter(functools.partial(inline_callback, event), sender=pk)
for state in [ProcessState.FINISHED, ProcessState.KILLED, ProcessState.EXCEPTED]:
broadcast_filter.add_subject_filter(f'state_changed.*.{state.value}')

if self.communicator:
if self.coordinator:
LOGGER.info('adding subscriber for broadcasts of %d', pk)
self.communicator.add_broadcast_subscriber(broadcast_filter, subscriber_identifier)
self.coordinator.add_broadcast_subscriber(broadcast_filter, subscriber_identifier)
self._poll_process(node, functools.partial(inline_callback, event))

def get_process_future(self, pk: int) -> futures.ProcessFuture:
Expand All @@ -348,7 +350,7 @@ def get_process_future(self, pk: int) -> futures.ProcessFuture:
:return: A future representing the completion of the process node
"""
return futures.ProcessFuture(pk, self._loop, self._poll_interval, self._communicator)
return futures.ProcessFuture(pk, self._loop, self._poll_interval, self._coordinator)

def _poll_process(self, node, callback):
"""Check whether the process state of the node is terminated and call the callback or reschedule it.
Expand Down
35 changes: 26 additions & 9 deletions src/aiida/manage/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import asyncio
import kiwipy
from plumpy.coordinator import Coordinator

if TYPE_CHECKING:
from kiwipy.rmq import RmqThreadCommunicator
Expand Down Expand Up @@ -60,8 +61,8 @@ class Manager:
3. A single storage backend object for the profile, to connect to data storage resources
5. A single daemon client object for the profile, to connect to the AiiDA daemon
4. A single communicator object for the profile, to connect to the process control resources
6. A single process controller object for the profile, which uses the communicator to control process tasks
4. A single coordinator object for the profile, to connect to the process control resources
6. A single process controller object for the profile, which uses the coordinator to control process tasks
7. A single runner object for the profile, which uses the process controller to start and stop processes
8. A single persister object for the profile, which can persist running processes to the profile storage
Expand Down Expand Up @@ -343,6 +344,23 @@ def get_communicator(self) -> 'RmqThreadCommunicator':

return broker.get_communicator()

def get_coordinator(self) -> 'Coordinator':
"""Return the coordinator
:return: a global coordinator instance
"""
from aiida.common import ConfigurationError

broker = self.get_broker()

if broker is None:
assert self._profile is not None
raise ConfigurationError(
f'profile `{self._profile.name}` does not provide a coordinator because it does not define a broker'
)

return broker.get_coordinator()

def get_daemon_client(self) -> 'DaemonClient':
"""Return the daemon client for the current profile.
Expand Down Expand Up @@ -373,8 +391,7 @@ def get_process_controller(self) -> 'RemoteProcessThreadController':
from plumpy.rmq import RemoteProcessThreadController

if self._process_controller is None:
# FIXME: use coordinator wrapper
self._process_controller = RemoteProcessThreadController(self.get_communicator())
self._process_controller = RemoteProcessThreadController(self.get_coordinator())

return self._process_controller

Expand Down Expand Up @@ -402,7 +419,7 @@ def create_runner(
self,
poll_interval: Union[int, float] | None = None,
loop: Optional[asyncio.AbstractEventLoop] = None,
communicator: Optional[kiwipy.Communicator] = None,
coordinator: Optional[Coordinator] = None,
broker_submit: bool = False,
persister: Optional[AiiDAPersister] = None,
) -> 'Runner':
Expand All @@ -423,13 +440,13 @@ def create_runner(

_default_poll_interval = 0.0 if profile.is_test_profile else self.get_option('runner.poll.interval')
_default_broker_submit = False
_default_communicator = self.get_communicator()
_default_coordinator = self.get_coordinator()
_default_persister = self.get_persister()

runner = runners.Runner(
poll_interval=poll_interval or _default_poll_interval,
loop=loop or asyncio.get_event_loop(),
communicator=communicator or _default_communicator,
coordinator=coordinator or _default_coordinator,
broker_submit=broker_submit or _default_broker_submit,
persister=persister or _default_persister,
)
Expand Down Expand Up @@ -461,8 +478,8 @@ def create_daemon_runner(self, loop: Optional['asyncio.AbstractEventLoop'] = Non
loader=persistence.get_object_loader(),
)

assert runner.communicator is not None, 'communicator not set for runner'
runner.communicator.add_task_subscriber(task_receiver)
assert runner.coordinator is not None, 'coordinator not set for runner'
runner.coordinator.add_task_subscriber(task_receiver)

return runner

Expand Down
2 changes: 1 addition & 1 deletion tests/engine/test_futures.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test_calculation_future_broadcasts(self):

# No polling
future = processes.futures.ProcessFuture(
pk=process.pid, loop=runner.loop, communicator=manager.get_coordinator()
pk=process.pid, loop=runner.loop, communicator=manager.get_communicator()
)

run(process)
Expand Down

0 comments on commit 28cdb1c

Please sign in to comment.