Skip to content

Commit

Permalink
processes.py
Browse files Browse the repository at this point in the history
  • Loading branch information
unkcpz committed Jan 23, 2025
1 parent 5f44e45 commit 29e3979
Showing 1 changed file with 53 additions and 61 deletions.
114 changes: 53 additions & 61 deletions src/plumpy/processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,14 @@
import time
import uuid
import warnings
from collections.abc import Awaitable, Generator, Sequence
from types import TracebackType
from typing import (
Any,
Awaitable,
Callable,
ClassVar,
Dict,
Generator,
Hashable,
List,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
Union,
cast,
)

Expand Down Expand Up @@ -161,18 +153,18 @@ class Process(StateMachine, metaclass=ProcessStateMachineMeta):
_spec_class = ProcessSpec
# Default placeholders, will be populated in init()
_stepping = False
_pausing: Optional[CancellableAction] = None
_paused: Optional[persistence.SavableFuture] = None
_killing: Optional[CancellableAction] = None
_interrupt_action: Optional[CancellableAction] = None
_pausing: CancellableAction | None = None
_paused: persistence.SavableFuture | None = None
_killing: CancellableAction | None = None
_interrupt_action: CancellableAction | None = None
_closed = False
_cleanups: Optional[List[Callable[[], None]]] = None
_cleanups: list[Callable[[], None]] | None = None

__called: bool = False
_auto_persist: ClassVar[set[str]]

@classmethod
def current(cls) -> Optional['Process']:
def current(cls) -> 'Process | None':
"""
Get the currently running process i.e. the one at the top of the stack
Expand All @@ -185,7 +177,7 @@ def current(cls) -> Optional['Process']:
return None

@classmethod
def get_states(cls) -> Sequence[Type[state_machine.State]]:
def get_states(cls) -> Sequence[type[state_machine.State]]:
"""Return all allowed states of the process."""
state_classes = cls.get_state_classes()
return (
Expand All @@ -194,7 +186,7 @@ def get_states(cls) -> Sequence[Type[state_machine.State]]:
)

@classmethod
def get_state_classes(cls) -> dict[process_states.ProcessState, Type[state_machine.State]]:
def get_state_classes(cls) -> dict[process_states.ProcessState, type[state_machine.State]]:
# A mapping of the State constants to the corresponding state class
return {
process_states.ProcessState.CREATED: process_states.Created,
Expand Down Expand Up @@ -238,14 +230,14 @@ def define(cls, _spec: ProcessSpec) -> None:
cls.__called = True

@classmethod
def get_description(cls) -> Dict[str, Any]:
def get_description(cls) -> dict[str, Any]:
"""
Get a human readable description of what this :class:`Process` does.
:return: The description.
"""
description: Dict[str, Any] = {}
description: dict[str, Any] = {}

if cls.__doc__:
description['description'] = cls.__doc__.strip()
Expand All @@ -260,7 +252,7 @@ def get_description(cls) -> Dict[str, Any]:
def recreate_from(
cls,
saved_state: SAVED_STATE_TYPE,
load_context: Optional[persistence.LoadSaveContext] = None,
load_context: persistence.LoadSaveContext | None = None,
) -> Self:
"""Recreate a process from a saved state, passing any positional
Expand Down Expand Up @@ -329,11 +321,11 @@ def recreate_from(

def __init__(
self,
inputs: Optional[dict] = None,
pid: Optional[PID_TYPE] = None,
logger: Optional[logging.Logger] = None,
loop: Optional[asyncio.AbstractEventLoop] = None,
coordinator: Optional[Coordinator] = None,
inputs: dict | None = None,
pid: PID_TYPE | None = None,
logger: logging.Logger | None = None,
loop: asyncio.AbstractEventLoop | None = None,
coordinator: Coordinator | None = None,
) -> None:
"""
The signature of the constructor should not be changed by subclassing processes.
Expand All @@ -354,19 +346,19 @@ def __init__(

self._setup_event_hooks()

self._status: Optional[str] = None # May hold a current status message
self._pre_paused_status: Optional[str] = (
self._status: str | None = None # May hold a current status message
self._pre_paused_status: str | None = (
None # Save status when a pause message replaces it, such that it can be restored
)
self._paused = None

# Input/output
self._raw_inputs = None if inputs is None else utils.AttributesFrozendict(inputs)
self._pid = pid
self._parsed_inputs: Optional[utils.AttributesFrozendict] = None
self._outputs: Dict[str, Any] = {}
self._uuid: Optional[uuid.UUID] = None
self._creation_time: Optional[float] = None
self._parsed_inputs: utils.AttributesFrozendict | None = None
self._outputs: dict[str, Any] = {}
self._uuid: uuid.UUID | None = None
self._creation_time: float | None = None

# Runtime variables
self._future = persistence.SavableFuture(loop=self._loop)
Expand Down Expand Up @@ -421,43 +413,43 @@ def _setup_event_hooks(self) -> None:
cast(state_machine.State, state)
),
state_machine.StateEventHook.ENTERED_STATE: lambda _s, _h, from_state: self.on_entered(
cast(Optional[state_machine.State], from_state)
cast(state_machine.State | None, from_state)
),
state_machine.StateEventHook.EXITING_STATE: lambda _s, _h, _state: self.on_exiting(),
}
for hook, callback in event_hooks.items():
self.add_state_event_callback(hook, callback)

@property
def creation_time(self) -> Optional[float]:
def creation_time(self) -> float | None:
"""
The creation time of this Process as returned by time.time() when instantiated
:return: The creation time
"""
return self._creation_time

@property
def pid(self) -> Optional[PID_TYPE]:
def pid(self) -> PID_TYPE | None:
"""Return the pid of the process."""
return self._pid

@property
def uuid(self) -> Optional[uuid.UUID]:
def uuid(self) -> uuid.UUID | None:
"""Return the UUID of the process"""
return self._uuid

@property
def raw_inputs(self) -> Optional[utils.AttributesFrozendict]:
def raw_inputs(self) -> utils.AttributesFrozendict | None:
"""The `AttributesFrozendict` of inputs (if not None)."""
return self._raw_inputs

@property
def inputs(self) -> Optional[utils.AttributesFrozendict]:
def inputs(self) -> utils.AttributesFrozendict | None:
"""Return the parsed inputs."""
return self._parsed_inputs

@property
def outputs(self) -> Dict[str, Any]:
def outputs(self) -> dict[str, Any]:
"""
Get the current outputs emitted by the Process. These may grow over
time as the process runs.
Expand All @@ -482,11 +474,11 @@ def logger(self) -> logging.Logger:
return _LOGGER

@property
def status(self) -> Optional[str]:
def status(self) -> str | None:
"""Return the status massage of the process."""
return self._status

def set_status(self, status: Optional[str]) -> None:
def set_status(self, status: str | None) -> None:
"""Set the status message of the process."""
self._status = status

Expand All @@ -505,10 +497,10 @@ def future(self) -> persistence.SavableFuture:
@ensure_not_closed
def launch(
self,
process_class: Type['Process'],
inputs: Optional[dict] = None,
pid: Optional[PID_TYPE] = None,
logger: Optional[logging.Logger] = None,
process_class: type['Process'],
inputs: dict | None = None,
pid: PID_TYPE | None = None,
logger: logging.Logger | None = None,
) -> 'Process':
"""Start running the nested process.
Expand Down Expand Up @@ -574,14 +566,14 @@ def killed(self) -> bool:
"""Return whether the process is killed."""
return self.state_label == process_states.ProcessState.KILLED

def killed_msg(self) -> Optional[MessageType]:
def killed_msg(self) -> MessageType | None:
"""Return the killed message."""
if isinstance(self.state, process_states.Killed):
return self.state.msg

raise exceptions.InvalidStateError('Has not been killed')

def exception(self) -> Optional[BaseException]:
def exception(self) -> BaseException | None:
"""Return exception, if the process is terminated in excepted state."""
if isinstance(self.state, process_states.Excepted):
return self.state.exception
Expand Down Expand Up @@ -628,8 +620,8 @@ def call_soon(self, callback: Callable[..., Any], *args: Any, **kwargs: Any) ->
def callback_excepted(
self,
_callback: Callable[..., Any],
exception: Optional[BaseException],
trace: Optional[TracebackType],
exception: BaseException | None,
trace: TracebackType | None,
) -> None:
if self.state_label != process_states.ProcessState.EXCEPTED:
self.fail(exception, trace)
Expand Down Expand Up @@ -741,7 +733,7 @@ def on_entering(self, state: state_machine.State) -> None:
elif state_label == process_states.ProcessState.EXCEPTED:
call_with_super_check(self.on_except, state.get_exc_info()) # type: ignore

def on_entered(self, from_state: Optional[state_machine.State]) -> None:
def on_entered(self, from_state: state_machine.State | None) -> None:
# Map these onto direct functions that the subclass can implement
state_label = self.state_label
if state_label == process_states.ProcessState.RUNNING:
Expand Down Expand Up @@ -841,11 +833,11 @@ def on_waiting(self) -> None:
self._fire_event(ProcessListener.on_process_waiting)

@super_check
def on_pausing(self, msg: Optional[str] = None) -> None:
def on_pausing(self, msg: str | None = None) -> None:
"""The process is being paused."""

@super_check
def on_paused(self, msg: Optional[str] = None) -> None:
def on_paused(self, msg: str | None = None) -> None:
"""The process was paused."""
self._pausing = None

Expand Down Expand Up @@ -890,7 +882,7 @@ def on_finished(self) -> None:
self._fire_event(ProcessListener.on_process_finished, self.future().result())

@super_check
def on_except(self, exc_info: Tuple[Any, Exception, TracebackType]) -> None:
def on_except(self, exc_info: tuple[Any, Exception, TracebackType]) -> None:
"""Entering the EXCEPTED state."""
exception = exc_info[1]
exception.__traceback__ = exc_info[2]
Expand All @@ -909,7 +901,7 @@ def on_excepted(self) -> None:
self._fire_event(ProcessListener.on_process_excepted, str(self.future().exception()))

@super_check
def on_kill(self, msg: Optional[MessageType]) -> None:
def on_kill(self, msg: MessageType | None) -> None:
"""Entering the KILLED state."""
if msg is None:
msg_txt = ''
Expand Down Expand Up @@ -979,7 +971,7 @@ def message_receive(self, _comm: Coordinator, msg: MessageType) -> Any:
if intent == message.Intent.KILL:
return self._schedule_rpc(self.kill, msg_text=msg.get(MESSAGE_TEXT_KEY, None))
if intent == message.Intent.STATUS:
status_info: Dict[str, Any] = {}
status_info: dict[str, Any] = {}
self.get_status_info(status_info)
return status_info

Expand All @@ -988,7 +980,7 @@ def message_receive(self, _comm: Coordinator, msg: MessageType) -> Any:

def broadcast_receive(
self, _comm: Coordinator, msg: MessageType, sender: Any, subject: Any, correlation_id: Any
) -> Optional[concurrent.futures.Future]:
) -> concurrent.futures.Future | None:
"""
Coroutine called when the process receives a message from the communicator
Expand Down Expand Up @@ -1115,7 +1107,7 @@ def transition_failed(
new_state = create_state(self, process_states.ProcessState.EXCEPTED, exception=exception, traceback=trace)
self.transition_to(new_state)

def pause(self, msg_text: str | None = None) -> Union[bool, CancellableAction]:
def pause(self, msg_text: str | None = None) -> bool | CancellableAction:
"""Pause the process.
:param msg: an optional message to set as the status. The current status will be saved in the private
Expand Down Expand Up @@ -1158,7 +1150,7 @@ def pause(self, msg_text: str | None = None) -> Union[bool, CancellableAction]:
def _interrupt(state: Interruptable, reason: Exception) -> None:
state.interrupt(reason)

def _do_pause(self, state_msg: Optional[MessageType], next_state: Optional[state_machine.State] = None) -> bool:
def _do_pause(self, state_msg: MessageType | None, next_state: state_machine.State | None = None) -> bool:
"""Carry out the pause procedure, optionally transitioning to the next state first"""
try:
if next_state is not None:
Expand Down Expand Up @@ -1204,7 +1196,7 @@ def do_kill(_next_state: state_machine.State) -> Any:

raise ValueError(f"Got unknown interruption type '{type(exception)}'")

def _set_interrupt_action(self, new_action: Optional[CancellableAction]) -> None:
def _set_interrupt_action(self, new_action: CancellableAction | None) -> None:
"""
Set the interrupt action cancelling the current one if it exists
:param new_action: The new interrupt action to set
Expand Down Expand Up @@ -1241,7 +1233,7 @@ def resume(self, *args: Any) -> None:
return self.state.resume(*args) # type: ignore

@event(to_states=process_states.Excepted)
def fail(self, exception: Optional[BaseException], traceback: Optional[TracebackType]) -> None:
def fail(self, exception: BaseException | None, traceback: TracebackType | None) -> None:
"""
Fail the process in response to an exception
:param exception: The exception that caused the failure
Expand All @@ -1251,7 +1243,7 @@ def fail(self, exception: Optional[BaseException], traceback: Optional[Traceback
new_state = create_state(self, process_states.ProcessState.EXCEPTED, exception=exception, traceback=traceback)
self.transition_to(new_state)

def kill(self, msg_text: Optional[str] = None) -> Union[bool, asyncio.Future]:
def kill(self, msg_text: str | None = None) -> bool | asyncio.Future:
"""
Kill the process
:param msg: An optional kill message
Expand Down Expand Up @@ -1318,7 +1310,7 @@ async def run(self) -> Any:
"""

@ensure_not_closed
def execute(self) -> Optional[Dict[str, Any]]:
def execute(self) -> dict[str, Any] | None:
"""
Execute the process. This will return if the process terminates or is paused.
Expand Down

0 comments on commit 29e3979

Please sign in to comment.