From b88a0e2d0f33be727c551a7ae485cf02223b8b0d Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Tue, 3 Dec 2024 01:48:02 +0100 Subject: [PATCH] create_state refact Hashable initialized + parameters passed to Hashable --- src/plumpy/base/state_machine.py | 27 ++++------- src/plumpy/process_states.py | 82 +++++++++++++++++++------------- src/plumpy/processes.py | 28 ++++++----- src/plumpy/workchains.py | 6 +-- tests/base/test_statemachine.py | 15 +++--- 5 files changed, 84 insertions(+), 74 deletions(-) diff --git a/src/plumpy/base/state_machine.py b/src/plumpy/base/state_machine.py index 2fffa9dc..9d3a1c96 100644 --- a/src/plumpy/base/state_machine.py +++ b/src/plumpy/base/state_machine.py @@ -147,7 +147,6 @@ def interrupt(self, reason: Exception) -> None: ... @runtime_checkable class Proceedable(Protocol): - def execute(self) -> State | None: """ Execute the state, performing the actions that this state is responsible for. @@ -156,6 +155,14 @@ def execute(self) -> State | None: ... +def create_state(st: state_machine.StateMachine, state_label: Hashable, *args, **kwargs: Any) -> state_machine.State: + if state_label not in st.get_states_map(): + raise ValueError(f'{state_label} is not a valid state') + + state_cls = st.get_states_map()[state_label] + return state_cls(*args, **kwargs) + + class StateEventHook(enum.Enum): """ Hooks that can be used to register callback at various points in the state transition @@ -298,6 +305,7 @@ def transition_to(self, new_state: State | None, **kwargs: Any) -> None: The arguments are passed to the state class to create state instance. (process arg does not need to pass since it will always call with 'self' as process) """ + print(f'try: {self._state} -> {new_state}') assert not self._transitioning, 'Cannot call transition_to when already transitioning state' if new_state is None: @@ -354,17 +362,6 @@ def get_debug(self) -> bool: def set_debug(self, enabled: bool) -> None: self._debug: bool = enabled - def create_state(self, state_label: Hashable, *args: Any, **kwargs: Any) -> State: - # XXX: this method create state from label, which is duplicate as _create_state_instance and less generic - # because the label is defined after the state and required to be know before calling this function. - # This method should be replaced by `_create_state_instance`. - # aiida-core using this method for its Waiting state override. - try: - state_cls = self.get_states_map()[state_label] - return state_cls(self, *args, **kwargs) - except KeyError: - raise ValueError(f'{state_label} is not a valid state') - def _exit_current_state(self, next_state: State) -> None: """Exit the given state""" @@ -387,9 +384,3 @@ def _enter_next_state(self, next_state: State) -> None: next_state.enter() self._state = next_state self._fire_state_event(StateEventHook.ENTERED_STATE, last_state) - - def _create_state_instance(self, state_cls: type[State], **kwargs: Any) -> State: - if state_cls.LABEL not in self.get_states_map(): - raise ValueError(f'{state_cls.LABEL} is not a valid state') - - return state_cls(self, **kwargs) diff --git a/src/plumpy/process_states.py b/src/plumpy/process_states.py index 4348803e..58ece817 100644 --- a/src/plumpy/process_states.py +++ b/src/plumpy/process_states.py @@ -5,7 +5,7 @@ import traceback from enum import Enum from types import TracebackType -from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple, Type, Union, cast, final +from typing import TYPE_CHECKING, Any, Callable, Hashable, Optional, Tuple, Type, Union, cast, final import yaml from yaml.loader import Loader @@ -20,7 +20,7 @@ _HAS_TBLIB = False from . import exceptions, futures, persistence, utils -from .base import state_machine +from .base import state_machine as st from .lang import NULL from .persistence import auto_persist from .utils import SAVED_STATE_TYPE @@ -138,6 +138,7 @@ class ProcessState(Enum): The possible states that a :class:`~plumpy.processes.Process` can be in. """ + # FIXME: see LSP error of return a exception, the type is Literal[str] which is invariant, tricky CREATED = 'created' RUNNING = 'running' WAITING = 'waiting' @@ -172,8 +173,10 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persi self.run_fn = getattr(self.process, saved_state[self.RUN_FN]) - async def execute(self) -> state_machine.State: - return self.process.create_state(ProcessState.RUNNING, self.run_fn, *self.args, **self.kwargs) + def execute(self) -> st.State: + return st.create_state( + self.process, ProcessState.RUNNING, process=self.process, run_fn=self.run_fn, *self.args, **self.kwargs + ) def enter(self) -> None: ... @@ -227,7 +230,7 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persi def interrupt(self, reason: Any) -> None: pass - async def execute(self) -> state_machine.State: + def execute(self) -> st.State: if self._command is not None: command = self._command else: @@ -241,8 +244,10 @@ async def execute(self) -> state_machine.State: # Let this bubble up to the caller raise except Exception: - excepted = self.process.create_state(ProcessState.EXCEPTED, *sys.exc_info()[1:]) - return cast(state_machine.State, excepted) + _, exception, traceback = sys.exc_info() + # excepted = state_cls(exception=exception, traceback=traceback) + excepted = Excepted(exception=exception, traceback=traceback) + return excepted else: if not isinstance(result, Command): if isinstance(result, exceptions.UnsuccessfulResult): @@ -256,21 +261,37 @@ async def execute(self) -> state_machine.State: next_state = self._action_command(command) return next_state - def _action_command(self, command: Union[Kill, Stop, Wait, Continue]) -> state_machine.State: + def _action_command(self, command: Union[Kill, Stop, Wait, Continue]) -> st.State: if isinstance(command, Kill): - state = self.process.create_state(ProcessState.KILLED, command.msg) + state = st.create_state(self.process, ProcessState.KILLED, msg=command.msg) # elif isinstance(command, Pause): # self.pause() elif isinstance(command, Stop): - state = self.process.create_state(ProcessState.FINISHED, command.result, command.successful) + state = st.create_state( + self.process, ProcessState.FINISHED, result=command.result, successful=command.successful + ) elif isinstance(command, Wait): - state = self.process.create_state(ProcessState.WAITING, command.continue_fn, command.msg, command.data) + state = st.create_state( + self.process, + ProcessState.WAITING, + process=self.process, + done_callback=command.continue_fn, + msg=command.msg, + data=command.data, + ) elif isinstance(command, Continue): - state = self.process.create_state(ProcessState.RUNNING, command.continue_fn, *command.args) + state = st.create_state( + self.process, + ProcessState.RUNNING, + process=self.process, + run_fn=command.continue_fn, + *command.args, + **command.kwargs, + ) else: raise ValueError('Unrecognised command') - return cast(state_machine.State, state) # casting from base.State to process.State + return cast(st.State, state) # casting from base.State to process.State def enter(self) -> None: ... @@ -334,7 +355,7 @@ def interrupt(self, reason: Exception) -> None: # This will cause the future in execute() to raise the exception self._waiting_future.set_exception(reason) - async def execute(self) -> state_machine.State: # type: ignore + async def execute(self) -> st.State: # type: ignore try: result = await self._waiting_future except Interruption: @@ -345,11 +366,15 @@ async def execute(self) -> state_machine.State: # type: ignore raise if result == NULL: - next_state = self.process.create_state(ProcessState.RUNNING, self.done_callback) + next_state = st.create_state( + self.process, ProcessState.RUNNING, process=self.process, run_fn=self.done_callback + ) else: - next_state = self.process.create_state(ProcessState.RUNNING, self.done_callback, result) + next_state = st.create_state( + self.process, ProcessState.RUNNING, process=self.process, done_callback=self.done_callback, *result + ) - return cast(state_machine.State, next_state) # casting from base.State to process.State + return cast(st.State, next_state) # casting from base.State to process.State def resume(self, value: Any = NULL) -> None: assert self._waiting_future is not None, 'Not yet waiting' @@ -367,10 +392,10 @@ def exit(self) -> None: ... @final class Excepted(persistence.Savable): """ - Excepted state, can optionally provide exception and trace_back + Excepted state, can optionally provide exception and traceback :param exception: The exception instance - :param trace_back: An optional exception traceback + :param traceback: An optional exception traceback """ LABEL = ProcessState.EXCEPTED @@ -383,18 +408,15 @@ class Excepted(persistence.Savable): def __init__( self, - process: 'Process', exception: Optional[BaseException], - trace_back: Optional[TracebackType] = None, + traceback: Optional[TracebackType] = None, ): """ - :param process: The associated process :param exception: The exception instance - :param trace_back: An optional exception traceback + :param traceback: An optional exception traceback """ - self.process = process self.exception = exception - self.traceback = trace_back + self.traceback = traceback def __str__(self) -> str: exception = traceback.format_exception_only(type(self.exception) if self.exception else None, self.exception)[0] @@ -408,7 +430,6 @@ def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persist def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) - self.process = load_context.process self.exception = yaml.load(saved_state[self.EXC_VALUE], Loader=Loader) if _HAS_TBLIB: @@ -450,14 +471,12 @@ class Finished(persistence.Savable): is_terminal = True - def __init__(self, process: 'Process', result: Any, successful: bool) -> None: - self.process = process + def __init__(self, result: Any, successful: bool) -> None: self.result = result self.successful = successful def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) - self.process = load_context.process def enter(self) -> None: ... @@ -481,17 +500,14 @@ class Killed(persistence.Savable): is_terminal = True - def __init__(self, process: 'Process', msg: Optional[MessageType]): + def __init__(self, msg: Optional[MessageType]): """ - :param process: The associated process :param msg: Optional kill message """ - self.process = process self.msg = msg def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) - self.process = load_context.process def enter(self) -> None: ... diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index 74808291..8e101900 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -59,6 +59,7 @@ StateMachineError, TransitionFailed, event, + create_state, ) from .base.utils import call_with_super_check, super_check from .event_helper import EventHelper @@ -876,7 +877,7 @@ def on_finish(self, result: Any, successful: bool) -> None: validation_error = self.spec().outputs.validate(self.outputs) if validation_error: state_cls = self.get_states_map()[process_states.ProcessState.FINISHED] - finished_state = state_cls(self, result=result, successful=False) + finished_state = state_cls(result=result, successful=False) raise StateEntryFailed(finished_state) self.future().set_result(self.outputs) @@ -1074,8 +1075,8 @@ def transition_failed( if final_state == process_states.ProcessState.CREATED: raise exception.with_traceback(trace) - state_class = self.get_states_map()[process_states.ProcessState.EXCEPTED] - new_state = self._create_state_instance(state_class, exception=exception, trace_back=trace) + # state_class = self.get_states_map()[process_states.ProcessState.EXCEPTED] + new_state = create_state(self, process_states.ProcessState.EXCEPTED, exception=exception, traceback=trace) self.transition_to(new_state) def pause(self, msg: Union[str, None] = None) -> Union[bool, futures.CancellableAction]: @@ -1148,10 +1149,11 @@ def _create_interrupt_action(self, exception: process_states.Interruption) -> fu def do_kill(_next_state: state_machine.State) -> Any: try: - state_class = self.get_states_map()[process_states.ProcessState.KILLED] - new_state = self._create_state_instance(state_class, msg=exception.msg) + new_state = create_state(self, process_states.ProcessState.KILLED, msg=exception.msg) self.transition_to(new_state) return True + # FIXME: if try block except, will hit deadlock in event loop + # need to know how to debug it, and where to set a timeout. finally: self._killing = None @@ -1196,14 +1198,14 @@ def resume(self, *args: Any) -> None: return self._state.resume(*args) # type: ignore @event(to_states=process_states.Excepted) - def fail(self, exception: Optional[BaseException], trace_back: Optional[TracebackType]) -> None: + def fail(self, exception: Optional[BaseException], traceback: Optional[TracebackType]) -> None: """ Fail the process in response to an exception :param exception: The exception that caused the failure - :param trace_back: Optional exception traceback + :param traceback: Optional exception traceback """ - state_class = self.get_states_map()[process_states.ProcessState.EXCEPTED] - new_state = self._create_state_instance(state_class, exception=exception, trace_back=trace_back) + # state_class = self.get_states_map()[process_states.ProcessState.EXCEPTED] + new_state = create_state(self, process_states.ProcessState.EXCEPTED, exception=exception, traceback=trace_back) self.transition_to(new_state) def kill(self, msg: Optional[MessageType] = None) -> Union[bool, asyncio.Future]: @@ -1232,8 +1234,7 @@ def kill(self, msg: Optional[MessageType] = None) -> Union[bool, asyncio.Future] self._state.interrupt(interrupt_exception) return cast(futures.CancellableAction, self._interrupt_action) - state_class = self.get_states_map()[process_states.ProcessState.KILLED] - new_state = self._create_state_instance(state_class, msg=msg) + new_state = create_state(self, process_states.ProcessState.KILLED, msg=msg) self.transition_to(new_state) return True @@ -1325,7 +1326,10 @@ async def step(self) -> None: raise except Exception: # Overwrite the next state to go to excepted directly - next_state = self.create_state(process_states.ProcessState.EXCEPTED, *sys.exc_info()[1:]) + _, exception, traceback = sys.exc_info() + next_state = create_state( + self, process_states.ProcessState.EXCEPTED, exception=exception, traceback=traceback + ) self._set_interrupt_action(None) if self._interrupt_action: diff --git a/src/plumpy/workchains.py b/src/plumpy/workchains.py index eefd57f1..2389942b 100644 --- a/src/plumpy/workchains.py +++ b/src/plumpy/workchains.py @@ -80,11 +80,11 @@ def __init__( process: 'WorkChain', done_callback: Optional[Callable[..., Any]], msg: Optional[str] = None, - awaiting: Optional[Dict[Union[asyncio.Future, processes.Process], str]] = None, + data: Optional[Dict[Union[asyncio.Future, processes.Process], str]] = None, ) -> None: - super().__init__(process, done_callback, msg, awaiting) + super().__init__(process, done_callback, msg, data) self._awaiting: Dict[asyncio.Future, str] = {} - for awaitable, key in (awaiting or {}).items(): + for awaitable, key in (data or {}).items(): resolved_awaitable = awaitable.future() if isinstance(awaitable, processes.Process) else awaitable self._awaiting[resolved_awaitable] = key diff --git a/tests/base/test_statemachine.py b/tests/base/test_statemachine.py index b6100614..6a61fe00 100644 --- a/tests/base/test_statemachine.py +++ b/tests/base/test_statemachine.py @@ -17,7 +17,7 @@ STOPPED = 'Stopped' -class Playing(state_machine.State): +class Playing: LABEL = PLAYING ALLOWED = {PAUSED, STOPPED} TRANSITIONS = {STOP: STOPPED} @@ -56,7 +56,7 @@ def exit(self) -> None: self.in_state = False -class Paused(state_machine.State): +class Paused: LABEL = PAUSED ALLOWED = {PLAYING, STOPPED} TRANSITIONS = {STOP: STOPPED} @@ -65,7 +65,6 @@ class Paused(state_machine.State): def __init__(self, player, playing_state): assert isinstance(playing_state, Playing), 'Must provide the playing state to pause' - super().__init__(player) self._player = player self.playing_state = playing_state @@ -74,9 +73,9 @@ def __str__(self): def play(self, track=None): if track is not None: - self.state_machine.transition_to(Playing(player=self.state_machine, track=track)) + self._player.transition_to(Playing(player=self.state_machine, track=track)) else: - self.state_machine.transition_to(self.playing_state) + self._player.transition_to(self.playing_state) def enter(self) -> None: self.in_state = True @@ -88,7 +87,7 @@ def exit(self) -> None: self.in_state = False -class Stopped(state_machine.State): +class Stopped: LABEL = STOPPED ALLOWED = { PLAYING, @@ -98,13 +97,13 @@ class Stopped(state_machine.State): is_terminal = False def __init__(self, player): - self.state_machine = player + self._player = player def __str__(self): return '[]' def play(self, track): - self.state_machine.transition_to(Playing(self.state_machine, track=track)) + self._player.transition_to(Playing(self._player, track=track)) def enter(self) -> None: self.in_state = True