Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Notjustrmq #9

Closed
wants to merge 19 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/nitpick-exceptions
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ py:class kiwipy.communications.Communicator

# unavailable forward references
py:class plumpy.process_states.Command
py:class plumpy.process_states.State
py:class plumpy.state_machine.State
py:class plumpy.base.state_machine.State
py:class State
py:class Process
Expand Down
4 changes: 2 additions & 2 deletions docs/source/tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@
" def continue_fn(self):\n",
" print('continuing')\n",
" # message is stored in the process status\n",
" return plumpy.Kill('I was killed')\n",
" return plumpy.Kill(plumpy.KillMessage.build('I was killed'))\n",
"\n",
"\n",
"process = ContinueProcess()\n",
Expand Down Expand Up @@ -1118,7 +1118,7 @@
"\n",
"process = SimpleProcess(communicator=communicator)\n",
"\n",
"pprint(communicator.rpc_send(str(process.pid), plumpy.STATUS_MSG).result())"
"pprint(communicator.rpc_send(str(process.pid), plumpy.StatusMessage.build()).result())"
]
},
{
Expand Down
157 changes: 71 additions & 86 deletions src/plumpy/base/state_machine.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,30 @@
# -*- coding: utf-8 -*-
"""The state machine for processes"""

from __future__ import annotations

import enum
import functools
import inspect
import logging
import os
import sys
from types import TracebackType
from typing import Any, Callable, Dict, Hashable, Iterable, List, Optional, Sequence, Set, Type, Union, cast
from typing import (
Any,
Callable,
ClassVar,
Dict,
Hashable,
Iterable,
List,
Optional,
Protocol,
Sequence,
Type,
Union,
runtime_checkable,
)

from plumpy.futures import Future

Expand All @@ -18,7 +34,6 @@

_LOGGER = logging.getLogger(__name__)

LABEL_TYPE = Union[None, enum.Enum, str]
EVENT_CALLBACK_TYPE = Callable[['StateMachine', Hashable, Optional['State']], None]


Expand All @@ -31,7 +46,7 @@ class StateEntryFailed(Exception): # noqa: N818
Failed to enter a state, can provide the next state to go to via this exception
"""

def __init__(self, state: Hashable = None, *args: Any, **kwargs: Any) -> None:
def __init__(self, state: State, *args: Any, **kwargs: Any) -> None:
super().__init__('failed to enter state')
self.state = state
self.args = args
Expand Down Expand Up @@ -74,12 +89,12 @@ def event(
if from_states != '*':
if inspect.isclass(from_states):
from_states = (from_states,)
if not all(issubclass(state, State) for state in from_states): # type: ignore
if not all(isinstance(state, State) for state in from_states): # type: ignore
raise TypeError(f'from_states: {from_states}')
if to_states != '*':
if inspect.isclass(to_states):
to_states = (to_states,)
if not all(issubclass(state, State) for state in to_states): # type: ignore
if not all(isinstance(state, State) for state in to_states): # type: ignore
raise TypeError(f'to_states: {to_states}')

def wrapper(wrapped: Callable[..., Any]) -> Callable[..., Any]:
Expand Down Expand Up @@ -113,57 +128,40 @@ def transition(self: Any, *a: Any, **kw: Any) -> Any:
return wrapper


class State:
LABEL: LABEL_TYPE = None
# A set containing the labels of states that can be entered
# from this one
ALLOWED: Set[LABEL_TYPE] = set()
@runtime_checkable
class State(Protocol):
LABEL: ClassVar[Any]
ALLOWED: ClassVar[set[Any]]
is_terminal: ClassVar[bool]

@classmethod
def is_terminal(cls) -> bool:
return not cls.ALLOWED
def __init__(self, *args: Any, **kwargs: Any): ...

def __init__(self, state_machine: 'StateMachine', *args: Any, **kwargs: Any):
"""
:param state_machine: The process this state belongs to
"""
self.state_machine = state_machine
self.in_state: bool = False
def enter(self) -> None: ...

def __str__(self) -> str:
return str(self.LABEL)
def exit(self) -> None: ...

@property
def label(self) -> LABEL_TYPE:
"""Convenience property to get the state label"""
return self.LABEL

@super_check
def enter(self) -> None:
"""Entering the state"""
@runtime_checkable
class Interruptable(Protocol):
def interrupt(self, reason: Exception) -> None: ...


def execute(self) -> Optional['State']:
@runtime_checkable
class Proceedable(Protocol):
def execute(self) -> State | None:
"""
Execute the state, performing the actions that this state is responsible for.
:returns: a state to transition to or None if finished.
"""
...

@super_check
def exit(self) -> None:
"""Exiting the state"""
if self.is_terminal():
raise InvalidStateError(f'Cannot exit a terminal state {self.LABEL}')

def create_state(self, state_label: Hashable, *args: Any, **kwargs: Any) -> 'State':
return self.state_machine.create_state(state_label, *args, **kwargs)
def create_state(st: StateMachine, state_label: Hashable, *args: Any, **kwargs: Any) -> State:
if state_label not in st.get_states_map():
raise ValueError(f'{state_label} is not a valid state')

def do_enter(self) -> None:
call_with_super_check(self.enter)
self.in_state = True

def do_exit(self) -> None:
call_with_super_check(self.exit)
self.in_state = False
state_cls = st.get_states_map()[state_label]
return state_cls(*args, **kwargs)


class StateEventHook(enum.Enum):
Expand All @@ -187,7 +185,7 @@ def __call__(cls, *args: Any, **kwargs: Any) -> 'StateMachine':
:param kwargs: Any keyword arguments to be passed to the constructor
:return: An instance of the state machine
"""
inst = super().__call__(*args, **kwargs)
inst: StateMachine = super().__call__(*args, **kwargs)
inst.transition_to(inst.create_initial_state())
call_with_super_check(inst.init)
return inst
Expand All @@ -214,13 +212,13 @@ def get_states(cls) -> Sequence[Type[State]]:
raise RuntimeError('States not defined')

@classmethod
def initial_state_label(cls) -> LABEL_TYPE:
def initial_state_label(cls) -> Any:
cls.__ensure_built()
assert cls.STATES is not None
return cls.STATES[0].LABEL

@classmethod
def get_state_class(cls, label: LABEL_TYPE) -> Type[State]:
def get_state_class(cls, label: Any) -> Type[State]:
cls.__ensure_built()
assert cls._STATES_MAP is not None
return cls._STATES_MAP[label]
Expand All @@ -240,7 +238,7 @@ def __ensure_built(cls) -> None:
# Build the states map
cls._STATES_MAP = {}
for state_cls in cls.STATES:
assert issubclass(state_cls, State)
assert isinstance(state_cls, State)
label = state_cls.LABEL
assert label not in cls._STATES_MAP, f"Duplicate label '{label}'"
cls._STATES_MAP[label] = state_cls
Expand All @@ -264,11 +262,11 @@ def init(self) -> None:
def __str__(self) -> str:
return f'<{self.__class__.__name__}> ({self.state})'

def create_initial_state(self) -> State:
return self.get_state_class(self.initial_state_label())(self)
def create_initial_state(self, *args: Any, **kwargs: Any) -> State:
return self.get_state_class(self.initial_state_label())(self, *args, **kwargs)

@property
def state(self) -> Optional[LABEL_TYPE]:
def state(self) -> Any:
if self._state is None:
return None
return self._state.LABEL
Expand Down Expand Up @@ -300,16 +298,24 @@ def _fire_state_event(self, hook: Hashable, state: Optional[State]) -> None:
def on_terminated(self) -> None:
"""Called when a terminal state is entered"""

def transition_to(self, new_state: Union[Hashable, State, Type[State]], *args: Any, **kwargs: Any) -> None:
def transition_to(self, new_state: State | None, **kwargs: Any) -> None:
"""Transite to the new state.

The new target state will be create lazily when the state is not yet instantiated,
which will happened for states not in the expect path such as pause and kill.
The arguments are passed to the state class to create state instance.
(process arg does not need to pass since it will always call with 'self' as process)
"""
print(f'try: {self._state} -> {new_state}')
assert not self._transitioning, 'Cannot call transition_to when already transitioning state'

if new_state is None:
return None

initial_state_label = self._state.LABEL if self._state is not None else None
label = None
try:
self._transitioning = True

# Make sure we have a state instance
new_state = self._create_state_instance(new_state, *args, **kwargs)
label = new_state.LABEL

# If the previous transition failed, do not try to exit it but go straight to next state
Expand All @@ -319,13 +325,12 @@ def transition_to(self, new_state: Union[Hashable, State, Type[State]], *args: A
try:
self._enter_next_state(new_state)
except StateEntryFailed as exception:
# Make sure we have a state instance
new_state = self._create_state_instance(exception.state, *exception.args, **exception.kwargs)
new_state = exception.state
label = new_state.LABEL
self._exit_current_state(new_state)
self._enter_next_state(new_state)

if self._state is not None and self._state.is_terminal():
if self._state is not None and self._state.is_terminal:
call_with_super_check(self.on_terminated)
except Exception:
self._transitioning = False
Expand All @@ -338,7 +343,11 @@ def transition_to(self, new_state: Union[Hashable, State, Type[State]], *args: A
self._transitioning = False

def transition_failed(
self, initial_state: Hashable, final_state: Hashable, exception: Exception, trace: TracebackType
self,
initial_state: Hashable,
final_state: Hashable,
exception: Exception,
trace: TracebackType,
) -> None:
"""Called when a state transitions fails.

Expand All @@ -354,49 +363,25 @@ def get_debug(self) -> bool:
def set_debug(self, enabled: bool) -> None:
self._debug: bool = enabled

def create_state(self, state_label: Hashable, *args: Any, **kwargs: Any) -> State:
try:
return self.get_states_map()[state_label](self, *args, **kwargs)
except KeyError:
raise ValueError(f'{state_label} is not a valid state')

def _exit_current_state(self, next_state: State) -> None:
"""Exit the given state"""

# If we're just being constructed we may not have a state yet to exit,
# in which case check the new state is the initial state
if self._state is None:
if next_state.label != self.initial_state_label():
if next_state.LABEL != self.initial_state_label():
raise RuntimeError(f"Cannot enter state '{next_state}' as the initial state")
return # Nothing to exit

if next_state.LABEL not in self._state.ALLOWED:
raise RuntimeError(f'Cannot transition from {self._state.LABEL} to {next_state.label}')
raise RuntimeError(f'Cannot transition from {self._state.LABEL} to {next_state.LABEL}')
self._fire_state_event(StateEventHook.EXITING_STATE, next_state)
self._state.do_exit()
self._state.exit()

def _enter_next_state(self, next_state: State) -> None:
last_state = self._state
self._fire_state_event(StateEventHook.ENTERING_STATE, next_state)
# Enter the new state
next_state.do_enter()
next_state.enter()
self._state = next_state
self._fire_state_event(StateEventHook.ENTERED_STATE, last_state)

def _create_state_instance(self, state: Union[Hashable, State, Type[State]], *args: Any, **kwargs: Any) -> State:
if isinstance(state, State):
# It's already a state instance
return state

# OK, have to create it
state_cls = self._ensure_state_class(state)
return state_cls(self, *args, **kwargs)

def _ensure_state_class(self, state: Union[Hashable, Type[State]]) -> Type[State]:
if inspect.isclass(state) and issubclass(state, State):
return state

try:
return self.get_states_map()[cast(Hashable, state)]
except KeyError:
raise ValueError(f'{state} is not a valid state')
Loading
Loading