Skip to content

Commit

Permalink
Move is_terminal as class attribute required
Browse files Browse the repository at this point in the history
  • Loading branch information
unkcpz committed Dec 2, 2024
1 parent 26b50fd commit 605b8c3
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 8 deletions.
8 changes: 2 additions & 6 deletions src/plumpy/base/state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,10 +133,6 @@ class State:
# from this one
ALLOWED: Set[LABEL_TYPE] = set()

@classmethod
def is_terminal(cls) -> bool:
return not cls.ALLOWED

def __init__(self, state_machine: 'StateMachine', *args: Any, **kwargs: Any):
"""
:param state_machine: The process this state belongs to
Expand Down Expand Up @@ -165,7 +161,7 @@ def execute(self) -> Optional['State']:
@super_check
def exit(self) -> None:
"""Exiting the state"""
if self.is_terminal():
if self.is_terminal:
raise InvalidStateError(f'Cannot exit a terminal state {self.LABEL}')

def create_state(self, state_label: Hashable, *args: Any, **kwargs: Any) -> 'State':
Expand Down Expand Up @@ -352,7 +348,7 @@ def transition_to(self, new_state: State | type[State] | None, **kwargs: Any) ->
self._exit_current_state(new_state)
self._enter_next_state(new_state)

if self._state is not None and self._state.is_terminal():
if self._state is not None and self._state.is_terminal:
call_with_super_check(self.on_terminated)
except Exception:
self._transitioning = False
Expand Down
11 changes: 11 additions & 0 deletions src/plumpy/process_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ class Created(state_machine.State, persistence.Savable):
ALLOWED = {ProcessState.RUNNING, ProcessState.KILLED, ProcessState.EXCEPTED}

RUN_FN = 'run_fn'
is_terminal = False

def __init__(self, process: 'Process', run_fn: Callable[..., Any], *args: Any, **kwargs: Any) -> None:
super().__init__(process)
Expand Down Expand Up @@ -200,6 +201,8 @@ class Running(state_machine.State, persistence.Savable):
_running: bool = False
_run_handle = None

is_terminal = False

def __init__(self, process: 'Process', run_fn: Callable[..., Any], *args: Any, **kwargs: Any) -> None:
super().__init__(process)
assert run_fn is not None
Expand Down Expand Up @@ -293,6 +296,8 @@ class Waiting(state_machine.State, persistence.Savable):

_interruption = None

is_terminal = False

def __str__(self) -> str:
state_info = super().__str__()
if self.msg is not None:
Expand Down Expand Up @@ -379,6 +384,8 @@ class Excepted(state_machine.State, persistence.Savable):
EXC_VALUE = 'ex_value'
TRACEBACK = 'traceback'

is_terminal = True

def __init__(
self,
process: 'Process',
Expand Down Expand Up @@ -447,6 +454,8 @@ class Finished(state_machine.State, persistence.Savable):

LABEL = ProcessState.FINISHED

is_terminal = True

def __init__(self, process: 'Process', result: Any, successful: bool) -> None:
super().__init__(process)
self.result = result
Expand Down Expand Up @@ -477,6 +486,8 @@ class Killed(state_machine.State, persistence.Savable):

LABEL = ProcessState.KILLED

is_terminal = True

def __init__(self, process: 'Process', msg: Optional[MessageType]):
"""
:param process: The associated process
Expand Down
4 changes: 2 additions & 2 deletions src/plumpy/processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ def launch(

def has_terminated(self) -> bool:
"""Return whether the process was terminated."""
return self._state.is_terminal()
return self._state.is_terminal

def result(self) -> Any:
"""
Expand Down Expand Up @@ -540,7 +540,7 @@ def done(self) -> bool:
Use the `has_terminated` method instead
"""
warnings.warn('method is deprecated, use `has_terminated` instead', DeprecationWarning)
return self._state.is_terminal()
return self._state.is_terminal

# endregion

Expand Down
6 changes: 6 additions & 0 deletions tests/base/test_statemachine.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ class Playing(state_machine.State):
ALLOWED = {PAUSED, STOPPED}
TRANSITIONS = {STOP: STOPPED}

is_terminal = False

def __init__(self, player, track):
assert track is not None, 'Must provide a track name'
super().__init__(player)
Expand Down Expand Up @@ -54,6 +56,8 @@ class Paused(state_machine.State):
ALLOWED = {PLAYING, STOPPED}
TRANSITIONS = {STOP: STOPPED}

is_terminal = False

def __init__(self, player, playing_state):
assert isinstance(playing_state, Playing), 'Must provide the playing state to pause'
super().__init__(player)
Expand All @@ -76,6 +80,8 @@ class Stopped(state_machine.State):
}
TRANSITIONS = {PLAY: PLAYING}

is_terminal = False

def __str__(self):
return '[]'

Expand Down

0 comments on commit 605b8c3

Please sign in to comment.