Skip to content

Commit

Permalink
>3.9 typing
Browse files Browse the repository at this point in the history
- state_machine.py
- message.py
- persistence.py
- port.py
  • Loading branch information
unkcpz committed Jan 23, 2025
1 parent cb9d872 commit adb908f
Show file tree
Hide file tree
Showing 12 changed files with 163 additions and 167 deletions.
103 changes: 49 additions & 54 deletions src/plumpy/base/state_machine.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,21 @@
# -*- coding: utf-8 -*-
"""The state machine for processes"""

from __future__ import annotations

import enum
import functools
import inspect
import logging
import os
import sys
from collections.abc import Iterable, Sequence
from types import TracebackType
from typing import (
Any,
Callable,
ClassVar,
Dict,
Hashable,
Iterable,
List,
Optional,
Protocol,
Sequence,
Type,
Union,
final,
runtime_checkable,
)

Expand All @@ -34,19 +27,48 @@

_LOGGER = logging.getLogger(__name__)

EVENT_CALLBACK_TYPE = Callable[['StateMachine', Hashable, Optional['State']], None]
EVENT_CALLBACK_TYPE = Callable[['StateMachine', Hashable, 'State | None'], None]


@runtime_checkable
class State(Protocol):
LABEL: ClassVar[Any]
ALLOWED: ClassVar[set[Any]]
is_terminal: ClassVar[bool]

def __init__(self, *args: Any, **kwargs: Any): ...

def enter(self) -> None: ...

def exit(self) -> None: ...


@runtime_checkable
class Interruptable(Protocol):
def interrupt(self, reason: Exception) -> None: ...


@runtime_checkable
class Proceedable(Protocol):
def execute(self) -> State | None:
"""
Execute the state, performing the actions that this state is responsible for.
:returns: a state to transition to or None if finished.
"""
...


class StateMachineError(Exception):
"""Base class for state machine errors"""


@final
class StateEntryFailed(Exception): # noqa: N818
"""
Failed to enter a state, can provide the next state to go to via this exception
"""

def __init__(self, state: State, *args: Any, **kwargs: Any) -> None:
def __init__(self, state: 'State', *args: Any, **kwargs: Any) -> None:
super().__init__('failed to enter state')
self.state = state
self.args = args
Expand All @@ -63,11 +85,12 @@ def __init__(self, evt: str, msg: str):
self.event = evt


@final
class TransitionFailed(Exception): # noqa: N818
"""A state transition failed"""

def __init__(
self, initial_state: 'State', final_state: Optional['State'] = None, traceback_str: Optional[str] = None
self, initial_state: 'State', final_state: 'State | None' = None, traceback_str: str | None = None
) -> None:
self.initial_state = initial_state
self.final_state = final_state
Expand All @@ -82,8 +105,8 @@ def _format_msg(self) -> str:


def event(
from_states: Union[str, Type['State'], Iterable[Type['State']]] = '*',
to_states: Union[str, Type['State'], Iterable[Type['State']]] = '*',
from_states: str | type['State'] | Iterable[type['State']] = '*',
to_states: str | type['State'] | Iterable[type['State']] = '*',
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
"""A decorator to check for correct transitions, raising ``EventError`` on invalid transitions."""
if from_states != '*':
Expand Down Expand Up @@ -115,7 +138,7 @@ def transition(self: Any, *a: Any, **kw: Any) -> Any:

raise EventError(
evt_label,
'Event produced invalid state transition from ' f'{initial.LABEL} to {self._state.LABEL}',
f'Event produced invalid state transition from {initial.LABEL} to {self._state.LABEL}',
)

return result
Expand All @@ -128,35 +151,7 @@ def transition(self: Any, *a: Any, **kw: Any) -> Any:
return wrapper


@runtime_checkable
class State(Protocol):
LABEL: ClassVar[Any]
ALLOWED: ClassVar[set[Any]]
is_terminal: ClassVar[bool]

def __init__(self, *args: Any, **kwargs: Any): ...

def enter(self) -> None: ...

def exit(self) -> None: ...


@runtime_checkable
class Interruptable(Protocol):
def interrupt(self, reason: Exception) -> None: ...


@runtime_checkable
class Proceedable(Protocol):
def execute(self) -> State | None:
"""
Execute the state, performing the actions that this state is responsible for.
:returns: a state to transition to or None if finished.
"""
...


def create_state(st: StateMachine, state_label: Hashable, *args: Any, **kwargs: Any) -> State:
def create_state(st: 'StateMachine', state_label: Hashable, *args: Any, **kwargs: Any) -> State:
if state_label not in st.get_states_map():
raise ValueError(f'{state_label} is not a valid state')

Expand Down Expand Up @@ -192,20 +187,20 @@ def __call__(cls, *args: Any, **kwargs: Any) -> 'StateMachine':


class StateMachine(metaclass=StateMachineMeta):
STATES: Optional[Sequence[Type[State]]] = None
_STATES_MAP: Optional[Dict[Hashable, Type[State]]] = None
STATES: Sequence[type[State]] | None = None
_STATES_MAP: dict[Hashable, type[State]] | None = None

_transitioning = False
_transition_failing = False
_transitioning: bool = False
_transition_failing: bool = False

@classmethod
def get_states_map(cls) -> Dict[Hashable, Type[State]]:
def get_states_map(cls) -> dict[Hashable, type[State]]:
cls.__ensure_built()
assert cls._STATES_MAP is not None # required for type checking
return cls._STATES_MAP

@classmethod
def get_states(cls) -> Sequence[Type[State]]:
def get_states(cls) -> Sequence[type[State]]:
if cls.STATES is not None:
return cls.STATES

Expand All @@ -218,7 +213,7 @@ def initial_state_label(cls) -> Any:
return cls.STATES[0].LABEL

@classmethod
def get_state_class(cls, label: Any) -> Type[State]:
def get_state_class(cls, label: Any) -> type[State]:
cls.__ensure_built()
assert cls._STATES_MAP is not None
return cls._STATES_MAP[label]
Expand Down Expand Up @@ -249,11 +244,11 @@ def __ensure_built(cls) -> None:
def __init__(self) -> None:
super().__init__()
self.__ensure_built()
self._state: Optional[State] = None
self._state: State | None = None
self._exception_handler = None # Note this appears to never be used
self.set_debug((not sys.flags.ignore_environment and bool(os.environ.get('PYTHONSMDEBUG'))))
self._transitioning = False
self._event_callbacks: Dict[Hashable, List[EVENT_CALLBACK_TYPE]] = {}
self._event_callbacks: dict[Hashable, list[EVENT_CALLBACK_TYPE]] = {}

@super_check
def init(self) -> None:
Expand Down Expand Up @@ -298,7 +293,7 @@ def remove_state_event_callback(self, hook: Hashable, callback: EVENT_CALLBACK_T
except (KeyError, ValueError):
raise ValueError(f"Callback not set for hook '{hook}'")

def _fire_state_event(self, hook: Hashable, state: Optional[State]) -> None:
def _fire_state_event(self, hook: Hashable, state: State | None) -> None:
for callback in self._event_callbacks.get(hook, []):
callback(self, hook, state)

Expand Down
7 changes: 4 additions & 3 deletions src/plumpy/coordinator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# -*- coding: utf-8 -*-
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Callable, Hashable, Pattern, Protocol
import re
from typing import TYPE_CHECKING, Any, Callable, Hashable, Protocol

if TYPE_CHECKING:
# identifiers for subscribers
Expand All @@ -23,8 +24,8 @@ def add_rpc_subscriber(self, subscriber: 'RpcSubscriber', identifier: 'ID_TYPE |
def add_broadcast_subscriber(
self,
subscriber: 'BroadcastSubscriber',
subject_filters: list[Hashable | Pattern[str]] | None = None,
sender_filters: list[Hashable | Pattern[str]] | None = None,
subject_filters: list[Hashable | re.Pattern[str]] | None = None,
sender_filters: list[Hashable | re.Pattern[str]] | None = None,
identifier: 'ID_TYPE | None' = None,
) -> Any: ...

Expand Down
4 changes: 2 additions & 2 deletions src/plumpy/event_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from __future__ import annotations

import logging
from typing import TYPE_CHECKING, Any, Callable, Optional, final
from typing import TYPE_CHECKING, Any, Callable, final

from typing_extensions import Self

Expand Down Expand Up @@ -38,7 +38,7 @@ def remove_all_listeners(self) -> None:
self._listeners.clear()

@classmethod
def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> Self:
def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext | None = None) -> Self:
"""
Recreate a :class:`Savable` from a saved state using an optional load context.
Expand Down
6 changes: 3 additions & 3 deletions src/plumpy/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import asyncio
import sys
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Sequence
from typing import TYPE_CHECKING, Any, Callable, Dict, Sequence

if TYPE_CHECKING:
from .processes import Process
Expand All @@ -22,7 +22,7 @@ def new_event_loop(*args: Any, **kwargs: Any) -> asyncio.AbstractEventLoop:
class PlumpyEventLoopPolicy(asyncio.DefaultEventLoopPolicy):
"""Custom event policy that always returns the same event loop that is made reentrant by ``nest_asyncio``."""

_loop: Optional[asyncio.AbstractEventLoop] = None
_loop: asyncio.AbstractEventLoop | None = None

def get_event_loop(self) -> asyncio.AbstractEventLoop:
"""Return the patched event loop."""
Expand Down Expand Up @@ -55,7 +55,7 @@ def reset_event_loop_policy() -> None:
asyncio.set_event_loop_policy(None)


def run_until_complete(future: asyncio.Future, loop: Optional[asyncio.AbstractEventLoop] = None) -> Any:
def run_until_complete(future: asyncio.Future, loop: asyncio.AbstractEventLoop | None = None) -> Any:
loop = loop or get_event_loop()
return loop.run_until_complete(future)

Expand Down
7 changes: 5 additions & 2 deletions src/plumpy/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# -*- coding: utf-8 -*-
from typing import Optional


from typing import final


class KilledError(Exception):
Expand All @@ -12,10 +14,11 @@ class InvalidStateError(Exception):
"""


@final
class UnsuccessfulResult:
"""The result of the process was unsuccessful"""

def __init__(self, result: Optional[int] = None):
def __init__(self, result: int | None = None):
"""Initialise.
:param result: the exit code of the process
Expand Down
6 changes: 4 additions & 2 deletions src/plumpy/futures.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@

import asyncio
import contextlib
from typing import Any, Awaitable, Callable, Generator, Optional
from collections.abc import Awaitable, Generator
from typing import Any, Callable, final


class InvalidFutureError(Exception):
Expand All @@ -33,6 +34,7 @@ def capture_exceptions(future, ignore: tuple[type[BaseException], ...] = ()) ->
future.set_exception(exception)


@final
class CancellableAction(Future):
"""
An action that can be launched and potentially cancelled
Expand Down Expand Up @@ -64,7 +66,7 @@ def run(self, *args: Any, **kwargs: Any) -> None:
self._action = None # type: ignore


def create_task(coro: Callable[[], Awaitable[Any]], loop: Optional[asyncio.AbstractEventLoop] = None) -> Future:
def create_task(coro: Callable[[], Awaitable[Any]], loop: asyncio.AbstractEventLoop | None = None) -> Future:
"""
Schedule a call to a coro in the event loop and wrap the outcome
in a future.
Expand Down
6 changes: 3 additions & 3 deletions src/plumpy/loaders.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
import abc
import importlib
from typing import Any, Optional
from typing import Any


class ObjectLoader(metaclass=abc.ABCMeta):
Expand Down Expand Up @@ -62,7 +62,7 @@ def identify_object(self, obj: Any) -> str:
return identifier


OBJECT_LOADER: Optional[ObjectLoader] = None
OBJECT_LOADER: ObjectLoader | None = None


def get_object_loader() -> ObjectLoader:
Expand All @@ -78,7 +78,7 @@ def get_object_loader() -> ObjectLoader:
return OBJECT_LOADER


def set_object_loader(loader: Optional[ObjectLoader]) -> None:
def set_object_loader(loader: ObjectLoader | None) -> None:
"""
Set the plumpy global object loader
Expand Down
Loading

0 comments on commit adb908f

Please sign in to comment.