From 29e3979fb0e8b1fc6299818c7b7037c5e79edd64 Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Thu, 23 Jan 2025 18:46:17 +0100 Subject: [PATCH] processes.py --- src/plumpy/processes.py | 114 +++++++++++++++++++--------------------- 1 file changed, 53 insertions(+), 61 deletions(-) diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index 9d985344..6e3610a2 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -16,22 +16,14 @@ import time import uuid import warnings +from collections.abc import Awaitable, Generator, Sequence from types import TracebackType from typing import ( Any, - Awaitable, Callable, ClassVar, - Dict, - Generator, Hashable, - List, - Optional, - Sequence, - Tuple, - Type, TypeVar, - Union, cast, ) @@ -161,18 +153,18 @@ class Process(StateMachine, metaclass=ProcessStateMachineMeta): _spec_class = ProcessSpec # Default placeholders, will be populated in init() _stepping = False - _pausing: Optional[CancellableAction] = None - _paused: Optional[persistence.SavableFuture] = None - _killing: Optional[CancellableAction] = None - _interrupt_action: Optional[CancellableAction] = None + _pausing: CancellableAction | None = None + _paused: persistence.SavableFuture | None = None + _killing: CancellableAction | None = None + _interrupt_action: CancellableAction | None = None _closed = False - _cleanups: Optional[List[Callable[[], None]]] = None + _cleanups: list[Callable[[], None]] | None = None __called: bool = False _auto_persist: ClassVar[set[str]] @classmethod - def current(cls) -> Optional['Process']: + def current(cls) -> 'Process | None': """ Get the currently running process i.e. the one at the top of the stack @@ -185,7 +177,7 @@ def current(cls) -> Optional['Process']: return None @classmethod - def get_states(cls) -> Sequence[Type[state_machine.State]]: + def get_states(cls) -> Sequence[type[state_machine.State]]: """Return all allowed states of the process.""" state_classes = cls.get_state_classes() return ( @@ -194,7 +186,7 @@ def get_states(cls) -> Sequence[Type[state_machine.State]]: ) @classmethod - def get_state_classes(cls) -> dict[process_states.ProcessState, Type[state_machine.State]]: + def get_state_classes(cls) -> dict[process_states.ProcessState, type[state_machine.State]]: # A mapping of the State constants to the corresponding state class return { process_states.ProcessState.CREATED: process_states.Created, @@ -238,14 +230,14 @@ def define(cls, _spec: ProcessSpec) -> None: cls.__called = True @classmethod - def get_description(cls) -> Dict[str, Any]: + def get_description(cls) -> dict[str, Any]: """ Get a human readable description of what this :class:`Process` does. :return: The description. """ - description: Dict[str, Any] = {} + description: dict[str, Any] = {} if cls.__doc__: description['description'] = cls.__doc__.strip() @@ -260,7 +252,7 @@ def get_description(cls) -> Dict[str, Any]: def recreate_from( cls, saved_state: SAVED_STATE_TYPE, - load_context: Optional[persistence.LoadSaveContext] = None, + load_context: persistence.LoadSaveContext | None = None, ) -> Self: """Recreate a process from a saved state, passing any positional @@ -329,11 +321,11 @@ def recreate_from( def __init__( self, - inputs: Optional[dict] = None, - pid: Optional[PID_TYPE] = None, - logger: Optional[logging.Logger] = None, - loop: Optional[asyncio.AbstractEventLoop] = None, - coordinator: Optional[Coordinator] = None, + inputs: dict | None = None, + pid: PID_TYPE | None = None, + logger: logging.Logger | None = None, + loop: asyncio.AbstractEventLoop | None = None, + coordinator: Coordinator | None = None, ) -> None: """ The signature of the constructor should not be changed by subclassing processes. @@ -354,8 +346,8 @@ def __init__( self._setup_event_hooks() - self._status: Optional[str] = None # May hold a current status message - self._pre_paused_status: Optional[str] = ( + self._status: str | None = None # May hold a current status message + self._pre_paused_status: str | None = ( None # Save status when a pause message replaces it, such that it can be restored ) self._paused = None @@ -363,10 +355,10 @@ def __init__( # Input/output self._raw_inputs = None if inputs is None else utils.AttributesFrozendict(inputs) self._pid = pid - self._parsed_inputs: Optional[utils.AttributesFrozendict] = None - self._outputs: Dict[str, Any] = {} - self._uuid: Optional[uuid.UUID] = None - self._creation_time: Optional[float] = None + self._parsed_inputs: utils.AttributesFrozendict | None = None + self._outputs: dict[str, Any] = {} + self._uuid: uuid.UUID | None = None + self._creation_time: float | None = None # Runtime variables self._future = persistence.SavableFuture(loop=self._loop) @@ -421,7 +413,7 @@ def _setup_event_hooks(self) -> None: cast(state_machine.State, state) ), state_machine.StateEventHook.ENTERED_STATE: lambda _s, _h, from_state: self.on_entered( - cast(Optional[state_machine.State], from_state) + cast(state_machine.State | None, from_state) ), state_machine.StateEventHook.EXITING_STATE: lambda _s, _h, _state: self.on_exiting(), } @@ -429,7 +421,7 @@ def _setup_event_hooks(self) -> None: self.add_state_event_callback(hook, callback) @property - def creation_time(self) -> Optional[float]: + def creation_time(self) -> float | None: """ The creation time of this Process as returned by time.time() when instantiated :return: The creation time @@ -437,27 +429,27 @@ def creation_time(self) -> Optional[float]: return self._creation_time @property - def pid(self) -> Optional[PID_TYPE]: + def pid(self) -> PID_TYPE | None: """Return the pid of the process.""" return self._pid @property - def uuid(self) -> Optional[uuid.UUID]: + def uuid(self) -> uuid.UUID | None: """Return the UUID of the process""" return self._uuid @property - def raw_inputs(self) -> Optional[utils.AttributesFrozendict]: + def raw_inputs(self) -> utils.AttributesFrozendict | None: """The `AttributesFrozendict` of inputs (if not None).""" return self._raw_inputs @property - def inputs(self) -> Optional[utils.AttributesFrozendict]: + def inputs(self) -> utils.AttributesFrozendict | None: """Return the parsed inputs.""" return self._parsed_inputs @property - def outputs(self) -> Dict[str, Any]: + def outputs(self) -> dict[str, Any]: """ Get the current outputs emitted by the Process. These may grow over time as the process runs. @@ -482,11 +474,11 @@ def logger(self) -> logging.Logger: return _LOGGER @property - def status(self) -> Optional[str]: + def status(self) -> str | None: """Return the status massage of the process.""" return self._status - def set_status(self, status: Optional[str]) -> None: + def set_status(self, status: str | None) -> None: """Set the status message of the process.""" self._status = status @@ -505,10 +497,10 @@ def future(self) -> persistence.SavableFuture: @ensure_not_closed def launch( self, - process_class: Type['Process'], - inputs: Optional[dict] = None, - pid: Optional[PID_TYPE] = None, - logger: Optional[logging.Logger] = None, + process_class: type['Process'], + inputs: dict | None = None, + pid: PID_TYPE | None = None, + logger: logging.Logger | None = None, ) -> 'Process': """Start running the nested process. @@ -574,14 +566,14 @@ def killed(self) -> bool: """Return whether the process is killed.""" return self.state_label == process_states.ProcessState.KILLED - def killed_msg(self) -> Optional[MessageType]: + def killed_msg(self) -> MessageType | None: """Return the killed message.""" if isinstance(self.state, process_states.Killed): return self.state.msg raise exceptions.InvalidStateError('Has not been killed') - def exception(self) -> Optional[BaseException]: + def exception(self) -> BaseException | None: """Return exception, if the process is terminated in excepted state.""" if isinstance(self.state, process_states.Excepted): return self.state.exception @@ -628,8 +620,8 @@ def call_soon(self, callback: Callable[..., Any], *args: Any, **kwargs: Any) -> def callback_excepted( self, _callback: Callable[..., Any], - exception: Optional[BaseException], - trace: Optional[TracebackType], + exception: BaseException | None, + trace: TracebackType | None, ) -> None: if self.state_label != process_states.ProcessState.EXCEPTED: self.fail(exception, trace) @@ -741,7 +733,7 @@ def on_entering(self, state: state_machine.State) -> None: elif state_label == process_states.ProcessState.EXCEPTED: call_with_super_check(self.on_except, state.get_exc_info()) # type: ignore - def on_entered(self, from_state: Optional[state_machine.State]) -> None: + def on_entered(self, from_state: state_machine.State | None) -> None: # Map these onto direct functions that the subclass can implement state_label = self.state_label if state_label == process_states.ProcessState.RUNNING: @@ -841,11 +833,11 @@ def on_waiting(self) -> None: self._fire_event(ProcessListener.on_process_waiting) @super_check - def on_pausing(self, msg: Optional[str] = None) -> None: + def on_pausing(self, msg: str | None = None) -> None: """The process is being paused.""" @super_check - def on_paused(self, msg: Optional[str] = None) -> None: + def on_paused(self, msg: str | None = None) -> None: """The process was paused.""" self._pausing = None @@ -890,7 +882,7 @@ def on_finished(self) -> None: self._fire_event(ProcessListener.on_process_finished, self.future().result()) @super_check - def on_except(self, exc_info: Tuple[Any, Exception, TracebackType]) -> None: + def on_except(self, exc_info: tuple[Any, Exception, TracebackType]) -> None: """Entering the EXCEPTED state.""" exception = exc_info[1] exception.__traceback__ = exc_info[2] @@ -909,7 +901,7 @@ def on_excepted(self) -> None: self._fire_event(ProcessListener.on_process_excepted, str(self.future().exception())) @super_check - def on_kill(self, msg: Optional[MessageType]) -> None: + def on_kill(self, msg: MessageType | None) -> None: """Entering the KILLED state.""" if msg is None: msg_txt = '' @@ -979,7 +971,7 @@ def message_receive(self, _comm: Coordinator, msg: MessageType) -> Any: if intent == message.Intent.KILL: return self._schedule_rpc(self.kill, msg_text=msg.get(MESSAGE_TEXT_KEY, None)) if intent == message.Intent.STATUS: - status_info: Dict[str, Any] = {} + status_info: dict[str, Any] = {} self.get_status_info(status_info) return status_info @@ -988,7 +980,7 @@ def message_receive(self, _comm: Coordinator, msg: MessageType) -> Any: def broadcast_receive( self, _comm: Coordinator, msg: MessageType, sender: Any, subject: Any, correlation_id: Any - ) -> Optional[concurrent.futures.Future]: + ) -> concurrent.futures.Future | None: """ Coroutine called when the process receives a message from the communicator @@ -1115,7 +1107,7 @@ def transition_failed( new_state = create_state(self, process_states.ProcessState.EXCEPTED, exception=exception, traceback=trace) self.transition_to(new_state) - def pause(self, msg_text: str | None = None) -> Union[bool, CancellableAction]: + def pause(self, msg_text: str | None = None) -> bool | CancellableAction: """Pause the process. :param msg: an optional message to set as the status. The current status will be saved in the private @@ -1158,7 +1150,7 @@ def pause(self, msg_text: str | None = None) -> Union[bool, CancellableAction]: def _interrupt(state: Interruptable, reason: Exception) -> None: state.interrupt(reason) - def _do_pause(self, state_msg: Optional[MessageType], next_state: Optional[state_machine.State] = None) -> bool: + def _do_pause(self, state_msg: MessageType | None, next_state: state_machine.State | None = None) -> bool: """Carry out the pause procedure, optionally transitioning to the next state first""" try: if next_state is not None: @@ -1204,7 +1196,7 @@ def do_kill(_next_state: state_machine.State) -> Any: raise ValueError(f"Got unknown interruption type '{type(exception)}'") - def _set_interrupt_action(self, new_action: Optional[CancellableAction]) -> None: + def _set_interrupt_action(self, new_action: CancellableAction | None) -> None: """ Set the interrupt action cancelling the current one if it exists :param new_action: The new interrupt action to set @@ -1241,7 +1233,7 @@ def resume(self, *args: Any) -> None: return self.state.resume(*args) # type: ignore @event(to_states=process_states.Excepted) - def fail(self, exception: Optional[BaseException], traceback: Optional[TracebackType]) -> None: + def fail(self, exception: BaseException | None, traceback: TracebackType | None) -> None: """ Fail the process in response to an exception :param exception: The exception that caused the failure @@ -1251,7 +1243,7 @@ def fail(self, exception: Optional[BaseException], traceback: Optional[Traceback new_state = create_state(self, process_states.ProcessState.EXCEPTED, exception=exception, traceback=traceback) self.transition_to(new_state) - def kill(self, msg_text: Optional[str] = None) -> Union[bool, asyncio.Future]: + def kill(self, msg_text: str | None = None) -> bool | asyncio.Future: """ Kill the process :param msg: An optional kill message @@ -1318,7 +1310,7 @@ async def run(self) -> Any: """ @ensure_not_closed - def execute(self) -> Optional[Dict[str, Any]]: + def execute(self) -> dict[str, Any] | None: """ Execute the process. This will return if the process terminates or is paused.