Skip to content

Commit

Permalink
Merge branch 'master' into force-kill
Browse files Browse the repository at this point in the history
  • Loading branch information
unkcpz authored Jan 8, 2025
2 parents f1f8095 + ecef9b9 commit b9dd887
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 59 deletions.
4 changes: 1 addition & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,9 @@ classifiers = [
keywords = ['workflow', 'multithreaded', 'rabbitmq']
requires-python = '>=3.8'
dependencies = [
'kiwipy[rmq]~=0.8.3',
'kiwipy[rmq]~=0.8.5',
'nest_asyncio~=1.5,>=1.5.1',
'pyyaml~=6.0',
# XXX: workaround for https://github.com/mosquito/aio-pika/issues/649
'typing-extensions~=4.12',
]

[project.urls]
Expand Down
38 changes: 17 additions & 21 deletions src/plumpy/process_comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
ProcessStatus = Any

INTENT_KEY = 'intent'
MESSAGE_KEY = 'message'
MESSAGE_TEXT_KEY = 'message'
FORCE_KILL_KEY = 'force_kill'


Expand All @@ -52,23 +52,23 @@ def play(cls, text: str | None = None) -> MessageType:
"""The play message send over communicator."""
return {
INTENT_KEY: Intent.PLAY,
MESSAGE_KEY: text,
MESSAGE_TEXT_KEY: text,
}

@classmethod
def pause(cls, text: str | None = None) -> MessageType:
"""The pause message send over communicator."""
return {
INTENT_KEY: Intent.PAUSE,
MESSAGE_KEY: text,
MESSAGE_TEXT_KEY: text,
}

@classmethod
def kill(cls, text: str | None = None, force_kill: bool = False) -> MessageType:
"""The kill message send over communicator."""
return {
INTENT_KEY: Intent.KILL,
MESSAGE_KEY: text,
MESSAGE_TEXT_KEY: text,
FORCE_KILL_KEY: force_kill,
}

Expand All @@ -77,7 +77,7 @@ def status(cls, text: str | None = None) -> MessageType:
"""The status message send over communicator."""
return {
INTENT_KEY: Intent.STATUS,
MESSAGE_KEY: text,
MESSAGE_TEXT_KEY: text,
}


Expand Down Expand Up @@ -200,15 +200,15 @@ async def get_status(self, pid: 'PID_TYPE') -> 'ProcessStatus':
result = await asyncio.wrap_future(future)
return result

async def pause_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> 'ProcessResult':
async def pause_process(self, pid: 'PID_TYPE', msg_text: Optional[str] = None) -> 'ProcessResult':
"""
Pause the process
:param pid: the pid of the process to pause
:param msg: optional pause message
:return: True if paused, False otherwise
"""
msg = MessageBuilder.pause(text=msg)
msg = MessageBuilder.pause(text=msg_text)

pause_future = self._communicator.rpc_send(pid, msg)
# rpc_send return a thread future from communicator
Expand All @@ -229,16 +229,15 @@ async def play_process(self, pid: 'PID_TYPE') -> 'ProcessResult':
result = await asyncio.wrap_future(future)
return result

async def kill_process(self, pid: 'PID_TYPE', msg: Optional[MessageType] = None) -> 'ProcessResult':
async def kill_process(self, pid: 'PID_TYPE', msg_text: Optional[str] = None) -> 'ProcessResult':
"""
Kill the process
:param pid: the pid of the process to kill
:param msg: optional kill message
:return: True if killed, False otherwise
"""
if msg is None:
msg = MessageBuilder.kill()
msg = MessageBuilder.kill(text=msg_text)

# Wait for the communication to go through
kill_future = self._communicator.rpc_send(pid, msg)
Expand Down Expand Up @@ -364,7 +363,7 @@ def get_status(self, pid: 'PID_TYPE') -> kiwipy.Future:
"""
return self._communicator.rpc_send(pid, MessageBuilder.status())

def pause_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> kiwipy.Future:
def pause_process(self, pid: 'PID_TYPE', msg_text: Optional[str] = None) -> kiwipy.Future:
"""
Pause the process
Expand All @@ -373,16 +372,17 @@ def pause_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> kiwipy.Fu
:return: a response future from the process to be paused
"""
msg = MessageBuilder.pause(text=msg)
msg = MessageBuilder.pause(text=msg_text)

return self._communicator.rpc_send(pid, msg)

def pause_all(self, msg: Any) -> None:
def pause_all(self, msg_text: Optional[str]) -> None:
"""
Pause all processes that are subscribed to the same communicator
:param msg: an optional pause message
"""
msg = MessageBuilder.pause(text=msg_text)
self._communicator.broadcast_send(msg, subject=Intent.PAUSE)

def play_process(self, pid: 'PID_TYPE') -> kiwipy.Future:
Expand All @@ -401,28 +401,24 @@ def play_all(self) -> None:
"""
self._communicator.broadcast_send(None, subject=Intent.PLAY)

def kill_process(self, pid: 'PID_TYPE', msg: Optional[MessageType] = None) -> kiwipy.Future:
def kill_process(self, pid: 'PID_TYPE', msg_text: Optional[str] = None) -> kiwipy.Future:
"""
Kill the process
:param pid: the pid of the process to kill
:param msg: optional kill message
:return: a response future from the process to be killed
"""
if msg is None:
msg = MessageBuilder.kill()

msg = MessageBuilder.kill(text=msg_text)
return self._communicator.rpc_send(pid, msg)

def kill_all(self, msg: Optional[MessageType]) -> None:
def kill_all(self, msg_text: Optional[str]) -> None:
"""
Kill all processes that are subscribed to the same communicator
:param msg: an optional pause message
"""
if msg is None:
msg = MessageBuilder.kill()
msg = MessageBuilder.kill(msg_text)

self._communicator.broadcast_send(msg, subject=Intent.KILL)

Expand Down
11 changes: 7 additions & 4 deletions src/plumpy/process_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,9 @@ class Interruption(Exception): # noqa: N818


class KillInterruption(Interruption):
def __init__(self, msg: MessageType | None):
def __init__(self, msg_text: str | None):
super().__init__()
if msg is None:
msg = MessageBuilder.kill()
msg = MessageBuilder.kill(text=msg_text)

self.msg: MessageType = msg

Expand All @@ -66,7 +65,11 @@ class ForceKillInterruption(Interruption):


class PauseInterruption(Interruption):
pass
def __init__(self, msg_text: str | None):
super().__init__()
msg = MessageBuilder.pause(text=msg_text)

self.msg: MessageType = msg


# region Commands
Expand Down
81 changes: 57 additions & 24 deletions src/plumpy/processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
from .base.state_machine import StateEntryFailed, StateMachine, TransitionFailed, event
from .base.utils import call_with_super_check, super_check
from .event_helper import EventHelper
from .process_comms import MESSAGE_KEY, MessageBuilder, MessageType
from .process_comms import MESSAGE_TEXT_KEY, MessageBuilder, MessageType
from .process_listener import ProcessListener
from .process_spec import ProcessSpec
from .utils import PID_TYPE, SAVED_STATE_TYPE, protected
Expand Down Expand Up @@ -344,8 +344,7 @@ def init(self) -> None:

def try_killing(future: futures.Future) -> None:
if future.cancelled():
msg = MessageBuilder.kill(text='Killed by future being cancelled')
if not self.kill(msg):
if not self.kill('Killed by future being cancelled'):
self.logger.warning(
'Process<%s>: Failed to kill process on future cancel',
self.pid,
Expand Down Expand Up @@ -903,7 +902,7 @@ def on_kill(self, msg: Optional[MessageType]) -> None:
if msg is None:
msg_txt = ''
else:
msg_txt = msg[MESSAGE_KEY] or ''
msg_txt = msg[MESSAGE_TEXT_KEY] or ''

self.set_status(msg_txt)
self.future().set_exception(exceptions.KilledError(msg_txt))
Expand Down Expand Up @@ -944,7 +943,7 @@ def _fire_event(self, evt: Callable[..., Any], *args: Any, **kwargs: Any) -> Non

# region Communication

def message_receive(self, _comm: kiwipy.Communicator, msg: Dict[str, Any]) -> Any:
def message_receive(self, _comm: kiwipy.Communicator, msg: MessageType) -> Any:
"""
Coroutine called when the process receives a message from the communicator
Expand All @@ -964,9 +963,9 @@ def message_receive(self, _comm: kiwipy.Communicator, msg: Dict[str, Any]) -> An
if intent == process_comms.Intent.PLAY:
return self._schedule_rpc(self.play)
if intent == process_comms.Intent.PAUSE:
return self._schedule_rpc(self.pause, msg=msg.get(process_comms.MESSAGE_KEY, None))
return self._schedule_rpc(self.pause, msg_text=msg.get(process_comms.MESSAGE_TEXT_KEY, None))
if intent == process_comms.Intent.KILL:
return self._schedule_rpc(self.kill, msg=msg)
return self._schedule_rpc(self.kill, msg_text=msg.get(process_comms.MESSAGE_TEXT_KEY, None))
if intent == process_comms.Intent.STATUS:
status_info: Dict[str, Any] = {}
self.get_status_info(status_info)
Expand All @@ -976,7 +975,7 @@ def message_receive(self, _comm: kiwipy.Communicator, msg: Dict[str, Any]) -> An
raise RuntimeError('Unknown intent')

def broadcast_receive(
self, _comm: kiwipy.Communicator, body: Any, sender: Any, subject: Any, correlation_id: Any
self, _comm: kiwipy.Communicator, msg: MessageType, sender: Any, subject: Any, correlation_id: Any
) -> Optional[kiwipy.Future]:
"""
Coroutine called when the process receives a message from the communicator
Expand All @@ -990,16 +989,16 @@ def broadcast_receive(
self.pid,
subject,
_comm,
body,
msg,
)

# If we get a message we recognise then action it, otherwise ignore
if subject == process_comms.Intent.PLAY:
return self._schedule_rpc(self.play)
if subject == process_comms.Intent.PAUSE:
return self._schedule_rpc(self.pause, msg=body)
return self._schedule_rpc(self.pause, msg_text=msg.get(process_comms.MESSAGE_TEXT_KEY, None))
if subject == process_comms.Intent.KILL:
return self._schedule_rpc(self.kill, msg=body)
return self._schedule_rpc(self.kill, msg_text=msg.get(process_comms.MESSAGE_TEXT_KEY, None))
return None

def _schedule_rpc(self, callback: Callable[..., Any], *args: Any, **kwargs: Any) -> kiwipy.Future:
Expand All @@ -1021,11 +1020,37 @@ def _schedule_rpc(self, callback: Callable[..., Any], *args: Any, **kwargs: Any)

async def run_callback() -> None:
with kiwipy.capture_exceptions(kiwi_future):
result = callback(*args, **kwargs)
while asyncio.isfuture(result):
result = await result
try:
result = callback(*args, **kwargs)
except Exception as exc:
import inspect
import traceback

# Get traceback as a string
tb_str = ''.join(traceback.format_exception(type(exc), exc, exc.__traceback__))

# Attempt to get file and line number where the callback is defined
# Note: This might fail for certain built-in or dynamically generated functions.
# If it fails, just skip that part.
try:
source_file = inspect.getfile(callback)
# getsourcelines returns a tuple (list_of_source_lines, starting_line_number)
_, start_line = inspect.getsourcelines(callback)
callback_location = f'{source_file}:{start_line}'
except Exception:
callback_location = '<unknown location>'

# Include the callback name, file/line info, and the full traceback in the message
raise RuntimeError(
f"Error invoking callback '{callback.__name__}' at {callback_location}.\n"
f'Exception: {type(exc).__name__}: {exc}\n\n'
f'Full Traceback:\n{tb_str}'
) from exc
else:
while asyncio.isfuture(result):
result = await result

kiwi_future.set_result(result)
kiwi_future.set_result(result)

# Schedule the task and give back a kiwi future
asyncio.run_coroutine_threadsafe(run_callback(), self.loop)
Expand Down Expand Up @@ -1071,7 +1096,7 @@ def transition_failed(
)
self.transition_to(new_state)

def pause(self, msg: Union[str, None] = None) -> Union[bool, futures.CancellableAction]:
def pause(self, msg_text: Optional[str] = None) -> Union[bool, futures.CancellableAction]:
"""Pause the process.
:param msg: an optional message to set as the status. The current status will be saved in the private
Expand All @@ -1095,22 +1120,29 @@ def pause(self, msg: Union[str, None] = None) -> Union[bool, futures.Cancellable
if self._stepping:
# Ask the step function to pause by setting this flag and giving the
# caller back a future
interrupt_exception = process_states.PauseInterruption(msg)
interrupt_exception = process_states.PauseInterruption(msg_text)
self._set_interrupt_action_from_exception(interrupt_exception)
self._pausing = self._interrupt_action
# Try to interrupt the state
self._state.interrupt(interrupt_exception)
return cast(futures.CancellableAction, self._interrupt_action)

return self._do_pause(msg)
msg = MessageBuilder.pause(msg_text)
return self._do_pause(state_msg=msg)

def _do_pause(self, state_msg: Optional[str], next_state: Optional[process_states.State] = None) -> bool:
def _do_pause(self, state_msg: Optional[MessageType], next_state: Optional[process_states.State] = None) -> bool:
"""Carry out the pause procedure, optionally transitioning to the next state first"""
try:
if next_state is not None:
self.transition_to(next_state)
call_with_super_check(self.on_pausing, state_msg)
call_with_super_check(self.on_paused, state_msg)

if state_msg is None:
msg_text = ''
else:
msg_text = state_msg[MESSAGE_TEXT_KEY]

call_with_super_check(self.on_pausing, msg_text)
call_with_super_check(self.on_paused, msg_text)
finally:
self._pausing = None

Expand All @@ -1125,7 +1157,7 @@ def _create_interrupt_action(self, exception: process_states.Interruption) -> fu
"""
if isinstance(exception, process_states.PauseInterruption):
do_pause = functools.partial(self._do_pause, str(exception))
do_pause = functools.partial(self._do_pause, exception.msg)
return futures.CancellableAction(do_pause, cookie=exception)

if isinstance(exception, (process_states.KillInterruption, process_states.ForceKillInterruption)):
Expand Down Expand Up @@ -1190,7 +1222,7 @@ def fail(self, exception: Optional[BaseException], trace_back: Optional[Tracebac
)
self.transition_to(new_state)

def kill(self, msg: Optional[MessageType] = None) -> Union[bool, asyncio.Future]:
def kill(self, msg_text: Optional[str] = None) -> Union[bool, asyncio.Future]:
"""
Kill the process
:param msg: An optional kill message
Expand Down Expand Up @@ -1218,12 +1250,13 @@ def kill(self, msg: Optional[MessageType] = None) -> Union[bool, asyncio.Future]
elif self._stepping:
# Ask the step function to pause by setting this flag and giving the
# caller back a future
interrupt_exception = process_states.KillInterruption(msg) # type: ignore
interrupt_exception = process_states.KillInterruption(msg_text)
self._set_interrupt_action_from_exception(interrupt_exception)
self._killing = self._interrupt_action
self._state.interrupt(interrupt_exception)
return cast(futures.CancellableAction, self._interrupt_action)

msg = MessageBuilder.kill(msg_text)
new_state = self._create_state_instance(process_states.ProcessState.KILLED, msg=msg)
self.transition_to(new_state)
return True
Expand Down
4 changes: 1 addition & 3 deletions tests/rmq/test_process_comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,7 @@ async def test_kill_all(self, thread_communicator, sync_controller):
for _ in range(10):
procs.append(utils.WaitForSignalProcess(communicator=thread_communicator))

msg = process_comms.MessageBuilder.kill(text='bang bang, I shot you down')

sync_controller.kill_all(msg)
sync_controller.kill_all(msg_text='bang bang, I shot you down')
await utils.wait_util(lambda: all([proc.killed() for proc in procs]))
assert all([proc.state == plumpy.ProcessState.KILLED for proc in procs])

Expand Down
Loading

0 comments on commit b9dd887

Please sign in to comment.