diff --git a/docs/source/nitpick-exceptions b/docs/source/nitpick-exceptions index 2f354987..e1d6d969 100644 --- a/docs/source/nitpick-exceptions +++ b/docs/source/nitpick-exceptions @@ -18,7 +18,7 @@ py:class kiwipy.communications.Communicator # unavailable forward references py:class plumpy.process_states.Command -py:class plumpy.process_states.State +py:class plumpy.state_machine.State py:class plumpy.base.state_machine.State py:class State py:class Process diff --git a/docs/source/tutorial.ipynb b/docs/source/tutorial.ipynb index c1fdb3b2..af1ed795 100644 --- a/docs/source/tutorial.ipynb +++ b/docs/source/tutorial.ipynb @@ -281,7 +281,7 @@ " def continue_fn(self):\n", " print('continuing')\n", " # message is stored in the process status\n", - " return plumpy.Kill('I was killed')\n", + " return plumpy.Kill(plumpy.KillMessage.build('I was killed'))\n", "\n", "\n", "process = ContinueProcess()\n", @@ -1118,7 +1118,7 @@ "\n", "process = SimpleProcess(communicator=communicator)\n", "\n", - "pprint(communicator.rpc_send(str(process.pid), plumpy.STATUS_MSG).result())" + "pprint(communicator.rpc_send(str(process.pid), plumpy.StatusMessage.build()).result())" ] }, { diff --git a/src/plumpy/base/state_machine.py b/src/plumpy/base/state_machine.py index d99d0705..fc926008 100644 --- a/src/plumpy/base/state_machine.py +++ b/src/plumpy/base/state_machine.py @@ -1,6 +1,8 @@ # -*- coding: utf-8 -*- """The state machine for processes""" +from __future__ import annotations + import enum import functools import inspect @@ -8,7 +10,21 @@ import os import sys from types import TracebackType -from typing import Any, Callable, Dict, Hashable, Iterable, List, Optional, Sequence, Set, Type, Union, cast +from typing import ( + Any, + Callable, + ClassVar, + Dict, + Hashable, + Iterable, + List, + Optional, + Protocol, + Sequence, + Type, + Union, + runtime_checkable, +) from plumpy.futures import Future @@ -18,7 +34,6 @@ _LOGGER = logging.getLogger(__name__) -LABEL_TYPE = Union[None, enum.Enum, str] EVENT_CALLBACK_TYPE = Callable[['StateMachine', Hashable, Optional['State']], None] @@ -31,7 +46,7 @@ class StateEntryFailed(Exception): # noqa: N818 Failed to enter a state, can provide the next state to go to via this exception """ - def __init__(self, state: Hashable = None, *args: Any, **kwargs: Any) -> None: + def __init__(self, state: State, *args: Any, **kwargs: Any) -> None: super().__init__('failed to enter state') self.state = state self.args = args @@ -74,12 +89,12 @@ def event( if from_states != '*': if inspect.isclass(from_states): from_states = (from_states,) - if not all(issubclass(state, State) for state in from_states): # type: ignore + if not all(isinstance(state, State) for state in from_states): # type: ignore raise TypeError(f'from_states: {from_states}') if to_states != '*': if inspect.isclass(to_states): to_states = (to_states,) - if not all(issubclass(state, State) for state in to_states): # type: ignore + if not all(isinstance(state, State) for state in to_states): # type: ignore raise TypeError(f'to_states: {to_states}') def wrapper(wrapped: Callable[..., Any]) -> Callable[..., Any]: @@ -113,57 +128,40 @@ def transition(self: Any, *a: Any, **kw: Any) -> Any: return wrapper -class State: - LABEL: LABEL_TYPE = None - # A set containing the labels of states that can be entered - # from this one - ALLOWED: Set[LABEL_TYPE] = set() +@runtime_checkable +class State(Protocol): + LABEL: ClassVar[Any] + ALLOWED: ClassVar[set[Any]] + is_terminal: ClassVar[bool] - @classmethod - def is_terminal(cls) -> bool: - return not cls.ALLOWED + def __init__(self, *args: Any, **kwargs: Any): ... - def __init__(self, state_machine: 'StateMachine', *args: Any, **kwargs: Any): - """ - :param state_machine: The process this state belongs to - """ - self.state_machine = state_machine - self.in_state: bool = False + def enter(self) -> None: ... - def __str__(self) -> str: - return str(self.LABEL) + def exit(self) -> None: ... - @property - def label(self) -> LABEL_TYPE: - """Convenience property to get the state label""" - return self.LABEL - @super_check - def enter(self) -> None: - """Entering the state""" +@runtime_checkable +class Interruptable(Protocol): + def interrupt(self, reason: Exception) -> None: ... + - def execute(self) -> Optional['State']: +@runtime_checkable +class Proceedable(Protocol): + def execute(self) -> State | None: """ Execute the state, performing the actions that this state is responsible for. :returns: a state to transition to or None if finished. """ + ... - @super_check - def exit(self) -> None: - """Exiting the state""" - if self.is_terminal(): - raise InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') - def create_state(self, state_label: Hashable, *args: Any, **kwargs: Any) -> 'State': - return self.state_machine.create_state(state_label, *args, **kwargs) +def create_state(st: StateMachine, state_label: Hashable, *args: Any, **kwargs: Any) -> State: + if state_label not in st.get_states_map(): + raise ValueError(f'{state_label} is not a valid state') - def do_enter(self) -> None: - call_with_super_check(self.enter) - self.in_state = True - - def do_exit(self) -> None: - call_with_super_check(self.exit) - self.in_state = False + state_cls = st.get_states_map()[state_label] + return state_cls(*args, **kwargs) class StateEventHook(enum.Enum): @@ -187,7 +185,7 @@ def __call__(cls, *args: Any, **kwargs: Any) -> 'StateMachine': :param kwargs: Any keyword arguments to be passed to the constructor :return: An instance of the state machine """ - inst = super().__call__(*args, **kwargs) + inst: StateMachine = super().__call__(*args, **kwargs) inst.transition_to(inst.create_initial_state()) call_with_super_check(inst.init) return inst @@ -214,13 +212,13 @@ def get_states(cls) -> Sequence[Type[State]]: raise RuntimeError('States not defined') @classmethod - def initial_state_label(cls) -> LABEL_TYPE: + def initial_state_label(cls) -> Any: cls.__ensure_built() assert cls.STATES is not None return cls.STATES[0].LABEL @classmethod - def get_state_class(cls, label: LABEL_TYPE) -> Type[State]: + def get_state_class(cls, label: Any) -> Type[State]: cls.__ensure_built() assert cls._STATES_MAP is not None return cls._STATES_MAP[label] @@ -240,7 +238,7 @@ def __ensure_built(cls) -> None: # Build the states map cls._STATES_MAP = {} for state_cls in cls.STATES: - assert issubclass(state_cls, State) + assert isinstance(state_cls, State) label = state_cls.LABEL assert label not in cls._STATES_MAP, f"Duplicate label '{label}'" cls._STATES_MAP[label] = state_cls @@ -264,11 +262,11 @@ def init(self) -> None: def __str__(self) -> str: return f'<{self.__class__.__name__}> ({self.state})' - def create_initial_state(self) -> State: - return self.get_state_class(self.initial_state_label())(self) + def create_initial_state(self, *args: Any, **kwargs: Any) -> State: + return self.get_state_class(self.initial_state_label())(self, *args, **kwargs) @property - def state(self) -> Optional[LABEL_TYPE]: + def state(self) -> Any: if self._state is None: return None return self._state.LABEL @@ -300,16 +298,24 @@ def _fire_state_event(self, hook: Hashable, state: Optional[State]) -> None: def on_terminated(self) -> None: """Called when a terminal state is entered""" - def transition_to(self, new_state: Union[Hashable, State, Type[State]], *args: Any, **kwargs: Any) -> None: + def transition_to(self, new_state: State | None, **kwargs: Any) -> None: + """Transite to the new state. + + The new target state will be create lazily when the state is not yet instantiated, + which will happened for states not in the expect path such as pause and kill. + The arguments are passed to the state class to create state instance. + (process arg does not need to pass since it will always call with 'self' as process) + """ + print(f'try: {self._state} -> {new_state}') assert not self._transitioning, 'Cannot call transition_to when already transitioning state' + if new_state is None: + return None + initial_state_label = self._state.LABEL if self._state is not None else None label = None try: self._transitioning = True - - # Make sure we have a state instance - new_state = self._create_state_instance(new_state, *args, **kwargs) label = new_state.LABEL # If the previous transition failed, do not try to exit it but go straight to next state @@ -319,13 +325,12 @@ def transition_to(self, new_state: Union[Hashable, State, Type[State]], *args: A try: self._enter_next_state(new_state) except StateEntryFailed as exception: - # Make sure we have a state instance - new_state = self._create_state_instance(exception.state, *exception.args, **exception.kwargs) + new_state = exception.state label = new_state.LABEL self._exit_current_state(new_state) self._enter_next_state(new_state) - if self._state is not None and self._state.is_terminal(): + if self._state is not None and self._state.is_terminal: call_with_super_check(self.on_terminated) except Exception: self._transitioning = False @@ -338,7 +343,11 @@ def transition_to(self, new_state: Union[Hashable, State, Type[State]], *args: A self._transitioning = False def transition_failed( - self, initial_state: Hashable, final_state: Hashable, exception: Exception, trace: TracebackType + self, + initial_state: Hashable, + final_state: Hashable, + exception: Exception, + trace: TracebackType, ) -> None: """Called when a state transitions fails. @@ -354,49 +363,25 @@ def get_debug(self) -> bool: def set_debug(self, enabled: bool) -> None: self._debug: bool = enabled - def create_state(self, state_label: Hashable, *args: Any, **kwargs: Any) -> State: - try: - return self.get_states_map()[state_label](self, *args, **kwargs) - except KeyError: - raise ValueError(f'{state_label} is not a valid state') - def _exit_current_state(self, next_state: State) -> None: """Exit the given state""" # If we're just being constructed we may not have a state yet to exit, # in which case check the new state is the initial state if self._state is None: - if next_state.label != self.initial_state_label(): + if next_state.LABEL != self.initial_state_label(): raise RuntimeError(f"Cannot enter state '{next_state}' as the initial state") return # Nothing to exit if next_state.LABEL not in self._state.ALLOWED: - raise RuntimeError(f'Cannot transition from {self._state.LABEL} to {next_state.label}') + raise RuntimeError(f'Cannot transition from {self._state.LABEL} to {next_state.LABEL}') self._fire_state_event(StateEventHook.EXITING_STATE, next_state) - self._state.do_exit() + self._state.exit() def _enter_next_state(self, next_state: State) -> None: last_state = self._state self._fire_state_event(StateEventHook.ENTERING_STATE, next_state) # Enter the new state - next_state.do_enter() + next_state.enter() self._state = next_state self._fire_state_event(StateEventHook.ENTERED_STATE, last_state) - - def _create_state_instance(self, state: Union[Hashable, State, Type[State]], *args: Any, **kwargs: Any) -> State: - if isinstance(state, State): - # It's already a state instance - return state - - # OK, have to create it - state_cls = self._ensure_state_class(state) - return state_cls(self, *args, **kwargs) - - def _ensure_state_class(self, state: Union[Hashable, Type[State]]) -> Type[State]: - if inspect.isclass(state) and issubclass(state, State): - return state - - try: - return self.get_states_map()[cast(Hashable, state)] - except KeyError: - raise ValueError(f'{state} is not a valid state') diff --git a/src/plumpy/process_comms.py b/src/plumpy/process_comms.py index 293c680b..cd6e7238 100644 --- a/src/plumpy/process_comms.py +++ b/src/plumpy/process_comms.py @@ -1,8 +1,9 @@ # -*- coding: utf-8 -*- """Module for process level communication functions and classes""" +from __future__ import annotations + import asyncio -import copy import logging from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Union, cast @@ -12,13 +13,13 @@ from .utils import PID_TYPE __all__ = [ - 'KILL_MSG', - 'PAUSE_MSG', - 'PLAY_MSG', - 'STATUS_MSG', + 'KillMessage', + 'PauseMessage', + 'PlayMessage', 'ProcessLauncher', 'RemoteProcessController', 'RemoteProcessThreadController', + 'StatusMessage', 'create_continue_body', 'create_launch_body', ] @@ -31,6 +32,7 @@ INTENT_KEY = 'intent' MESSAGE_KEY = 'message' +FORCE_KILL_KEY = 'force_kill' class Intent: @@ -42,10 +44,45 @@ class Intent: STATUS: str = 'status' -PAUSE_MSG = {INTENT_KEY: Intent.PAUSE} -PLAY_MSG = {INTENT_KEY: Intent.PLAY} -KILL_MSG = {INTENT_KEY: Intent.KILL} -STATUS_MSG = {INTENT_KEY: Intent.STATUS} +MessageType = Dict[str, Any] + + +class PlayMessage: + @classmethod + def build(cls, message: str | None = None) -> MessageType: + return { + INTENT_KEY: Intent.PLAY, + MESSAGE_KEY: message, + } + + +class PauseMessage: + @classmethod + def build(cls, message: str | None = None) -> MessageType: + return { + INTENT_KEY: Intent.PAUSE, + MESSAGE_KEY: message, + } + + +class KillMessage: + @classmethod + def build(cls, message: str | None = None, force: bool = False) -> MessageType: + return { + INTENT_KEY: Intent.KILL, + MESSAGE_KEY: message, + FORCE_KILL_KEY: force, + } + + +class StatusMessage: + @classmethod + def build(cls, message: str | None = None) -> MessageType: + return { + INTENT_KEY: Intent.STATUS, + MESSAGE_KEY: message, + } + TASK_KEY = 'task' TASK_ARGS = 'args' @@ -162,7 +199,7 @@ async def get_status(self, pid: 'PID_TYPE') -> 'ProcessStatus': :param pid: the process id :return: the status response from the process """ - future = self._communicator.rpc_send(pid, STATUS_MSG) + future = self._communicator.rpc_send(pid, StatusMessage.build()) result = await asyncio.wrap_future(future) return result @@ -174,11 +211,9 @@ async def pause_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> 'Pr :param msg: optional pause message :return: True if paused, False otherwise """ - message = copy.copy(PAUSE_MSG) - if msg is not None: - message[MESSAGE_KEY] = msg + msg = PauseMessage.build(message=msg) - pause_future = self._communicator.rpc_send(pid, message) + pause_future = self._communicator.rpc_send(pid, msg) # rpc_send return a thread future from communicator future = await asyncio.wrap_future(pause_future) # future is just returned from rpc call which return a kiwipy future @@ -192,12 +227,12 @@ async def play_process(self, pid: 'PID_TYPE') -> 'ProcessResult': :param pid: the pid of the process to play :return: True if played, False otherwise """ - play_future = self._communicator.rpc_send(pid, PLAY_MSG) + play_future = self._communicator.rpc_send(pid, PlayMessage.build()) future = await asyncio.wrap_future(play_future) result = await asyncio.wrap_future(future) return result - async def kill_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> 'ProcessResult': + async def kill_process(self, pid: 'PID_TYPE', msg: Optional[MessageType] = None) -> 'ProcessResult': """ Kill the process @@ -205,12 +240,11 @@ async def kill_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> 'Pro :param msg: optional kill message :return: True if killed, False otherwise """ - message = copy.copy(KILL_MSG) - if msg is not None: - message[MESSAGE_KEY] = msg + if msg is None: + msg = KillMessage.build() # Wait for the communication to go through - kill_future = self._communicator.rpc_send(pid, message) + kill_future = self._communicator.rpc_send(pid, msg) future = await asyncio.wrap_future(kill_future) # Now wait for the kill to be enacted result = await asyncio.wrap_future(future) @@ -331,7 +365,7 @@ def get_status(self, pid: 'PID_TYPE') -> kiwipy.Future: :param pid: the process id :return: the status response from the process """ - return self._communicator.rpc_send(pid, STATUS_MSG) + return self._communicator.rpc_send(pid, StatusMessage.build()) def pause_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> kiwipy.Future: """ @@ -342,11 +376,9 @@ def pause_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> kiwipy.Fu :return: a response future from the process to be paused """ - message = copy.copy(PAUSE_MSG) - if msg is not None: - message[MESSAGE_KEY] = msg + msg = PauseMessage.build(message=msg) - return self._communicator.rpc_send(pid, message) + return self._communicator.rpc_send(pid, msg) def pause_all(self, msg: Any) -> None: """ @@ -364,7 +396,7 @@ def play_process(self, pid: 'PID_TYPE') -> kiwipy.Future: :return: a response future from the process to be played """ - return self._communicator.rpc_send(pid, PLAY_MSG) + return self._communicator.rpc_send(pid, PlayMessage.build()) def play_all(self) -> None: """ @@ -372,7 +404,7 @@ def play_all(self) -> None: """ self._communicator.broadcast_send(None, subject=Intent.PLAY) - def kill_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> kiwipy.Future: + def kill_process(self, pid: 'PID_TYPE', msg: Optional[MessageType] = None) -> kiwipy.Future: """ Kill the process @@ -381,18 +413,20 @@ def kill_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> kiwipy.Fut :return: a response future from the process to be killed """ - message = copy.copy(KILL_MSG) - if msg is not None: - message[MESSAGE_KEY] = msg + if msg is None: + msg = KillMessage.build() - return self._communicator.rpc_send(pid, message) + return self._communicator.rpc_send(pid, msg) - def kill_all(self, msg: Optional[Any]) -> None: + def kill_all(self, msg: Optional[MessageType]) -> None: """ Kill all processes that are subscribed to the same communicator :param msg: an optional pause message """ + if msg is None: + msg = KillMessage.build() + self._communicator.broadcast_send(msg, subject=Intent.KILL) def continue_process( diff --git a/src/plumpy/process_states.py b/src/plumpy/process_states.py index 7ae6e9bd..5f3e8237 100644 --- a/src/plumpy/process_states.py +++ b/src/plumpy/process_states.py @@ -1,13 +1,30 @@ # -*- coding: utf-8 -*- +from __future__ import annotations + import sys import traceback from enum import Enum from types import TracebackType -from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple, Type, Union, cast +from typing import ( + TYPE_CHECKING, + Any, + Callable, + ClassVar, + Optional, + Protocol, + Tuple, + Type, + Union, + cast, + final, + runtime_checkable, +) import yaml from yaml.loader import Loader +from plumpy.process_comms import KillMessage, MessageType + try: import tblib @@ -16,9 +33,9 @@ _HAS_TBLIB = False from . import exceptions, futures, persistence, utils -from .base import state_machine +from .base import state_machine as st from .lang import NULL -from .persistence import auto_persist +from .persistence import LoadSaveContext, auto_persist from .utils import SAVED_STATE_TYPE __all__ = [ @@ -48,7 +65,12 @@ class Interruption(Exception): # noqa: N818 class KillInterruption(Interruption): - pass + def __init__(self, msg: MessageType | None): + super().__init__() + if msg is None: + msg = KillMessage.build() + + self.msg: MessageType = msg class PauseInterruption(Interruption): @@ -64,7 +86,7 @@ class Command(persistence.Savable): @auto_persist('msg') class Kill(Command): - def __init__(self, msg: Optional[Any] = None): + def __init__(self, msg: Optional[MessageType] = None): super().__init__() self.msg = msg @@ -76,7 +98,10 @@ class Pause(Command): @auto_persist('msg', 'data') class Wait(Command): def __init__( - self, continue_fn: Optional[Callable[..., Any]] = None, msg: Optional[Any] = None, data: Optional[Any] = None + self, + continue_fn: Optional[Callable[..., Any]] = None, + msg: Optional[Any] = None, + data: Optional[Any] = None, ): super().__init__() self.continue_fn = continue_fn @@ -108,6 +133,7 @@ def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persist def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) + self.state_machine = load_context.process try: self.continue_fn = utils.load_function(saved_state[self.CONTINUE_FN]) except ValueError: @@ -125,41 +151,32 @@ class ProcessState(Enum): The possible states that a :class:`~plumpy.processes.Process` can be in. """ - CREATED: str = 'created' - RUNNING: str = 'running' - WAITING: str = 'waiting' - FINISHED: str = 'finished' - EXCEPTED: str = 'excepted' - KILLED: str = 'killed' + # FIXME: see LSP error of return a exception, the type is Literal[str] which is invariant, tricky + CREATED = 'created' + RUNNING = 'running' + WAITING = 'waiting' + FINISHED = 'finished' + EXCEPTED = 'excepted' + KILLED = 'killed' -@auto_persist('in_state') -class State(state_machine.State, persistence.Savable): - @property - def process(self) -> state_machine.StateMachine: - """ - :return: The process - """ - return self.state_machine - - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: - super().load_instance_state(saved_state, load_context) - self.state_machine = load_context.process - - def interrupt(self, reason: Any) -> None: - pass +@runtime_checkable +class Savable(Protocol): + def save(self, save_context: LoadSaveContext | None = None) -> SAVED_STATE_TYPE: ... +@final @auto_persist('args', 'kwargs') -class Created(State): - LABEL = ProcessState.CREATED - ALLOWED = {ProcessState.RUNNING, ProcessState.KILLED, ProcessState.EXCEPTED} +class Created(persistence.Savable): + LABEL: ClassVar = ProcessState.CREATED + ALLOWED: ClassVar = {ProcessState.RUNNING, ProcessState.KILLED, ProcessState.EXCEPTED} RUN_FN = 'run_fn' + is_terminal: ClassVar[bool] = False def __init__(self, process: 'Process', run_fn: Callable[..., Any], *args: Any, **kwargs: Any) -> None: - super().__init__(process) assert run_fn is not None + self.process = process self.run_fn = run_fn self.args = args self.kwargs = kwargs @@ -170,16 +187,25 @@ def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persist def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) + self.process = load_context.process + self.run_fn = getattr(self.process, saved_state[self.RUN_FN]) - def execute(self) -> state_machine.State: - return self.create_state(ProcessState.RUNNING, self.run_fn, *self.args, **self.kwargs) + def execute(self) -> st.State: + return st.create_state( + self.process, ProcessState.RUNNING, process=self.process, run_fn=self.run_fn, *self.args, **self.kwargs + ) + + def enter(self) -> None: ... + + def exit(self) -> None: ... +@final @auto_persist('args', 'kwargs') -class Running(State): - LABEL = ProcessState.RUNNING - ALLOWED = { +class Running(persistence.Savable): + LABEL: ClassVar = ProcessState.RUNNING + ALLOWED: ClassVar = { ProcessState.RUNNING, ProcessState.WAITING, ProcessState.FINISHED, @@ -195,9 +221,11 @@ class Running(State): _running: bool = False _run_handle = None + is_terminal: ClassVar[bool] = False + def __init__(self, process: 'Process', run_fn: Callable[..., Any], *args: Any, **kwargs: Any) -> None: - super().__init__(process) assert run_fn is not None + self.process = process self.run_fn = run_fn self.args = args self.kwargs = kwargs @@ -211,6 +239,8 @@ def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persist def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) + self.process = load_context.process + self.run_fn = getattr(self.process, saved_state[self.RUN_FN]) if self.COMMAND in saved_state: self._command = persistence.Savable.load(saved_state[self.COMMAND], load_context) # type: ignore @@ -218,7 +248,7 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persi def interrupt(self, reason: Any) -> None: pass - async def execute(self) -> State: # type: ignore + def execute(self) -> st.State: if self._command is not None: command = self._command else: @@ -232,8 +262,10 @@ async def execute(self) -> State: # type: ignore # Let this bubble up to the caller raise except Exception: - excepted = self.create_state(ProcessState.EXCEPTED, *sys.exc_info()[1:]) - return cast(State, excepted) + _, exception, traceback = sys.exc_info() + # excepted = state_cls(exception=exception, traceback=traceback) + excepted = Excepted(exception=exception, traceback=traceback) + return excepted else: if not isinstance(result, Command): if isinstance(result, exceptions.UnsuccessfulResult): @@ -242,32 +274,52 @@ async def execute(self) -> State: # type: ignore # Got passed a basic return type result = Stop(result, True) - command = result + command = cast(Stop, result) next_state = self._action_command(command) return next_state - def _action_command(self, command: Union[Kill, Stop, Wait, Continue]) -> State: + def _action_command(self, command: Union[Kill, Stop, Wait, Continue]) -> st.State: if isinstance(command, Kill): - state = self.create_state(ProcessState.KILLED, command.msg) + state = st.create_state(self.process, ProcessState.KILLED, msg=command.msg) # elif isinstance(command, Pause): # self.pause() elif isinstance(command, Stop): - state = self.create_state(ProcessState.FINISHED, command.result, command.successful) + state = st.create_state( + self.process, ProcessState.FINISHED, result=command.result, successful=command.successful + ) elif isinstance(command, Wait): - state = self.create_state(ProcessState.WAITING, command.continue_fn, command.msg, command.data) + state = st.create_state( + self.process, + ProcessState.WAITING, + process=self.process, + done_callback=command.continue_fn, + msg=command.msg, + data=command.data, + ) elif isinstance(command, Continue): - state = self.create_state(ProcessState.RUNNING, command.continue_fn, *command.args) + state = st.create_state( + self.process, + ProcessState.RUNNING, + process=self.process, + run_fn=command.continue_fn, + *command.args, + **command.kwargs, + ) else: raise ValueError('Unrecognised command') - return cast(State, state) # casting from base.State to process.State + return state + + def enter(self) -> None: ... + + def exit(self) -> None: ... @auto_persist('msg', 'data') -class Waiting(State): - LABEL = ProcessState.WAITING - ALLOWED = { +class Waiting(persistence.Savable): + LABEL: ClassVar = ProcessState.WAITING + ALLOWED: ClassVar = { ProcessState.RUNNING, ProcessState.WAITING, ProcessState.KILLED, @@ -279,6 +331,8 @@ class Waiting(State): _interruption = None + is_terminal: ClassVar[bool] = False + def __str__(self) -> str: state_info = super().__str__() if self.msg is not None: @@ -292,7 +346,7 @@ def __init__( msg: Optional[str] = None, data: Optional[Any] = None, ) -> None: - super().__init__(process) + self.process = process self.done_callback = done_callback self.msg = msg self.data = data @@ -305,6 +359,8 @@ def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persist def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) + self.process = load_context.process + callback_name = saved_state.get(self.DONE_CALLBACK, None) if callback_name is not None: self.done_callback = getattr(self.process, callback_name) @@ -312,11 +368,11 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persi self.done_callback = None self._waiting_future = futures.Future() - def interrupt(self, reason: Any) -> None: + def interrupt(self, reason: Exception) -> None: # This will cause the future in execute() to raise the exception self._waiting_future.set_exception(reason) - async def execute(self) -> State: # type: ignore + async def execute(self) -> st.State: try: result = await self._waiting_future except Interruption: @@ -327,11 +383,15 @@ async def execute(self) -> State: # type: ignore raise if result == NULL: - next_state = self.create_state(ProcessState.RUNNING, self.done_callback) + next_state = st.create_state( + self.process, ProcessState.RUNNING, process=self.process, run_fn=self.done_callback + ) else: - next_state = self.create_state(ProcessState.RUNNING, self.done_callback, result) + next_state = st.create_state( + self.process, ProcessState.RUNNING, process=self.process, done_callback=self.done_callback, *result + ) - return cast(State, next_state) # casting from base.State to process.State + return next_state def resume(self, value: Any = NULL) -> None: assert self._waiting_future is not None, 'Not yet waiting' @@ -341,24 +401,39 @@ def resume(self, value: Any = NULL) -> None: self._waiting_future.set_result(value) + def enter(self) -> None: ... + + def exit(self) -> None: ... + + +@final +class Excepted(persistence.Savable): + """ + Excepted state, can optionally provide exception and traceback + + :param exception: The exception instance + :param traceback: An optional exception traceback + """ -class Excepted(State): - LABEL = ProcessState.EXCEPTED + LABEL: ClassVar = ProcessState.EXCEPTED + ALLOWED: ClassVar[set[str]] = set() EXC_VALUE = 'ex_value' TRACEBACK = 'traceback' + is_terminal: ClassVar = True + def __init__( - self, process: 'Process', exception: Optional[BaseException], trace_back: Optional[TracebackType] = None + self, + exception: Optional[BaseException], + traceback: Optional[TracebackType] = None, ): """ - :param process: The associated process :param exception: The exception instance - :param trace_back: An optional exception traceback + :param traceback: An optional exception traceback """ - super().__init__(process) self.exception = exception - self.traceback = trace_back + self.traceback = traceback def __str__(self) -> str: exception = traceback.format_exception_only(type(self.exception) if self.exception else None, self.exception)[0] @@ -372,6 +447,7 @@ def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persist def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) + self.exception = yaml.load(saved_state[self.EXC_VALUE], Loader=Loader) if _HAS_TBLIB: try: @@ -381,35 +457,78 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persi else: self.traceback = None - def get_exc_info(self) -> Tuple[Optional[Type[BaseException]], Optional[BaseException], Optional[TracebackType]]: + def get_exc_info( + self, + ) -> Tuple[Optional[Type[BaseException]], Optional[BaseException], Optional[TracebackType]]: """ Recreate the exc_info tuple and return it """ - return type(self.exception) if self.exception else None, self.exception, self.traceback + return ( + type(self.exception) if self.exception else None, + self.exception, + self.traceback, + ) + + def enter(self) -> None: ... + def exit(self) -> None: ... + +@final @auto_persist('result', 'successful') -class Finished(State): - LABEL = ProcessState.FINISHED +class Finished(persistence.Savable): + """State for process is finished. + + :param result: The result of process + :param successful: Boolean for the exit code is ``0`` the process is successful. + """ + + LABEL: ClassVar = ProcessState.FINISHED + ALLOWED: ClassVar[set[str]] = set() - def __init__(self, process: 'Process', result: Any, successful: bool) -> None: - super().__init__(process) + is_terminal: ClassVar[bool] = True + + def __init__(self, result: Any, successful: bool) -> None: self.result = result self.successful = successful + def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: + super().load_instance_state(saved_state, load_context) + def enter(self) -> None: ... + + def exit(self) -> None: ... + + +@final @auto_persist('msg') -class Killed(State): - LABEL = ProcessState.KILLED +class Killed(persistence.Savable): + """ + Represents a state where a process has been killed. + + This state is used to indicate that a process has been terminated and can optionally + include a message providing details about the termination. + + :param msg: An optional message explaining the reason for the process termination. + """ + + LABEL: ClassVar = ProcessState.KILLED + ALLOWED: ClassVar[set[str]] = set() - def __init__(self, process: 'Process', msg: Optional[str]): + is_terminal: ClassVar[bool] = True + + def __init__(self, msg: Optional[MessageType]): """ - :param process: The associated process :param msg: Optional kill message - """ - super().__init__(process) self.msg = msg + def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: + super().load_instance_state(saved_state, load_context) + + def enter(self) -> None: ... + + def exit(self) -> None: ... + # endregion diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index ba7967d3..bae08dd4 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -1,6 +1,8 @@ # -*- coding: utf-8 -*- """The main Process module""" +from __future__ import annotations + import abc import asyncio import contextlib @@ -26,6 +28,7 @@ Sequence, Tuple, Type, + TypeVar, Union, cast, ) @@ -39,15 +42,36 @@ import yaml from aio_pika.exceptions import ChannelInvalidStateError, ConnectionClosed -from . import events, exceptions, futures, persistence, ports, process_comms, process_states, utils +from . import ( + events, + exceptions, + futures, + persistence, + ports, + process_comms, + process_states, + utils, +) from .base import state_machine -from .base.state_machine import StateEntryFailed, StateMachine, TransitionFailed, event +from .base.state_machine import ( + Interruptable, + Proceedable, + StateEntryFailed, + StateMachine, + StateMachineError, + TransitionFailed, + create_state, + event, +) from .base.utils import call_with_super_check, super_check from .event_helper import EventHelper +from .process_comms import MESSAGE_KEY, KillMessage, MessageType from .process_listener import ProcessListener from .process_spec import ProcessSpec from .utils import PID_TYPE, SAVED_STATE_TYPE, protected +T = TypeVar('T') + __all__ = ['BundleKeys', 'Process', 'ProcessSpec', 'TransitionFailed'] _LOGGER = logging.getLogger(__name__) @@ -91,7 +115,13 @@ def func_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: @persistence.auto_persist( - '_pid', '_creation_time', '_future', '_paused', '_status', '_pre_paused_status', '_event_helper' + '_pid', + '_creation_time', + '_future', + '_paused', + '_status', + '_pre_paused_status', + '_event_helper', ) class Process(StateMachine, persistence.Savable, metaclass=ProcessStateMachineMeta): """ @@ -158,7 +188,7 @@ def current(cls) -> Optional['Process']: return None @classmethod - def get_states(cls) -> Sequence[Type[process_states.State]]: + def get_states(cls) -> Sequence[Type[state_machine.State]]: """Return all allowed states of the process.""" state_classes = cls.get_state_classes() return ( @@ -167,7 +197,7 @@ def get_states(cls) -> Sequence[Type[process_states.State]]: ) @classmethod - def get_state_classes(cls) -> Dict[Hashable, Type[process_states.State]]: + def get_state_classes(cls) -> dict[process_states.ProcessState, Type[state_machine.State]]: # A mapping of the State constants to the corresponding state class return { process_states.ProcessState.CREATED: process_states.Created, @@ -231,7 +261,9 @@ def get_description(cls) -> Dict[str, Any]: @classmethod def recreate_from( - cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[persistence.LoadSaveContext] = None + cls, + saved_state: SAVED_STATE_TYPE, + load_context: Optional[persistence.LoadSaveContext] = None, ) -> 'Process': """ Recreate a process from a saved state, passing any positional and @@ -314,14 +346,21 @@ def init(self) -> None: identifier = self._communicator.add_broadcast_subscriber(subscriber, identifier=str(self.pid)) self.add_cleanup(functools.partial(self._communicator.remove_broadcast_subscriber, identifier)) except kiwipy.TimeoutError: - self.logger.exception('Process<%s>: failed to register as a broadcast subscriber', self.pid) + self.logger.exception( + 'Process<%s>: failed to register as a broadcast subscriber', + self.pid, + ) if not self._future.done(): def try_killing(future: futures.Future) -> None: if future.cancelled(): - if not self.kill('Killed by future being cancelled'): - self.logger.warning('Process<%s>: Failed to kill process on future cancel', self.pid) + msg = KillMessage.build(message='Killed by future being cancelled') + if not self.kill(msg): + self.logger.warning( + 'Process<%s>: Failed to kill process on future cancel', + self.pid, + ) self._future.add_done_callback(try_killing) @@ -329,10 +368,10 @@ def _setup_event_hooks(self) -> None: """Set the event hooks to process, when it is created or loaded(recreated).""" event_hooks = { state_machine.StateEventHook.ENTERING_STATE: lambda _s, _h, state: self.on_entering( - cast(process_states.State, state) + cast(state_machine.State, state) ), state_machine.StateEventHook.ENTERED_STATE: lambda _s, _h, from_state: self.on_entered( - cast(Optional[process_states.State], from_state) + cast(Optional[state_machine.State], from_state) ), state_machine.StateEventHook.EXITING_STATE: lambda _s, _h, _state: self.on_exiting(), } @@ -425,7 +464,13 @@ def launch( The process is started asynchronously, without blocking other task in the event loop. """ - process = process_class(inputs=inputs, pid=pid, logger=logger, loop=self.loop, communicator=self._communicator) + process = process_class( + inputs=inputs, + pid=pid, + logger=logger, + loop=self.loop, + communicator=self._communicator, + ) self.loop.create_task(process.step_until_terminated()) return process @@ -433,7 +478,7 @@ def launch( def has_terminated(self) -> bool: """Return whether the process was terminated.""" - return self._state.is_terminal() + return self._state.is_terminal def result(self) -> Any: """ @@ -477,7 +522,7 @@ def killed(self) -> bool: """Return whether the process is killed.""" return self.state == process_states.ProcessState.KILLED - def killed_msg(self) -> Optional[str]: + def killed_msg(self) -> Optional[MessageType]: """Return the killed message.""" if isinstance(self._state, process_states.Killed): return self._state.msg @@ -506,7 +551,7 @@ def done(self) -> bool: Use the `has_terminated` method instead """ warnings.warn('method is deprecated, use `has_terminated` instead', DeprecationWarning) - return self._state.is_terminal() + return self._state.is_terminal # endregion @@ -529,7 +574,10 @@ def call_soon(self, callback: Callable[..., Any], *args: Any, **kwargs: Any) -> return handle def callback_excepted( - self, _callback: Callable[..., Any], exception: Optional[BaseException], trace: Optional[TracebackType] + self, + _callback: Callable[..., Any], + exception: Optional[BaseException], + trace: Optional[TracebackType], ) -> None: if self.state != process_states.ProcessState.EXCEPTED: self.fail(exception, trace) @@ -555,7 +603,7 @@ def _process_scope(self) -> Generator[None, None, None]: stack_copy.pop() PROCESS_STACK.set(stack_copy) - async def _run_task(self, callback: Callable[..., Any], *args: Any, **kwargs: Any) -> Any: + async def _run_task(self, callback: Callable[..., T], *args: Any, **kwargs: Any) -> T: """ This method should be used to run all Process related functions and coroutines. If there is an exception the process will enter the EXCEPTED state. @@ -576,7 +624,9 @@ async def _run_task(self, callback: Callable[..., Any], *args: Any, **kwargs: An # region Persistence def save_instance_state( - self, out_state: SAVED_STATE_TYPE, save_context: Optional[persistence.LoadSaveContext] + self, + out_state: SAVED_STATE_TYPE, + save_context: Optional[persistence.LoadSaveContext], ) -> None: """ Ask the process to save its current instance state. @@ -586,7 +636,9 @@ def save_instance_state( """ super().save_instance_state(out_state, save_context) - out_state['_state'] = self._state.save() + # FIXME: the combined ProcessState protocol should cover the case + if isinstance(self._state, process_states.Savable): + out_state['_state'] = self._state.save() # Inputs/outputs if self.raw_inputs is not None: @@ -622,7 +674,7 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persi else: self._loop = asyncio.get_event_loop() - self._state: process_states.State = self.recreate_state(saved_state['_state']) + self._state: state_machine.State = self.recreate_state(saved_state['_state']) if 'communicator' in load_context: self._communicator = load_context.communicator @@ -680,7 +732,7 @@ def log_with_pid(self, level: int, msg: str) -> None: # region Events - def on_entering(self, state: process_states.State) -> None: + def on_entering(self, state: state_machine.State) -> None: # Map these onto direct functions that the subclass can implement state_label = state.LABEL if state_label == process_states.ProcessState.CREATED: @@ -696,7 +748,7 @@ def on_entering(self, state: process_states.State) -> None: elif state_label == process_states.ProcessState.EXCEPTED: call_with_super_check(self.on_except, state.get_exc_info()) # type: ignore - def on_entered(self, from_state: Optional[process_states.State]) -> None: + def on_entered(self, from_state: Optional[state_machine.State]) -> None: # Map these onto direct functions that the subclass can implement state_label = self._state.LABEL if state_label == process_states.ProcessState.RUNNING: @@ -828,7 +880,9 @@ def on_finish(self, result: Any, successful: bool) -> None: if successful: validation_error = self.spec().outputs.validate(self.outputs) if validation_error: - raise StateEntryFailed(process_states.ProcessState.FINISHED, result, False) + state_cls = self.get_states_map()[process_states.ProcessState.FINISHED] + finished_state = state_cls(result=result, successful=False) + raise StateEntryFailed(finished_state) self.future().set_result(self.outputs) @@ -857,10 +911,15 @@ def on_excepted(self) -> None: self._fire_event(ProcessListener.on_process_excepted, str(self.future().exception())) @super_check - def on_kill(self, msg: Optional[str]) -> None: + def on_kill(self, msg: Optional[MessageType]) -> None: """Entering the KILLED state.""" - self.set_status(msg) - self.future().set_exception(exceptions.KilledError(msg)) + if msg is None: + msg_txt = '' + else: + msg_txt = msg[MESSAGE_KEY] or '' + + self.set_status(msg_txt) + self.future().set_exception(exceptions.KilledError(msg_txt)) @super_check def on_killed(self) -> None: @@ -906,7 +965,12 @@ def message_receive(self, _comm: kiwipy.Communicator, msg: Dict[str, Any]) -> An :param msg: the message :return: the outcome of processing the message, the return value will be sent back as a response to the sender """ - self.logger.debug("Process<%s>: received RPC message with communicator '%s': %r", self.pid, _comm, msg) + self.logger.debug( + "Process<%s>: received RPC message with communicator '%s': %r", + self.pid, + _comm, + msg, + ) intent = msg[process_comms.INTENT_KEY] @@ -915,7 +979,7 @@ def message_receive(self, _comm: kiwipy.Communicator, msg: Dict[str, Any]) -> An if intent == process_comms.Intent.PAUSE: return self._schedule_rpc(self.pause, msg=msg.get(process_comms.MESSAGE_KEY, None)) if intent == process_comms.Intent.KILL: - return self._schedule_rpc(self.kill, msg=msg.get(process_comms.MESSAGE_KEY, None)) + return self._schedule_rpc(self.kill, msg=msg) if intent == process_comms.Intent.STATUS: status_info: Dict[str, Any] = {} self.get_status_info(status_info) @@ -935,7 +999,11 @@ def broadcast_receive( """ self.logger.debug( - "Process<%s>: received broadcast message '%s' with communicator '%s': %r", self.pid, subject, _comm, body + "Process<%s>: received broadcast message '%s' with communicator '%s': %r", + self.pid, + subject, + _comm, + body, ) # If we get a message we recognise then action it, otherwise ignore @@ -1001,13 +1069,19 @@ def close(self) -> None: # region State related methods def transition_failed( - self, initial_state: Hashable, final_state: Hashable, exception: Exception, trace: TracebackType + self, + initial_state: Hashable, + final_state: Hashable, + exception: Exception, + trace: TracebackType, ) -> None: # If we are creating, then reraise instead of failing. if final_state == process_states.ProcessState.CREATED: raise exception.with_traceback(trace) - self.transition_to(process_states.ProcessState.EXCEPTED, exception, trace) + # state_class = self.get_states_map()[process_states.ProcessState.EXCEPTED] + new_state = create_state(self, process_states.ProcessState.EXCEPTED, exception=exception, traceback=trace) + self.transition_to(new_state) def pause(self, msg: Union[str, None] = None) -> Union[bool, futures.CancellableAction]: """Pause the process. @@ -1031,6 +1105,11 @@ def pause(self, msg: Union[str, None] = None) -> Union[bool, futures.Cancellable return self._pausing if self._stepping: + if not isinstance(self._state, Interruptable): + raise exceptions.InvalidStateError( + f'cannot interrupt {self._state.__class__}, method `interrupt` not implement' + ) + # Ask the step function to pause by setting this flag and giving the # caller back a future interrupt_exception = process_states.PauseInterruption(msg) @@ -1042,7 +1121,11 @@ def pause(self, msg: Union[str, None] = None) -> Union[bool, futures.Cancellable return self._do_pause(msg) - def _do_pause(self, state_msg: Optional[str], next_state: Optional[process_states.State] = None) -> bool: + @staticmethod + def _interrupt(state: Interruptable, reason: Exception) -> None: + state.interrupt(reason) + + def _do_pause(self, state_msg: Optional[str], next_state: Optional[state_machine.State] = None) -> bool: """Carry out the pause procedure, optionally transitioning to the next state first""" try: if next_state is not None: @@ -1068,11 +1151,13 @@ def _create_interrupt_action(self, exception: process_states.Interruption) -> fu if isinstance(exception, process_states.KillInterruption): - def do_kill(_next_state: process_states.State) -> Any: + def do_kill(_next_state: state_machine.State) -> Any: try: - # Ignore the next state - self.transition_to(process_states.ProcessState.KILLED, str(exception)) + new_state = create_state(self, process_states.ProcessState.KILLED, msg=exception.msg) + self.transition_to(new_state) return True + # FIXME: if try block except, will hit deadlock in event loop + # need to know how to debug it, and where to set a timeout. finally: self._killing = None @@ -1117,15 +1202,17 @@ def resume(self, *args: Any) -> None: return self._state.resume(*args) # type: ignore @event(to_states=process_states.Excepted) - def fail(self, exception: Optional[BaseException], trace_back: Optional[TracebackType]) -> None: + def fail(self, exception: Optional[BaseException], traceback: Optional[TracebackType]) -> None: """ Fail the process in response to an exception :param exception: The exception that caused the failure - :param trace_back: Optional exception traceback + :param traceback: Optional exception traceback """ - self.transition_to(process_states.ProcessState.EXCEPTED, exception, trace_back) + # state_class = self.get_states_map()[process_states.ProcessState.EXCEPTED] + new_state = create_state(self, process_states.ProcessState.EXCEPTED, exception=exception, traceback=traceback) + self.transition_to(new_state) - def kill(self, msg: Union[str, None] = None) -> Union[bool, asyncio.Future]: + def kill(self, msg: Optional[MessageType] = None) -> Union[bool, asyncio.Future]: """ Kill the process :param msg: An optional kill message @@ -1142,7 +1229,7 @@ def kill(self, msg: Union[str, None] = None) -> Union[bool, asyncio.Future]: # Already killing return self._killing - if self._stepping: + if self._stepping and isinstance(self._state, Interruptable): # Ask the step function to pause by setting this flag and giving the # caller back a future interrupt_exception = process_states.KillInterruption(msg) @@ -1151,7 +1238,8 @@ def kill(self, msg: Union[str, None] = None) -> Union[bool, asyncio.Future]: self._state.interrupt(interrupt_exception) return cast(futures.CancellableAction, self._interrupt_action) - self.transition_to(process_states.ProcessState.KILLED, msg) + new_state = create_state(self, process_states.ProcessState.KILLED, msg=msg) + self.transition_to(new_state) return True @property @@ -1161,16 +1249,16 @@ def is_killing(self) -> bool: # endregion - def create_initial_state(self) -> process_states.State: + def create_initial_state(self) -> state_machine.State: """This method is here to override its superclass. Automatically enter the CREATED state when the process is created. :return: A Created state """ - return cast(process_states.State, self.get_state_class(process_states.ProcessState.CREATED)(self, self.run)) + return self.get_state_class(process_states.ProcessState.CREATED)(self, self.run) - def recreate_state(self, saved_state: persistence.Bundle) -> process_states.State: + def recreate_state(self, saved_state: persistence.Bundle) -> state_machine.State: """ Create a state object from a saved state @@ -1178,7 +1266,7 @@ def recreate_state(self, saved_state: persistence.Bundle) -> process_states.Stat :return: An instance of the object with its state loaded from the save state. """ load_context = persistence.LoadSaveContext(process=self) - return cast(process_states.State, persistence.Savable.load(saved_state, load_context)) + return cast(state_machine.State, persistence.Savable.load(saved_state, load_context)) # endregion @@ -1216,6 +1304,9 @@ async def step(self) -> None: if self.paused and self._paused is not None: await self._paused + if not isinstance(self._state, Proceedable): + raise StateMachineError(f'cannot step from {self._state.__class__}, async method `execute` not implemented') + try: self._stepping = True next_state = None @@ -1236,7 +1327,9 @@ async def step(self) -> None: raise except Exception: # Overwrite the next state to go to excepted directly - next_state = self.create_state(process_states.ProcessState.EXCEPTED, *sys.exc_info()[1:]) + next_state = create_state( + self, process_states.ProcessState.EXCEPTED, exception=sys.exc_info()[1], traceback=sys.exc_info()[2] + ) self._set_interrupt_action(None) if self._interrupt_action: diff --git a/src/plumpy/workchains.py b/src/plumpy/workchains.py index 748a44d7..865a5b61 100644 --- a/src/plumpy/workchains.py +++ b/src/plumpy/workchains.py @@ -11,7 +11,6 @@ Any, Callable, Dict, - Hashable, List, Mapping, MutableSequence, @@ -25,6 +24,9 @@ import kiwipy +from plumpy.base import state_machine +from plumpy.exceptions import InvalidStateError + from . import lang, mixins, persistence, process_states, processes from .utils import PID_TYPE, SAVED_STATE_TYPE @@ -68,6 +70,7 @@ def get_outline(self) -> Union['_Instruction', '_FunctionCall']: return self._outline +# FIXME: better use composition here @persistence.auto_persist('_awaiting') class Waiting(process_states.Waiting): """Overwrite the waiting state""" @@ -77,24 +80,14 @@ def __init__( process: 'WorkChain', done_callback: Optional[Callable[..., Any]], msg: Optional[str] = None, - awaiting: Optional[Dict[Union[asyncio.Future, processes.Process], str]] = None, + data: Optional[Dict[Union[asyncio.Future, processes.Process], str]] = None, ) -> None: - super().__init__(process, done_callback, msg, awaiting) + super().__init__(process, done_callback, msg, data) self._awaiting: Dict[asyncio.Future, str] = {} - for awaitable, key in (awaiting or {}).items(): + for awaitable, key in (data or {}).items(): resolved_awaitable = awaitable.future() if isinstance(awaitable, processes.Process) else awaitable self._awaiting[resolved_awaitable] = key - def enter(self) -> None: - super().enter() - for awaitable in self._awaiting: - awaitable.add_done_callback(self._awaitable_done) - - def exit(self) -> None: - super().exit() - for awaitable in self._awaiting: - awaitable.remove_done_callback(self._awaitable_done) - def _awaitable_done(self, awaitable: asyncio.Future) -> None: key = self._awaiting.pop(awaitable) try: @@ -105,6 +98,20 @@ def _awaitable_done(self, awaitable: asyncio.Future) -> None: if not self._awaiting: self._waiting_future.set_result(lang.NULL) + def enter(self) -> None: + for awaitable in self._awaiting: + awaitable.add_done_callback(self._awaitable_done) + + self.in_state = True + + def exit(self) -> None: + if self.is_terminal: + raise InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') + self.in_state = False + + for awaitable in self._awaiting: + awaitable.remove_done_callback(self._awaitable_done) + class WorkChain(mixins.ContextMixin, processes.Process): """ @@ -117,7 +124,7 @@ class WorkChain(mixins.ContextMixin, processes.Process): _CONTEXT = 'CONTEXT' @classmethod - def get_state_classes(cls) -> Dict[Hashable, Type[process_states.State]]: + def get_state_classes(cls) -> Dict[process_states.ProcessState, Type[state_machine.State]]: states_map = super().get_state_classes() states_map[process_states.ProcessState.WAITING] = Waiting return states_map diff --git a/tests/base/test_statemachine.py b/tests/base/test_statemachine.py index 5b4b73d8..6a61fe00 100644 --- a/tests/base/test_statemachine.py +++ b/tests/base/test_statemachine.py @@ -1,8 +1,10 @@ # -*- coding: utf-8 -*- import time +from typing import final import unittest from plumpy.base import state_machine +from plumpy.exceptions import InvalidStateError # Events PLAY = 'Play' @@ -15,31 +17,25 @@ STOPPED = 'Stopped' -class Playing(state_machine.State): +class Playing: LABEL = PLAYING ALLOWED = {PAUSED, STOPPED} TRANSITIONS = {STOP: STOPPED} + is_terminal = False + def __init__(self, player, track): assert track is not None, 'Must provide a track name' - super().__init__(player) self.track = track self._last_time = None self._played = 0.0 + self.in_state = False def __str__(self): if self.in_state: self._update_time() return f'> {self.track} ({self._played}s)' - def enter(self): - super().enter() - self._last_time = time.time() - - def exit(self): - super().exit() - self._update_time() - def play(self, track=None): # pylint: disable=no-self-use, unused-argument return False @@ -48,15 +44,28 @@ def _update_time(self): self._played += current_time - self._last_time self._last_time = current_time + def enter(self) -> None: + self._last_time = time.time() + self.in_state = True + + def exit(self) -> None: + if self.is_terminal: + raise InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') + + self._update_time() + self.in_state = False + -class Paused(state_machine.State): +class Paused: LABEL = PAUSED ALLOWED = {PLAYING, STOPPED} TRANSITIONS = {STOP: STOPPED} + is_terminal = False + def __init__(self, player, playing_state): assert isinstance(playing_state, Playing), 'Must provide the playing state to pause' - super().__init__(player) + self._player = player self.playing_state = playing_state def __str__(self): @@ -64,23 +73,46 @@ def __str__(self): def play(self, track=None): if track is not None: - self.state_machine.transition_to(Playing, track=track) + self._player.transition_to(Playing(player=self.state_machine, track=track)) else: - self.state_machine.transition_to(self.playing_state) + self._player.transition_to(self.playing_state) + + def enter(self) -> None: + self.in_state = True + + def exit(self) -> None: + if self.is_terminal: + raise InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') + self.in_state = False -class Stopped(state_machine.State): + +class Stopped: LABEL = STOPPED ALLOWED = { PLAYING, } TRANSITIONS = {PLAY: PLAYING} + is_terminal = False + + def __init__(self, player): + self._player = player + def __str__(self): return '[]' def play(self, track): - self.state_machine.transition_to(Playing, track=track) + self._player.transition_to(Playing(self._player, track=track)) + + def enter(self) -> None: + self.in_state = True + + def exit(self) -> None: + if self.is_terminal: + raise InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') + + self.in_state = False class CdPlayer(state_machine.StateMachine): @@ -107,12 +139,12 @@ def play(self, track=None): @state_machine.event(from_states=Playing, to_states=Paused) def pause(self): - self.transition_to(Paused, playing_state=self._state) + self.transition_to(Paused(self, playing_state=self._state)) return True @state_machine.event(from_states=(Playing, Paused), to_states=Stopped) def stop(self): - self.transition_to(Stopped) + self.transition_to(Stopped(self)) class TestStateMachine(unittest.TestCase): diff --git a/tests/persistence/test_inmemory.py b/tests/persistence/test_inmemory.py index b0db46e7..9e3141de 100644 --- a/tests/persistence/test_inmemory.py +++ b/tests/persistence/test_inmemory.py @@ -1,11 +1,9 @@ # -*- coding: utf-8 -*- import unittest -from ..utils import ProcessWithCheckpoint - import plumpy -import plumpy +from ..utils import ProcessWithCheckpoint class TestInMemoryPersister(unittest.TestCase): diff --git a/tests/persistence/test_pickle.py b/tests/persistence/test_pickle.py index dd68b4fd..da4ede51 100644 --- a/tests/persistence/test_pickle.py +++ b/tests/persistence/test_pickle.py @@ -5,10 +5,10 @@ if getattr(tempfile, 'TemporaryDirectory', None) is None: from backports import tempfile -from ..utils import ProcessWithCheckpoint - import plumpy +from ..utils import ProcessWithCheckpoint + class TestPicklePersister(unittest.TestCase): def test_save_load_roundtrip(self): diff --git a/tests/rmq/test_process_comms.py b/tests/rmq/test_process_comms.py index 7223b888..c6826a24 100644 --- a/tests/rmq/test_process_comms.py +++ b/tests/rmq/test_process_comms.py @@ -1,6 +1,5 @@ # -*- coding: utf-8 -*- import asyncio -import copy import kiwipy import pytest @@ -196,8 +195,7 @@ async def test_kill_all(self, thread_communicator, sync_controller): for _ in range(10): procs.append(utils.WaitForSignalProcess(communicator=thread_communicator)) - msg = copy.copy(process_comms.KILL_MSG) - msg[process_comms.MESSAGE_KEY] = 'bang bang, I shot you down' + msg = process_comms.KillMessage.build(message='bang bang, I shot you down') sync_controller.kill_all(msg) await utils.wait_util(lambda: all([proc.killed() for proc in procs])) diff --git a/tests/test_expose.py b/tests/test_expose.py index 0f6f8087..c5e6014c 100644 --- a/tests/test_expose.py +++ b/tests/test_expose.py @@ -1,12 +1,12 @@ # -*- coding: utf-8 -*- import unittest -from .utils import NewLoopProcess - from plumpy.ports import PortNamespace from plumpy.process_spec import ProcessSpec from plumpy.processes import Process +from .utils import NewLoopProcess + def validator_function(input, port): pass diff --git a/tests/test_process_comms.py b/tests/test_process_comms.py index c59737ac..44947230 100644 --- a/tests/test_process_comms.py +++ b/tests/test_process_comms.py @@ -1,9 +1,9 @@ # -*- coding: utf-8 -*- import pytest -from tests import utils import plumpy from plumpy import process_comms +from tests import utils class Process(plumpy.Process): diff --git a/tests/test_processes.py b/tests/test_processes.py index faea9eae..4b8cc606 100644 --- a/tests/test_processes.py +++ b/tests/test_processes.py @@ -2,18 +2,17 @@ """Process tests""" import asyncio -import copy import enum import unittest import kiwipy import pytest -from tests import utils import plumpy from plumpy import BundleKeys, Process, ProcessState -from plumpy.process_comms import KILL_MSG, MESSAGE_KEY +from plumpy.process_comms import KillMessage from plumpy.utils import AttributesFrozendict +from tests import utils class ForgetToCallParent(plumpy.Process): @@ -323,8 +322,7 @@ def run(self, **kwargs): def test_kill(self): proc: Process = utils.DummyProcess() - msg = copy.copy(KILL_MSG) - msg[MESSAGE_KEY] = 'Farewell!' + msg = KillMessage.build(message='Farewell!') proc.kill(msg) self.assertTrue(proc.killed()) self.assertEqual(proc.killed_msg(), msg) @@ -430,8 +428,7 @@ class KillProcess(Process): after_kill = False def run(self, **kwargs): - msg = copy.copy(KILL_MSG) - msg[MESSAGE_KEY] = 'killed' + msg = KillMessage.build(message='killed') self.kill(msg) # The following line should be executed because kill will not # interrupt execution of a method call in the RUNNING state @@ -656,7 +653,7 @@ def test_exception_during_on_entered(self): class RaisingProcess(Process): def on_entered(self, from_state): - if from_state is not None and from_state.label == ProcessState.RUNNING: + if from_state is not None and from_state.LABEL == ProcessState.RUNNING: raise RuntimeError('exception during on_entered') super().on_entered(from_state) diff --git a/tests/utils.py b/tests/utils.py index f2a58dfc..88638e01 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -3,13 +3,12 @@ import asyncio import collections -import copy import unittest from collections.abc import Mapping import plumpy from plumpy import persistence, process_states, processes, utils -from plumpy.process_comms import KILL_MSG, MESSAGE_KEY +from plumpy.process_comms import KillMessage Snapshot = collections.namedtuple('Snapshot', ['state', 'bundle', 'outputs']) @@ -86,8 +85,7 @@ def last_step(self): class KillProcess(processes.Process): @utils.override def run(self): - msg = copy.copy(KILL_MSG) - msg[MESSAGE_KEY] = 'killed' + msg = KillMessage.build(message='killed') return process_states.Kill(msg=msg)