Skip to content

Commit

Permalink
Adapt after merge the message protocol PR
Browse files Browse the repository at this point in the history
  • Loading branch information
unkcpz committed Jan 8, 2025
1 parent b9dd887 commit 4a462a0
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 12 deletions.
10 changes: 6 additions & 4 deletions src/plumpy/process_comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,15 +229,17 @@ async def play_process(self, pid: 'PID_TYPE') -> 'ProcessResult':
result = await asyncio.wrap_future(future)
return result

async def kill_process(self, pid: 'PID_TYPE', msg_text: Optional[str] = None) -> 'ProcessResult':
async def kill_process(
self, pid: 'PID_TYPE', msg_text: Optional[str] = None, force_kill: bool = False
) -> 'ProcessResult':
"""
Kill the process
:param pid: the pid of the process to kill
:param msg: optional kill message
:return: True if killed, False otherwise
"""
msg = MessageBuilder.kill(text=msg_text)
msg = MessageBuilder.kill(text=msg_text, force_kill=force_kill)

# Wait for the communication to go through
kill_future = self._communicator.rpc_send(pid, msg)
Expand Down Expand Up @@ -401,15 +403,15 @@ def play_all(self) -> None:
"""
self._communicator.broadcast_send(None, subject=Intent.PLAY)

def kill_process(self, pid: 'PID_TYPE', msg_text: Optional[str] = None) -> kiwipy.Future:
def kill_process(self, pid: 'PID_TYPE', msg_text: Optional[str] = None, force_kill: bool = False) -> kiwipy.Future:
"""
Kill the process
:param pid: the pid of the process to kill
:param msg: optional kill message
:return: a response future from the process to be killed
"""
msg = MessageBuilder.kill(text=msg_text)
msg = MessageBuilder.kill(text=msg_text, force_kill=force_kill)
return self._communicator.rpc_send(pid, msg)

def kill_all(self, msg_text: Optional[str]) -> None:
Expand Down
6 changes: 5 additions & 1 deletion src/plumpy/process_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,11 @@ def __init__(self, msg_text: str | None):


class ForceKillInterruption(Interruption):
pass
def __init__(self, msg_text: str | None):
super().__init__()
msg = MessageBuilder.kill(text=msg_text)

self.msg: MessageType = msg


class PauseInterruption(Interruption):
Expand Down
34 changes: 27 additions & 7 deletions src/plumpy/processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -965,7 +965,11 @@ def message_receive(self, _comm: kiwipy.Communicator, msg: MessageType) -> Any:
if intent == process_comms.Intent.PAUSE:
return self._schedule_rpc(self.pause, msg_text=msg.get(process_comms.MESSAGE_TEXT_KEY, None))
if intent == process_comms.Intent.KILL:
return self._schedule_rpc(self.kill, msg_text=msg.get(process_comms.MESSAGE_TEXT_KEY, None))
return self._schedule_rpc(
self.kill,
msg_text=msg.get(process_comms.MESSAGE_TEXT_KEY, None),
force_kill=msg.get(process_comms.FORCE_KILL_KEY),
)
if intent == process_comms.Intent.STATUS:
status_info: Dict[str, Any] = {}
self.get_status_info(status_info)
Expand Down Expand Up @@ -998,7 +1002,11 @@ def broadcast_receive(
if subject == process_comms.Intent.PAUSE:
return self._schedule_rpc(self.pause, msg_text=msg.get(process_comms.MESSAGE_TEXT_KEY, None))
if subject == process_comms.Intent.KILL:
return self._schedule_rpc(self.kill, msg_text=msg.get(process_comms.MESSAGE_TEXT_KEY, None))
return self._schedule_rpc(
self.kill,
msg_text=msg.get(process_comms.MESSAGE_TEXT_KEY, None),
force_kill=msg.get(process_comms.FORCE_KILL_KEY),
)
return None

def _schedule_rpc(self, callback: Callable[..., Any], *args: Any, **kwargs: Any) -> kiwipy.Future:
Expand Down Expand Up @@ -1222,12 +1230,12 @@ def fail(self, exception: Optional[BaseException], trace_back: Optional[Tracebac
)
self.transition_to(new_state)

def kill(self, msg_text: Optional[str] = None) -> Union[bool, asyncio.Future]:
def kill(self, msg_text: Optional[str] = None, force_kill: bool = False) -> Union[bool, asyncio.Future]:
"""
Kill the process
:param msg: An optional kill message
:param force_kill: An optional whether force kill the process
"""
force_kill = isinstance(msg, str) and '-F' in msg

if self.state == process_states.ProcessState.KILLED:
# Already killed
Expand All @@ -1243,20 +1251,32 @@ def kill(self, msg_text: Optional[str] = None) -> Union[bool, asyncio.Future]:

if force_kill:
# Skip interrupting the state and go straight to killed
interrupt_exception = process_states.ForceKillInterruption(msg)
interrupt_exception = process_states.ForceKillInterruption(msg_text)
# XXX: this line was not in ali's PR but to make the change align with _stepping,
# it seems it is needed to set the _interrupt_action to be used line after.
# Requires more check to test with aiida-core's PR.
#
# self._set_interrupt_action_from_exception(interrupt_exception)
#
self._killing = self._interrupt_action
self._state.interrupt(interrupt_exception)

elif self._stepping:
msg = MessageBuilder.kill(msg_text, force_kill=True)
new_state = self._create_state_instance(process_states.ProcessState.KILLED, msg=msg)
self.transition_to(new_state)
return True

if self._stepping:
# Ask the step function to pause by setting this flag and giving the
# caller back a future
interrupt_exception = process_states.KillInterruption(msg_text)
self._set_interrupt_action_from_exception(interrupt_exception)
self._killing = self._interrupt_action
self._state.interrupt(interrupt_exception)

return cast(futures.CancellableAction, self._interrupt_action)

msg = MessageBuilder.kill(msg_text)
msg = MessageBuilder.kill(msg_text, force_kill=False)
new_state = self._create_state_instance(process_states.ProcessState.KILLED, msg=msg)
self.transition_to(new_state)
return True
Expand Down

0 comments on commit 4a462a0

Please sign in to comment.