Skip to content

Commit

Permalink
create_state refact
Browse files Browse the repository at this point in the history
Hashable initialized + parameters passed to Hashable
  • Loading branch information
unkcpz committed Dec 4, 2024
1 parent 9beea25 commit b88a0e2
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 74 deletions.
27 changes: 9 additions & 18 deletions src/plumpy/base/state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,6 @@ def interrupt(self, reason: Exception) -> None: ...

@runtime_checkable
class Proceedable(Protocol):

def execute(self) -> State | None:
"""
Execute the state, performing the actions that this state is responsible for.
Expand All @@ -156,6 +155,14 @@ def execute(self) -> State | None:
...


def create_state(st: state_machine.StateMachine, state_label: Hashable, *args, **kwargs: Any) -> state_machine.State:
if state_label not in st.get_states_map():
raise ValueError(f'{state_label} is not a valid state')

state_cls = st.get_states_map()[state_label]
return state_cls(*args, **kwargs)


class StateEventHook(enum.Enum):
"""
Hooks that can be used to register callback at various points in the state transition
Expand Down Expand Up @@ -298,6 +305,7 @@ def transition_to(self, new_state: State | None, **kwargs: Any) -> None:
The arguments are passed to the state class to create state instance.
(process arg does not need to pass since it will always call with 'self' as process)
"""
print(f'try: {self._state} -> {new_state}')
assert not self._transitioning, 'Cannot call transition_to when already transitioning state'

if new_state is None:
Expand Down Expand Up @@ -354,17 +362,6 @@ def get_debug(self) -> bool:
def set_debug(self, enabled: bool) -> None:
self._debug: bool = enabled

def create_state(self, state_label: Hashable, *args: Any, **kwargs: Any) -> State:
# XXX: this method create state from label, which is duplicate as _create_state_instance and less generic
# because the label is defined after the state and required to be know before calling this function.
# This method should be replaced by `_create_state_instance`.
# aiida-core using this method for its Waiting state override.
try:
state_cls = self.get_states_map()[state_label]
return state_cls(self, *args, **kwargs)
except KeyError:
raise ValueError(f'{state_label} is not a valid state')

def _exit_current_state(self, next_state: State) -> None:
"""Exit the given state"""

Expand All @@ -387,9 +384,3 @@ def _enter_next_state(self, next_state: State) -> None:
next_state.enter()
self._state = next_state
self._fire_state_event(StateEventHook.ENTERED_STATE, last_state)

def _create_state_instance(self, state_cls: type[State], **kwargs: Any) -> State:
if state_cls.LABEL not in self.get_states_map():
raise ValueError(f'{state_cls.LABEL} is not a valid state')

return state_cls(self, **kwargs)
82 changes: 49 additions & 33 deletions src/plumpy/process_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import traceback
from enum import Enum
from types import TracebackType
from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple, Type, Union, cast, final
from typing import TYPE_CHECKING, Any, Callable, Hashable, Optional, Tuple, Type, Union, cast, final

import yaml
from yaml.loader import Loader
Expand All @@ -20,7 +20,7 @@
_HAS_TBLIB = False

from . import exceptions, futures, persistence, utils
from .base import state_machine
from .base import state_machine as st
from .lang import NULL
from .persistence import auto_persist
from .utils import SAVED_STATE_TYPE
Expand Down Expand Up @@ -138,6 +138,7 @@ class ProcessState(Enum):
The possible states that a :class:`~plumpy.processes.Process` can be in.
"""

# FIXME: see LSP error of return a exception, the type is Literal[str] which is invariant, tricky
CREATED = 'created'
RUNNING = 'running'
WAITING = 'waiting'
Expand Down Expand Up @@ -172,8 +173,10 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persi

self.run_fn = getattr(self.process, saved_state[self.RUN_FN])

async def execute(self) -> state_machine.State:
return self.process.create_state(ProcessState.RUNNING, self.run_fn, *self.args, **self.kwargs)
def execute(self) -> st.State:
return st.create_state(
self.process, ProcessState.RUNNING, process=self.process, run_fn=self.run_fn, *self.args, **self.kwargs
)

def enter(self) -> None: ...

Expand Down Expand Up @@ -227,7 +230,7 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persi
def interrupt(self, reason: Any) -> None:
pass

async def execute(self) -> state_machine.State:
def execute(self) -> st.State:
if self._command is not None:
command = self._command
else:
Expand All @@ -241,8 +244,10 @@ async def execute(self) -> state_machine.State:
# Let this bubble up to the caller
raise
except Exception:
excepted = self.process.create_state(ProcessState.EXCEPTED, *sys.exc_info()[1:])
return cast(state_machine.State, excepted)
_, exception, traceback = sys.exc_info()
# excepted = state_cls(exception=exception, traceback=traceback)
excepted = Excepted(exception=exception, traceback=traceback)
return excepted
else:
if not isinstance(result, Command):
if isinstance(result, exceptions.UnsuccessfulResult):
Expand All @@ -256,21 +261,37 @@ async def execute(self) -> state_machine.State:
next_state = self._action_command(command)
return next_state

def _action_command(self, command: Union[Kill, Stop, Wait, Continue]) -> state_machine.State:
def _action_command(self, command: Union[Kill, Stop, Wait, Continue]) -> st.State:
if isinstance(command, Kill):
state = self.process.create_state(ProcessState.KILLED, command.msg)
state = st.create_state(self.process, ProcessState.KILLED, msg=command.msg)
# elif isinstance(command, Pause):
# self.pause()
elif isinstance(command, Stop):
state = self.process.create_state(ProcessState.FINISHED, command.result, command.successful)
state = st.create_state(
self.process, ProcessState.FINISHED, result=command.result, successful=command.successful
)
elif isinstance(command, Wait):
state = self.process.create_state(ProcessState.WAITING, command.continue_fn, command.msg, command.data)
state = st.create_state(
self.process,
ProcessState.WAITING,
process=self.process,
done_callback=command.continue_fn,
msg=command.msg,
data=command.data,
)
elif isinstance(command, Continue):
state = self.process.create_state(ProcessState.RUNNING, command.continue_fn, *command.args)
state = st.create_state(
self.process,
ProcessState.RUNNING,
process=self.process,
run_fn=command.continue_fn,
*command.args,
**command.kwargs,
)
else:
raise ValueError('Unrecognised command')

return cast(state_machine.State, state) # casting from base.State to process.State
return cast(st.State, state) # casting from base.State to process.State

def enter(self) -> None: ...

Expand Down Expand Up @@ -334,7 +355,7 @@ def interrupt(self, reason: Exception) -> None:
# This will cause the future in execute() to raise the exception
self._waiting_future.set_exception(reason)

async def execute(self) -> state_machine.State: # type: ignore
async def execute(self) -> st.State: # type: ignore
try:
result = await self._waiting_future
except Interruption:
Expand All @@ -345,11 +366,15 @@ async def execute(self) -> state_machine.State: # type: ignore
raise

if result == NULL:
next_state = self.process.create_state(ProcessState.RUNNING, self.done_callback)
next_state = st.create_state(
self.process, ProcessState.RUNNING, process=self.process, run_fn=self.done_callback
)
else:
next_state = self.process.create_state(ProcessState.RUNNING, self.done_callback, result)
next_state = st.create_state(
self.process, ProcessState.RUNNING, process=self.process, done_callback=self.done_callback, *result
)

return cast(state_machine.State, next_state) # casting from base.State to process.State
return cast(st.State, next_state) # casting from base.State to process.State

def resume(self, value: Any = NULL) -> None:
assert self._waiting_future is not None, 'Not yet waiting'
Expand All @@ -367,10 +392,10 @@ def exit(self) -> None: ...
@final
class Excepted(persistence.Savable):
"""
Excepted state, can optionally provide exception and trace_back
Excepted state, can optionally provide exception and traceback
:param exception: The exception instance
:param trace_back: An optional exception traceback
:param traceback: An optional exception traceback
"""

LABEL = ProcessState.EXCEPTED
Expand All @@ -383,18 +408,15 @@ class Excepted(persistence.Savable):

def __init__(
self,
process: 'Process',
exception: Optional[BaseException],
trace_back: Optional[TracebackType] = None,
traceback: Optional[TracebackType] = None,
):
"""
:param process: The associated process
:param exception: The exception instance
:param trace_back: An optional exception traceback
:param traceback: An optional exception traceback
"""
self.process = process
self.exception = exception
self.traceback = trace_back
self.traceback = traceback

def __str__(self) -> str:
exception = traceback.format_exception_only(type(self.exception) if self.exception else None, self.exception)[0]
Expand All @@ -408,7 +430,6 @@ def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persist

def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None:
super().load_instance_state(saved_state, load_context)
self.process = load_context.process

self.exception = yaml.load(saved_state[self.EXC_VALUE], Loader=Loader)
if _HAS_TBLIB:
Expand Down Expand Up @@ -450,14 +471,12 @@ class Finished(persistence.Savable):

is_terminal = True

def __init__(self, process: 'Process', result: Any, successful: bool) -> None:
self.process = process
def __init__(self, result: Any, successful: bool) -> None:
self.result = result
self.successful = successful

def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None:
super().load_instance_state(saved_state, load_context)
self.process = load_context.process

def enter(self) -> None: ...

Expand All @@ -481,17 +500,14 @@ class Killed(persistence.Savable):

is_terminal = True

def __init__(self, process: 'Process', msg: Optional[MessageType]):
def __init__(self, msg: Optional[MessageType]):
"""
:param process: The associated process
:param msg: Optional kill message
"""
self.process = process
self.msg = msg

def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None:
super().load_instance_state(saved_state, load_context)
self.process = load_context.process

def enter(self) -> None: ...

Expand Down
28 changes: 16 additions & 12 deletions src/plumpy/processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
StateMachineError,
TransitionFailed,
event,
create_state,
)
from .base.utils import call_with_super_check, super_check
from .event_helper import EventHelper
Expand Down Expand Up @@ -876,7 +877,7 @@ def on_finish(self, result: Any, successful: bool) -> None:
validation_error = self.spec().outputs.validate(self.outputs)
if validation_error:
state_cls = self.get_states_map()[process_states.ProcessState.FINISHED]
finished_state = state_cls(self, result=result, successful=False)
finished_state = state_cls(result=result, successful=False)
raise StateEntryFailed(finished_state)

self.future().set_result(self.outputs)
Expand Down Expand Up @@ -1074,8 +1075,8 @@ def transition_failed(
if final_state == process_states.ProcessState.CREATED:
raise exception.with_traceback(trace)

state_class = self.get_states_map()[process_states.ProcessState.EXCEPTED]
new_state = self._create_state_instance(state_class, exception=exception, trace_back=trace)
# state_class = self.get_states_map()[process_states.ProcessState.EXCEPTED]
new_state = create_state(self, process_states.ProcessState.EXCEPTED, exception=exception, traceback=trace)
self.transition_to(new_state)

def pause(self, msg: Union[str, None] = None) -> Union[bool, futures.CancellableAction]:
Expand Down Expand Up @@ -1148,10 +1149,11 @@ def _create_interrupt_action(self, exception: process_states.Interruption) -> fu

def do_kill(_next_state: state_machine.State) -> Any:
try:
state_class = self.get_states_map()[process_states.ProcessState.KILLED]
new_state = self._create_state_instance(state_class, msg=exception.msg)
new_state = create_state(self, process_states.ProcessState.KILLED, msg=exception.msg)
self.transition_to(new_state)
return True
# FIXME: if try block except, will hit deadlock in event loop
# need to know how to debug it, and where to set a timeout.
finally:
self._killing = None

Expand Down Expand Up @@ -1196,14 +1198,14 @@ def resume(self, *args: Any) -> None:
return self._state.resume(*args) # type: ignore

@event(to_states=process_states.Excepted)
def fail(self, exception: Optional[BaseException], trace_back: Optional[TracebackType]) -> None:
def fail(self, exception: Optional[BaseException], traceback: Optional[TracebackType]) -> None:
"""
Fail the process in response to an exception
:param exception: The exception that caused the failure
:param trace_back: Optional exception traceback
:param traceback: Optional exception traceback
"""
state_class = self.get_states_map()[process_states.ProcessState.EXCEPTED]
new_state = self._create_state_instance(state_class, exception=exception, trace_back=trace_back)
# state_class = self.get_states_map()[process_states.ProcessState.EXCEPTED]
new_state = create_state(self, process_states.ProcessState.EXCEPTED, exception=exception, traceback=trace_back)
self.transition_to(new_state)

def kill(self, msg: Optional[MessageType] = None) -> Union[bool, asyncio.Future]:
Expand Down Expand Up @@ -1232,8 +1234,7 @@ def kill(self, msg: Optional[MessageType] = None) -> Union[bool, asyncio.Future]
self._state.interrupt(interrupt_exception)
return cast(futures.CancellableAction, self._interrupt_action)

state_class = self.get_states_map()[process_states.ProcessState.KILLED]
new_state = self._create_state_instance(state_class, msg=msg)
new_state = create_state(self, process_states.ProcessState.KILLED, msg=msg)
self.transition_to(new_state)
return True

Expand Down Expand Up @@ -1325,7 +1326,10 @@ async def step(self) -> None:
raise
except Exception:
# Overwrite the next state to go to excepted directly
next_state = self.create_state(process_states.ProcessState.EXCEPTED, *sys.exc_info()[1:])
_, exception, traceback = sys.exc_info()
next_state = create_state(
self, process_states.ProcessState.EXCEPTED, exception=exception, traceback=traceback
)
self._set_interrupt_action(None)

if self._interrupt_action:
Expand Down
6 changes: 3 additions & 3 deletions src/plumpy/workchains.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,11 @@ def __init__(
process: 'WorkChain',
done_callback: Optional[Callable[..., Any]],
msg: Optional[str] = None,
awaiting: Optional[Dict[Union[asyncio.Future, processes.Process], str]] = None,
data: Optional[Dict[Union[asyncio.Future, processes.Process], str]] = None,
) -> None:
super().__init__(process, done_callback, msg, awaiting)
super().__init__(process, done_callback, msg, data)
self._awaiting: Dict[asyncio.Future, str] = {}
for awaitable, key in (awaiting or {}).items():
for awaitable, key in (data or {}).items():
resolved_awaitable = awaitable.future() if isinstance(awaitable, processes.Process) else awaitable
self._awaiting[resolved_awaitable] = key

Expand Down
Loading

0 comments on commit b88a0e2

Please sign in to comment.