Skip to content

Commit

Permalink
>3.9 typing
Browse files Browse the repository at this point in the history
- state_machine.py
  • Loading branch information
unkcpz committed Jan 23, 2025
1 parent 510bbe4 commit 2ce7971
Showing 1 changed file with 50 additions and 55 deletions.
105 changes: 50 additions & 55 deletions src/plumpy/base/state_machine.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
# -*- coding: utf-8 -*-
"""The state machine for processes"""

from __future__ import annotations

import enum
import functools
import inspect
Expand All @@ -14,17 +12,12 @@
Any,
Callable,
ClassVar,
Dict,
Hashable,
Iterable,
List,
Optional,
Protocol,
Sequence,
Type,
Union,
final,
runtime_checkable,
)
from collections.abc import Iterable, Sequence

from plumpy.futures import Future

Expand All @@ -34,19 +27,48 @@

_LOGGER = logging.getLogger(__name__)

EVENT_CALLBACK_TYPE = Callable[['StateMachine', Hashable, Optional['State']], None]
EVENT_CALLBACK_TYPE = Callable[['StateMachine', Hashable, 'State | None'], None]


@runtime_checkable
class State(Protocol):
LABEL: ClassVar[Any]
ALLOWED: ClassVar[set[Any]]
is_terminal: ClassVar[bool]

def __init__(self, *args: Any, **kwargs: Any): ...

def enter(self) -> None: ...

def exit(self) -> None: ...


@runtime_checkable
class Interruptable(Protocol):
def interrupt(self, reason: Exception) -> None: ...


@runtime_checkable
class Proceedable(Protocol):
def execute(self) -> State | None:
"""
Execute the state, performing the actions that this state is responsible for.
:returns: a state to transition to or None if finished.
"""
...


class StateMachineError(Exception):
"""Base class for state machine errors"""


@final
class StateEntryFailed(Exception): # noqa: N818
"""
Failed to enter a state, can provide the next state to go to via this exception
"""

def __init__(self, state: State, *args: Any, **kwargs: Any) -> None:
def __init__(self, state: 'State', *args: Any, **kwargs: Any) -> None:
super().__init__('failed to enter state')
self.state = state
self.args = args
Expand All @@ -63,11 +85,12 @@ def __init__(self, evt: str, msg: str):
self.event = evt


@final
class TransitionFailed(Exception): # noqa: N818
"""A state transition failed"""

def __init__(
self, initial_state: 'State', final_state: Optional['State'] = None, traceback_str: Optional[str] = None
self, initial_state: 'State', final_state: 'State | None' = None, traceback_str: str | None = None
) -> None:
self.initial_state = initial_state
self.final_state = final_state
Expand All @@ -82,8 +105,8 @@ def _format_msg(self) -> str:


def event(
from_states: Union[str, Type['State'], Iterable[Type['State']]] = '*',
to_states: Union[str, Type['State'], Iterable[Type['State']]] = '*',
from_states: str | type['State'] | Iterable[type['State']] = '*',
to_states: str | type['State'] | Iterable[type['State']] = '*',
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
"""A decorator to check for correct transitions, raising ``EventError`` on invalid transitions."""
if from_states != '*':
Expand Down Expand Up @@ -115,7 +138,7 @@ def transition(self: Any, *a: Any, **kw: Any) -> Any:

raise EventError(
evt_label,
'Event produced invalid state transition from ' f'{initial.LABEL} to {self._state.LABEL}',
f'Event produced invalid state transition from {initial.LABEL} to {self._state.LABEL}',
)

return result
Expand All @@ -128,35 +151,7 @@ def transition(self: Any, *a: Any, **kw: Any) -> Any:
return wrapper


@runtime_checkable
class State(Protocol):
LABEL: ClassVar[Any]
ALLOWED: ClassVar[set[Any]]
is_terminal: ClassVar[bool]

def __init__(self, *args: Any, **kwargs: Any): ...

def enter(self) -> None: ...

def exit(self) -> None: ...


@runtime_checkable
class Interruptable(Protocol):
def interrupt(self, reason: Exception) -> None: ...


@runtime_checkable
class Proceedable(Protocol):
def execute(self) -> State | None:
"""
Execute the state, performing the actions that this state is responsible for.
:returns: a state to transition to or None if finished.
"""
...


def create_state(st: StateMachine, state_label: Hashable, *args: Any, **kwargs: Any) -> State:
def create_state(st: 'StateMachine', state_label: Hashable, *args: Any, **kwargs: Any) -> State:
if state_label not in st.get_states_map():
raise ValueError(f'{state_label} is not a valid state')

Expand Down Expand Up @@ -192,20 +187,20 @@ def __call__(cls, *args: Any, **kwargs: Any) -> 'StateMachine':


class StateMachine(metaclass=StateMachineMeta):
STATES: Optional[Sequence[Type[State]]] = None
_STATES_MAP: Optional[Dict[Hashable, Type[State]]] = None
STATES: Sequence[type[State]] | None = None
_STATES_MAP: dict[Hashable, type[State]] | None = None

_transitioning = False
_transition_failing = False
_transitioning: bool = False
_transition_failing: bool = False

@classmethod
def get_states_map(cls) -> Dict[Hashable, Type[State]]:
def get_states_map(cls) -> dict[Hashable, type[State]]:
cls.__ensure_built()
assert cls._STATES_MAP is not None # required for type checking
return cls._STATES_MAP

@classmethod
def get_states(cls) -> Sequence[Type[State]]:
def get_states(cls) -> Sequence[type[State]]:
if cls.STATES is not None:
return cls.STATES

Expand All @@ -218,7 +213,7 @@ def initial_state_label(cls) -> Any:
return cls.STATES[0].LABEL

@classmethod
def get_state_class(cls, label: Any) -> Type[State]:
def get_state_class(cls, label: Any) -> type[State]:
cls.__ensure_built()
assert cls._STATES_MAP is not None
return cls._STATES_MAP[label]
Expand All @@ -244,16 +239,16 @@ def __ensure_built(cls) -> None:
cls._STATES_MAP[label] = state_cls

# should class initialise sealed = False?
cls.sealed = True # type: ignore
cls.sealed: bool = True

def __init__(self) -> None:
super().__init__()
self.__ensure_built()
self._state: Optional[State] = None
self._state: State | None = None
self._exception_handler = None # Note this appears to never be used
self.set_debug((not sys.flags.ignore_environment and bool(os.environ.get('PYTHONSMDEBUG'))))
self._transitioning = False
self._event_callbacks: Dict[Hashable, List[EVENT_CALLBACK_TYPE]] = {}
self._event_callbacks: dict[Hashable, list[EVENT_CALLBACK_TYPE]] = {}

@super_check
def init(self) -> None:
Expand Down Expand Up @@ -298,7 +293,7 @@ def remove_state_event_callback(self, hook: Hashable, callback: EVENT_CALLBACK_T
except (KeyError, ValueError):
raise ValueError(f"Callback not set for hook '{hook}'")

def _fire_state_event(self, hook: Hashable, state: Optional[State]) -> None:
def _fire_state_event(self, hook: Hashable, state: State | None) -> None:
for callback in self._event_callbacks.get(hook, []):
callback(self, hook, state)

Expand Down

0 comments on commit 2ce7971

Please sign in to comment.