Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
unkcpz committed Jan 31, 2025
1 parent c6d5b8f commit 8465edb
Show file tree
Hide file tree
Showing 8 changed files with 53 additions and 10 deletions.
23 changes: 17 additions & 6 deletions src/plumpy/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,14 @@
from .processes import Process

get_event_loop = asyncio.get_event_loop
new_event_loop = asyncio.new_event_loop


def create_running_loop():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)

return loop


def set_event_loop(*args: Any, **kwargs: Any) -> None:
Expand All @@ -24,16 +32,19 @@ class PlumpyEventLoopPolicy(asyncio.DefaultEventLoopPolicy):

_loop: asyncio.AbstractEventLoop | None = None

def get_event_loop(self) -> asyncio.AbstractEventLoop:
"""Return the patched event loop."""
def new_event_loop(self) -> asyncio.AbstractEventLoop:
"""Create new event loop and patch event loop as re-entrant."""
import nest_asyncio

if self._loop is None:
self._loop = super().get_event_loop()
nest_asyncio.apply(self._loop)
self._loop = super().new_event_loop()
nest_asyncio.apply(self._loop)

return self._loop

def get_event_loop(self) -> asyncio.AbstractEventLoop:
"""Return the patched event loop."""
return self._loop or self.new_event_loop()


def set_event_loop_policy() -> None:
"""Enable plumpy's event loop policy that will make event loop's reentrant."""
Expand All @@ -45,7 +56,7 @@ def set_event_loop_policy() -> None:

def reset_event_loop_policy() -> None:
"""Reset the event loop policy to the default."""
loop = get_event_loop()
loop = asyncio.get_event_loop()

cls = loop.__class__

Expand Down
2 changes: 1 addition & 1 deletion src/plumpy/processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ def __init__(
# Don't allow the spec to be changed anymore
self.spec().seal()

self._loop = loop if loop is not None else asyncio.get_event_loop()
self._loop = loop or asyncio.get_event_loop()

self._setup_event_hooks()

Expand Down
9 changes: 6 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@
import pytest


@pytest.fixture(scope='session')
def set_event_loop_policy():
from plumpy import set_event_loop_policy
from plumpy.events import set_event_loop_policy, reset_event_loop_policy


@pytest.fixture(scope='function')
def custom_event_loop_policy():
set_event_loop_policy()
yield
reset_event_loop_policy()
6 changes: 6 additions & 0 deletions tests/rmq/test_communications.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def __call__(self):
return Subscriber()


@pytest.mark.usefixtures('custom_event_loop_policy')
def test_add_rpc_subscriber(_coordinator, subscriber):
"""Test the `LoopCommunicator.add_rpc_subscriber` method."""
assert _coordinator.add_rpc_subscriber(subscriber) is not None
Expand All @@ -51,12 +52,14 @@ def test_add_rpc_subscriber(_coordinator, subscriber):
assert _coordinator.add_rpc_subscriber(subscriber, identifier) == identifier


@pytest.mark.usefixtures('custom_event_loop_policy')
def test_remove_rpc_subscriber(_coordinator, subscriber):
"""Test the `LoopCommunicator.remove_rpc_subscriber` method."""
identifier = _coordinator.add_rpc_subscriber(subscriber)
_coordinator.remove_rpc_subscriber(identifier)


@pytest.mark.usefixtures('custom_event_loop_policy')
def test_add_broadcast_subscriber(_coordinator, subscriber):
"""Test the `LoopCommunicator.add_broadcast_subscriber` method."""
assert _coordinator.add_broadcast_subscriber(subscriber) is not None
Expand All @@ -65,17 +68,20 @@ def test_add_broadcast_subscriber(_coordinator, subscriber):
assert _coordinator.add_broadcast_subscriber(subscriber, identifier=identifier) == identifier


@pytest.mark.usefixtures('custom_event_loop_policy')
def test_remove_broadcast_subscriber(_coordinator, subscriber):
"""Test the `LoopCommunicator.remove_broadcast_subscriber` method."""
identifier = _coordinator.add_broadcast_subscriber(subscriber)
_coordinator.remove_broadcast_subscriber(identifier)


@pytest.mark.usefixtures('custom_event_loop_policy')
def test_add_task_subscriber(_coordinator, subscriber):
"""Test the `LoopCommunicator.add_task_subscriber` method."""
assert _coordinator.add_task_subscriber(subscriber) is not None


@pytest.mark.usefixtures('custom_event_loop_policy')
def test_remove_task_subscriber(_coordinator, subscriber):
"""Test the `LoopCommunicator.remove_task_subscriber` method."""
identifier = _coordinator.add_task_subscriber(subscriber)
Expand Down
12 changes: 12 additions & 0 deletions tests/rmq/test_process_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def sync_controller(_coordinator):

class TestRemoteProcessController:
@pytest.mark.asyncio
@pytest.mark.usefixtures('custom_event_loop_policy')
async def test_pause(self, _coordinator, async_controller):
proc = utils.WaitForSignalProcess(coordinator=_coordinator)
# Run the process in the background
Expand All @@ -57,6 +58,7 @@ async def test_pause(self, _coordinator, async_controller):
assert proc.paused

@pytest.mark.asyncio
@pytest.mark.usefixtures('custom_event_loop_policy')
async def test_play(self, _coordinator, async_controller):
proc = utils.WaitForSignalProcess(coordinator=_coordinator)
# Run the process in the background
Expand All @@ -75,6 +77,7 @@ async def test_play(self, _coordinator, async_controller):
await async_controller.kill_process(proc.pid)

@pytest.mark.asyncio
@pytest.mark.usefixtures('custom_event_loop_policy')
async def test_kill(self, _coordinator, async_controller):
proc = utils.WaitForSignalProcess(coordinator=_coordinator)
# Run the process in the event loop
Expand All @@ -88,6 +91,7 @@ async def test_kill(self, _coordinator, async_controller):
assert proc.state_label == plumpy.ProcessState.KILLED

@pytest.mark.asyncio
@pytest.mark.usefixtures('custom_event_loop_policy')
async def test_status(self, _coordinator, async_controller):
proc = utils.WaitForSignalProcess(coordinator=_coordinator)
# Run the process in the background
Expand All @@ -101,6 +105,7 @@ async def test_status(self, _coordinator, async_controller):
# make sure proc reach the final state
await async_controller.kill_process(proc.pid)

@pytest.mark.usefixtures('custom_event_loop_policy')
def test_broadcast(self, _coordinator):
messages = []

Expand All @@ -123,6 +128,7 @@ def on_broadcast_receive(**msg):

class TestRemoteProcessThreadController:
@pytest.mark.asyncio
@pytest.mark.usefixtures('custom_event_loop_policy')
async def test_pause(self, _coordinator, sync_controller):
proc = utils.WaitForSignalProcess(coordinator=_coordinator)

Expand All @@ -137,6 +143,7 @@ async def test_pause(self, _coordinator, sync_controller):
assert proc.paused

@pytest.mark.asyncio
@pytest.mark.usefixtures('custom_event_loop_policy')
async def test_pause_all(self, _coordinator, sync_controller):
"""Test pausing all processes on a communicator"""
procs = []
Expand All @@ -148,6 +155,7 @@ async def test_pause_all(self, _coordinator, sync_controller):
await utils.wait_util(lambda: all([proc.paused for proc in procs]))

@pytest.mark.asyncio
@pytest.mark.usefixtures('custom_event_loop_policy')
async def test_play_all(self, _coordinator, sync_controller):
"""Test pausing all processes on a communicator"""
procs = []
Expand All @@ -162,6 +170,7 @@ async def test_play_all(self, _coordinator, sync_controller):
await utils.wait_util(lambda: all([not proc.paused for proc in procs]))

@pytest.mark.asyncio
@pytest.mark.usefixtures('custom_event_loop_policy')
async def test_play(self, _coordinator, sync_controller):
proc = utils.WaitForSignalProcess(coordinator=_coordinator)
assert proc.pause()
Expand All @@ -176,6 +185,7 @@ async def test_play(self, _coordinator, sync_controller):
assert proc.state_label == plumpy.ProcessState.CREATED

@pytest.mark.asyncio
@pytest.mark.usefixtures('custom_event_loop_policy')
async def test_kill(self, _coordinator, sync_controller):
proc = utils.WaitForSignalProcess(coordinator=_coordinator)

Expand All @@ -190,6 +200,7 @@ async def test_kill(self, _coordinator, sync_controller):
assert proc.state_label == plumpy.ProcessState.KILLED

@pytest.mark.asyncio
@pytest.mark.usefixtures('custom_event_loop_policy')
async def test_kill_all(self, _coordinator, sync_controller):
"""Test pausing all processes on a communicator"""
procs = []
Expand All @@ -201,6 +212,7 @@ async def test_kill_all(self, _coordinator, sync_controller):
assert all([proc.state_label == plumpy.ProcessState.KILLED for proc in procs])

@pytest.mark.asyncio
@pytest.mark.usefixtures('custom_event_loop_policy')
async def test_status(self, _coordinator, sync_controller):
proc = utils.WaitForSignalProcess(coordinator=_coordinator)
# Run the process in the background
Expand Down
3 changes: 3 additions & 0 deletions tests/test_event_helper.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
# -*- coding: utf-8 -*-
import pytest

from plumpy.event_helper import EventHelper
from plumpy.persistence import Savable, load
from tests.utils import DummyProcess, ProcessListenerTester


@pytest.mark.usefixtures('custom_event_loop_policy')
def test_event_helper_savable():
eh = EventHelper(ProcessListenerTester)

Expand Down
3 changes: 3 additions & 0 deletions tests/test_message.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# -*- coding: utf-8 -*-
import asyncio
import pytest

import plumpy
Expand Down Expand Up @@ -26,6 +27,7 @@ def identify_object(self, obj):


@pytest.mark.asyncio
@pytest.mark.usefixtures('custom_event_loop_policy')
async def test_continue():
persister = plumpy.InMemoryPersister()
load_context = plumpy.LoadSaveContext()
Expand All @@ -42,6 +44,7 @@ async def test_continue():


@pytest.mark.asyncio
@pytest.mark.usefixtures('custom_event_loop_policy')
async def test_loader_is_used():
"""Make sure that the provided class loader is used by the process launcher"""
loader = CustomObjectLoader()
Expand Down
5 changes: 5 additions & 0 deletions tests/test_processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def test_process_is_savable():


@pytest.mark.asyncio
@pytest.mark.usefixtures('custom_event_loop_policy')
async def test_process_scope():
class ProcessTaskInterleave(plumpy.Process):
async def task(self, steps: list):
Expand All @@ -70,6 +71,7 @@ async def task(self, steps: list):


class TestProcess:
@pytest.mark.usefixtures('custom_event_loop_policy')
def test_spec(self):
"""
Check that the references to specs are doing the right thing...
Expand Down Expand Up @@ -392,6 +394,7 @@ async def async_test():
loop.create_task(proc.step_until_terminated())
loop.run_until_complete(async_test())

@pytest.mark.usefixtures('custom_event_loop_policy')
def test_pause_play_status_messaging(self):
"""
Test the setting of a processes' status through pause and play works correctly.
Expand Down Expand Up @@ -619,6 +622,7 @@ def run(self):

assert len(expect_true) == n_run * 3

@pytest.mark.usefixtures('custom_event_loop_policy')
def test_process_nested(self):
"""
Run multiple and nested processes to make sure the process stack is always correct
Expand All @@ -634,6 +638,7 @@ def run(self):

ParentProcess().execute()

@pytest.mark.usefixtures('custom_event_loop_policy')
def test_call_soon(self):
class CallSoon(plumpy.Process):
def run(self):
Expand Down

0 comments on commit 8465edb

Please sign in to comment.