diff --git a/src/plumpy/base/state_machine.py b/src/plumpy/base/state_machine.py index be27e0cd..681858f0 100644 --- a/src/plumpy/base/state_machine.py +++ b/src/plumpy/base/state_machine.py @@ -325,6 +325,8 @@ def transition_to(self, new_state: State | None, **kwargs: Any) -> None: assert not self._transitioning, 'Cannot call transition_to when already transitioning state' if new_state is None: + # early return if the new state is `None` + # it can happened when transit from terminal state return None initial_state_label = self._state.LABEL if self._state is not None else None @@ -411,8 +413,10 @@ def _enter_next_state(self, next_state: State) -> None: self._state = next_state self._fire_state_event(StateEventHook.ENTERED_STATE, last_state) - def _create_state_instance(self, state_cls: type[State], **kwargs: Any) -> State: - if state_cls.LABEL not in self.get_states_map(): - raise ValueError(f'{state_cls.LABEL} is not a valid state') + def _create_state_instance(self, state_cls: Hashable, **kwargs: Any) -> State: + if state_cls not in self.get_states_map(): + raise ValueError(f'{state_cls} is not a valid state') - return state_cls(self, **kwargs) + cls = self.get_states_map()[state_cls] + + return cls(self, **kwargs) diff --git a/src/plumpy/process_comms.py b/src/plumpy/process_comms.py index cd6e7238..9558b2db 100644 --- a/src/plumpy/process_comms.py +++ b/src/plumpy/process_comms.py @@ -48,6 +48,8 @@ class Intent: class PlayMessage: + """The play message send over communicator.""" + @classmethod def build(cls, message: str | None = None) -> MessageType: return { @@ -57,6 +59,8 @@ def build(cls, message: str | None = None) -> MessageType: class PauseMessage: + """The pause message send over communicator.""" + @classmethod def build(cls, message: str | None = None) -> MessageType: return { @@ -66,16 +70,20 @@ def build(cls, message: str | None = None) -> MessageType: class KillMessage: + """The kill message send over communicator.""" + @classmethod - def build(cls, message: str | None = None, force: bool = False) -> MessageType: + def build(cls, message: str | None = None, force_kill: bool = False) -> MessageType: return { INTENT_KEY: Intent.KILL, MESSAGE_KEY: message, - FORCE_KILL_KEY: force, + FORCE_KILL_KEY: force_kill, } class StatusMessage: + """The status message send over communicator.""" + @classmethod def build(cls, message: str | None = None) -> MessageType: return { diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index b762b672..f846c052 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -1066,8 +1066,9 @@ def transition_failed( if final_state == process_states.ProcessState.CREATED: raise exception.with_traceback(trace) - state_class = self.get_states_map()[process_states.ProcessState.EXCEPTED] - new_state = self._create_state_instance(state_class, exception=exception, trace_back=trace) + new_state = self._create_state_instance( + process_states.ProcessState.EXCEPTED, exception=exception, trace_back=trace + ) self.transition_to(new_state) def pause(self, msg: Union[str, None] = None) -> Union[bool, futures.CancellableAction]: @@ -1131,8 +1132,7 @@ def _create_interrupt_action(self, exception: process_states.Interruption) -> fu def do_kill(_next_state: process_states.State) -> Any: try: - state_class = self.get_states_map()[process_states.ProcessState.KILLED] - new_state = self._create_state_instance(state_class, msg=exception.msg) + new_state = self._create_state_instance(process_states.ProcessState.KILLED, msg=exception.msg) self.transition_to(new_state) return True finally: @@ -1185,8 +1185,9 @@ def fail(self, exception: Optional[BaseException], trace_back: Optional[Tracebac :param exception: The exception that caused the failure :param trace_back: Optional exception traceback """ - state_class = self.get_states_map()[process_states.ProcessState.EXCEPTED] - new_state = self._create_state_instance(state_class, exception=exception, trace_back=trace_back) + new_state = self._create_state_instance( + process_states.ProcessState.EXCEPTED, exception=exception, trace_back=trace_back + ) self.transition_to(new_state) def kill(self, msg: Optional[MessageType] = None) -> Union[bool, asyncio.Future]: @@ -1215,8 +1216,7 @@ def kill(self, msg: Optional[MessageType] = None) -> Union[bool, asyncio.Future] self._state.interrupt(interrupt_exception) return cast(futures.CancellableAction, self._interrupt_action) - state_class = self.get_states_map()[process_states.ProcessState.KILLED] - new_state = self._create_state_instance(state_class, msg=msg) + new_state = self._create_state_instance(process_states.ProcessState.KILLED, msg=msg) self.transition_to(new_state) return True