Skip to content

Commit

Permalink
Just the interface duck typing
Browse files Browse the repository at this point in the history
  • Loading branch information
unkcpz committed Dec 2, 2024
1 parent af11b1f commit 3e8fd20
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 49 deletions.
1 change: 1 addition & 0 deletions src/plumpy/base/state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def transition(self: Any, *a: Any, **kw: Any) -> Any:
@runtime_checkable
class State(Protocol):
LABEL: ClassVar[LABEL_TYPE]
ALLOWED: ClassVar[set[str]]
is_terminal: ClassVar[bool]

def enter(self) -> None: ...
Expand Down
78 changes: 29 additions & 49 deletions src/plumpy/process_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,8 @@ class ProcessState(Enum):


@final
@auto_persist('args', 'kwargs', 'in_state')
class Created(state_machine.State, persistence.Savable):
@auto_persist('args', 'kwargs')
class Created(persistence.Savable):
LABEL = ProcessState.CREATED
ALLOWED = {ProcessState.RUNNING, ProcessState.KILLED, ProcessState.EXCEPTED}

Expand All @@ -161,7 +161,6 @@ def __init__(self, process: 'Process', run_fn: Callable[..., Any], *args: Any, *
self.run_fn = run_fn
self.args = args
self.kwargs = kwargs
self.in_state = True

def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None:
super().save_instance_state(out_state, save_context)
Expand All @@ -177,18 +176,15 @@ async def execute(self) -> state_machine.State:
return self.process.create_state(ProcessState.RUNNING, self.run_fn, *self.args, **self.kwargs)

def enter(self) -> None:
self.in_state = True
...

def exit(self) -> None:
if self.is_terminal:
raise exceptions.InvalidStateError(f'Cannot exit a terminal state {self.LABEL}')

self.in_state = False
...


@final
@auto_persist('args', 'kwargs', 'in_state')
class Running(state_machine.State, persistence.Savable):
@auto_persist('args', 'kwargs')
class Running(persistence.Savable):
LABEL = ProcessState.RUNNING
ALLOWED = {
ProcessState.RUNNING,
Expand All @@ -215,7 +211,6 @@ def __init__(self, process: 'Process', run_fn: Callable[..., Any], *args: Any, *
self.args = args
self.kwargs = kwargs
self._run_handle = None
self.in_state = False

def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None:
super().save_instance_state(out_state, save_context)
Expand Down Expand Up @@ -280,18 +275,15 @@ def _action_command(self, command: Union[Kill, Stop, Wait, Continue]) -> state_m
return cast(state_machine.State, state) # casting from base.State to process.State

def enter(self) -> None:
self.in_state = True
...

def exit(self) -> None:
if self.is_terminal:
raise exceptions.InvalidStateError(f'Cannot exit a terminal state {self.LABEL}')

self.in_state = False
...


@final
@auto_persist('msg', 'data', 'in_state')
class Waiting(state_machine.State, persistence.Savable):
@auto_persist('msg', 'data')
class Waiting(persistence.Savable):
LABEL = ProcessState.WAITING
ALLOWED = {
ProcessState.RUNNING,
Expand Down Expand Up @@ -325,7 +317,6 @@ def __init__(
self.msg = msg
self.data = data
self._waiting_future: futures.Future = futures.Future()
self.in_state = False

def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None:
super().save_instance_state(out_state, save_context)
Expand Down Expand Up @@ -373,25 +364,21 @@ def resume(self, value: Any = NULL) -> None:
self._waiting_future.set_result(value)

def enter(self) -> None:
self.in_state = True
...

def exit(self) -> None:
if self.is_terminal:
raise exceptions.InvalidStateError(f'Cannot exit a terminal state {self.LABEL}')

self.in_state = False
...


@auto_persist('in_state')
class Excepted(state_machine.State, persistence.Savable):
@final
class Excepted(persistence.Savable):
"""
Excepted state, can optionally provide exception and trace_back
:param exception: The exception instance
:param trace_back: An optional exception traceback
"""

LABEL = ProcessState.EXCEPTED
ALLOWED: set[str] = set()

EXC_VALUE = 'ex_value'
TRACEBACK = 'traceback'
Expand All @@ -412,7 +399,6 @@ def __init__(
self.process = process
self.exception = exception
self.traceback = trace_back
self.in_state = False

def __str__(self) -> str:
exception = traceback.format_exception_only(type(self.exception) if self.exception else None, self.exception)[0]
Expand Down Expand Up @@ -450,49 +436,44 @@ def get_exc_info(
)

def enter(self) -> None:
self.in_state = True
...

def exit(self) -> None:
if self.is_terminal:
raise exceptions.InvalidStateError(f'Cannot exit a terminal state {self.LABEL}')

self.in_state = False
...


@auto_persist('result', 'successful', 'in_state')
class Finished(state_machine.State, persistence.Savable):
@final
@auto_persist('result', 'successful')
class Finished(persistence.Savable):
"""State for process is finished.
:param result: The result of process
:param successful: Boolean for the exit code is ``0`` the process is successful.
"""

LABEL = ProcessState.FINISHED
ALLOWED: set[str] = set()

is_terminal = True

def __init__(self, process: 'Process', result: Any, successful: bool) -> None:
self.process = process
self.result = result
self.successful = successful
self.in_state = False

def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None:
super().load_instance_state(saved_state, load_context)
self.process = load_context.process

def enter(self) -> None:
self.in_state = True
...

def exit(self) -> None:
if self.is_terminal:
raise exceptions.InvalidStateError(f'Cannot exit a terminal state {self.LABEL}')
...

self.in_state = False


@auto_persist('msg', 'in_state')
class Killed(state_machine.State, persistence.Savable):
@final
@auto_persist('msg')
class Killed(persistence.Savable):
"""
Represents a state where a process has been killed.
Expand All @@ -503,6 +484,7 @@ class Killed(state_machine.State, persistence.Savable):
"""

LABEL = ProcessState.KILLED
ALLOWED: set[str] = set()

is_terminal = True

Expand All @@ -519,13 +501,11 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persi
self.process = load_context.process

def enter(self) -> None:
self.in_state = True
...

def exit(self) -> None:
if self.is_terminal:
raise exceptions.InvalidStateError(f'Cannot exit a terminal state {self.LABEL}')
...

self.in_state = False


# endregion

0 comments on commit 3e8fd20

Please sign in to comment.