diff --git a/src/plumpy/base/state_machine.py b/src/plumpy/base/state_machine.py index da94299b..cd372e48 100644 --- a/src/plumpy/base/state_machine.py +++ b/src/plumpy/base/state_machine.py @@ -274,7 +274,13 @@ def create_initial_state(self, *args: Any, **kwargs: Any) -> State: return self.get_state_class(self.initial_state_label())(self, *args, **kwargs) @property - def state(self) -> Any: + def state(self) -> State | None: + if self._state is None: + return None + return self._state + + @property + def state_label(self) -> Any: if self._state is None: return None return self._state.LABEL @@ -326,7 +332,7 @@ def transition_to(self, new_state: State | None, **kwargs: Any) -> None: # it can happened when transit from terminal state return None - initial_state_label = self._state.LABEL if self._state is not None else None + initial_state_label = self.state_label label = None try: self._transitioning = True diff --git a/src/plumpy/event_helper.py b/src/plumpy/event_helper.py index abc2b24b..9262f856 100644 --- a/src/plumpy/event_helper.py +++ b/src/plumpy/event_helper.py @@ -45,8 +45,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa """ load_context = ensure_object_loader(load_context, saved_state) - obj = cls.__new__(cls) - auto_load(obj, saved_state, load_context) + obj = auto_load(cls, saved_state, load_context) return obj def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: diff --git a/src/plumpy/persistence.py b/src/plumpy/persistence.py index afe82439..31bbc67c 100644 --- a/src/plumpy/persistence.py +++ b/src/plumpy/persistence.py @@ -20,6 +20,7 @@ List, Optional, Protocol, + TypeVar, cast, runtime_checkable, ) @@ -523,6 +524,8 @@ def auto_save(obj: Savable, save_context: Optional[LoadSaveContext] = None) -> S value = value.__name__ elif isinstance(value, Savable) and not isinstance(value, type): # persist for a savable obj, call `save` method of obj. + # the rhs branch is for when value is a Savable class, it is true runtime check + # of lhs condition. SaveUtil.set_meta_type(out_state, member, META__TYPE__SAVABLE) value = value.save() else: @@ -532,11 +535,25 @@ def auto_save(obj: Savable, save_context: Optional[LoadSaveContext] = None) -> S return out_state -def auto_load(obj: SavableWithAutoPersist, saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext) -> None: +def load_auto_persist_params( + obj: SavableWithAutoPersist, saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext | None +) -> None: for member in obj._auto_persist: setattr(obj, member, _get_value(obj, saved_state, member, load_context)) +T = TypeVar('T', bound=Savable) + + +def auto_load(cls: type[T], saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext | None) -> T: + obj = cls.__new__(cls) + + if isinstance(obj, SavableWithAutoPersist): + load_auto_persist_params(obj, saved_state, load_context) + + return obj + + def _get_value( obj: Any, saved_state: SAVED_STATE_TYPE, name: str, load_context: LoadSaveContext | None ) -> MethodType | Savable: diff --git a/src/plumpy/process_states.py b/src/plumpy/process_states.py index 0659d1da..abca268c 100644 --- a/src/plumpy/process_states.py +++ b/src/plumpy/process_states.py @@ -78,8 +78,8 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa :return: The recreated instance """ - obj = cls.__new__(cls) - auto_load(obj, saved_state, load_context) + load_context = ensure_object_loader(load_context, saved_state) + obj = auto_load(cls, saved_state, load_context) return obj def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: @@ -151,15 +151,15 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa """ load_context = ensure_object_loader(load_context, saved_state) - obj = cls.__new__(cls) - auto_load(obj, saved_state, load_context) + obj = auto_load(cls, saved_state, load_context) - obj.state_machine = load_context.process try: obj.continue_fn = utils.load_function(saved_state[obj.CONTINUE_FN]) except ValueError: - process = load_context.process - obj.continue_fn = getattr(process, saved_state[obj.CONTINUE_FN]) + if load_context is not None: + obj.continue_fn = getattr(load_context.proc, saved_state[obj.CONTINUE_FN]) + else: + raise return obj @@ -215,12 +215,8 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa """ load_context = ensure_object_loader(load_context, saved_state) - obj = cls.__new__(cls) - - auto_load(obj, saved_state, load_context) - + obj = auto_load(cls, saved_state, load_context) obj.process = load_context.process - obj.run_fn = getattr(obj.process, saved_state[obj.RUN_FN]) return obj @@ -286,15 +282,12 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa """ load_context = ensure_object_loader(load_context, saved_state) - obj = cls.__new__(cls) - auto_load(obj, saved_state, load_context) - + obj = auto_load(cls, saved_state, load_context) obj.process = load_context.process - obj.run_fn = getattr(obj.process, saved_state[obj.RUN_FN]) if obj.COMMAND in saved_state: - # FIXME: typing obj._command = persistence.load(saved_state[obj.COMMAND], load_context) # type: ignore + return obj def interrupt(self, reason: Any) -> None: @@ -424,9 +417,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa """ load_context = ensure_object_loader(load_context, saved_state) - obj = cls.__new__(cls) - auto_load(obj, saved_state, load_context) - + obj = auto_load(cls, saved_state, load_context) obj.process = load_context.process callback_name = saved_state.get(obj.DONE_CALLBACK, None) @@ -530,8 +521,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa """ load_context = ensure_object_loader(load_context, saved_state) - obj = cls.__new__(cls) - auto_load(obj, saved_state, load_context) + obj = auto_load(cls, saved_state, load_context) obj.exception = yaml.load(saved_state[obj.EXC_VALUE], Loader=Loader) if _HAS_TBLIB: @@ -590,8 +580,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa """ load_context = ensure_object_loader(load_context, saved_state) - obj = cls.__new__(cls) - auto_load(obj, saved_state, load_context) + obj = auto_load(cls, saved_state, load_context) return obj def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: @@ -639,8 +628,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa """ load_context = ensure_object_loader(load_context, saved_state) - obj = cls.__new__(cls) - auto_load(obj, saved_state, load_context) + obj = auto_load(cls, saved_state, load_context) return obj def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index e8444bc5..a6d0dc13 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -21,6 +21,7 @@ Any, Awaitable, Callable, + ClassVar, Dict, Generator, Hashable, @@ -166,6 +167,7 @@ class Process(StateMachine, metaclass=ProcessStateMachineMeta): _cleanups: Optional[List[Callable[[], None]]] = None __called: bool = False + _auto_persist: ClassVar[set[str]] @classmethod def current(cls) -> Optional['Process']: @@ -285,7 +287,7 @@ def recreate_from( else: proc._loop = asyncio.get_event_loop() - proc._state: state_machine.State = proc.recreate_state(saved_state['_state']) + proc._state = proc.recreate_state(saved_state['_state']) if 'communicator' in load_context: proc._communicator = load_context.communicator @@ -294,7 +296,7 @@ def recreate_from( proc._logger = load_context.logger # Need to call this here as things downstream may rely on us having the runtime variable above - persistence.auto_load(proc, saved_state, load_context) + persistence.load_auto_persist_params(proc, saved_state, load_context) # Inputs/outputs try: @@ -519,7 +521,9 @@ def launch( def has_terminated(self) -> bool: """Return whether the process was terminated.""" - return self._state.is_terminal + if self.state is None: + raise exceptions.InvalidStateError('process is not in state None that is invalid') + return self.state.is_terminal def result(self) -> Any: """ @@ -529,12 +533,12 @@ def result(self) -> Any: If in any other state this will raise an InvalidStateError. :return: The result of the process """ - if isinstance(self._state, process_states.Finished): - return self._state.result - if isinstance(self._state, process_states.Killed): - raise exceptions.KilledError(self._state.msg) - if isinstance(self._state, process_states.Excepted): - raise (self._state.exception or Exception('process excepted')) + if isinstance(self.state, process_states.Finished): + return self.state.result + if isinstance(self.state, process_states.Killed): + raise exceptions.KilledError(self.state.msg) + if isinstance(self.state, process_states.Excepted): + raise (self.state.exception or Exception('process excepted')) raise exceptions.InvalidStateError @@ -544,7 +548,7 @@ def successful(self) -> bool: Will raise if the process is not in the FINISHED state """ try: - return self._state.successful # type: ignore + return self.state.successful # type: ignore except AttributeError as exception: raise exceptions.InvalidStateError('process is not in the finished state') from exception @@ -555,25 +559,25 @@ def is_successful(self) -> bool: :return: boolean, True if the process is in `Finished` state with `successful` attribute set to `True` """ try: - return self._state.successful # type: ignore + return self.state.successful # type: ignore except AttributeError: return False def killed(self) -> bool: """Return whether the process is killed.""" - return self.state == process_states.ProcessState.KILLED + return self.state_label == process_states.ProcessState.KILLED def killed_msg(self) -> Optional[MessageType]: """Return the killed message.""" - if isinstance(self._state, process_states.Killed): - return self._state.msg + if isinstance(self.state, process_states.Killed): + return self.state.msg raise exceptions.InvalidStateError('Has not been killed') def exception(self) -> Optional[BaseException]: """Return exception, if the process is terminated in excepted state.""" - if isinstance(self._state, process_states.Excepted): - return self._state.exception + if isinstance(self.state, process_states.Excepted): + return self.state.exception return None @@ -583,7 +587,7 @@ def is_excepted(self) -> bool: :return: boolean, True if the process is in ``EXCEPTED`` state. """ - return self.state == process_states.ProcessState.EXCEPTED + return self.state_label == process_states.ProcessState.EXCEPTED def done(self) -> bool: """Return True if the call was successfully killed or finished running. @@ -592,7 +596,7 @@ def done(self) -> bool: Use the `has_terminated` method instead """ warnings.warn('method is deprecated, use `has_terminated` instead', DeprecationWarning) - return self._state.is_terminal + return self.has_terminated() # endregion @@ -620,7 +624,7 @@ def callback_excepted( exception: Optional[BaseException], trace: Optional[TracebackType], ) -> None: - if self.state != process_states.ProcessState.EXCEPTED: + if self.state_label != process_states.ProcessState.EXCEPTED: self.fail(exception, trace) @contextlib.contextmanager @@ -673,8 +677,8 @@ def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SA """ out_state: SAVED_STATE_TYPE = persistence.auto_save(self, save_context) - if isinstance(self._state, persistence.Savable): - out_state['_state'] = self._state.save() + if isinstance(self.state, persistence.Savable): + out_state['_state'] = self.state.save() # Inputs/outputs if self.raw_inputs is not None: @@ -732,7 +736,7 @@ def on_entering(self, state: state_machine.State) -> None: def on_entered(self, from_state: Optional[state_machine.State]) -> None: # Map these onto direct functions that the subclass can implement - state_label = self._state.LABEL + state_label = self.state_label if state_label == process_states.ProcessState.RUNNING: call_with_super_check(self.on_running) elif state_label == process_states.ProcessState.WAITING: @@ -746,7 +750,7 @@ def on_entered(self, from_state: Optional[state_machine.State]) -> None: if self._coordinator and isinstance(self.state, enum.Enum): from_label = cast(enum.Enum, from_state.LABEL).value if from_state is not None else None - subject = f'state_changed.{from_label}.{self.state.value}' + subject = f'state_changed.{from_label}.{self.state_label.value}' self.logger.info('Process<%s>: Broadcasting state change: %s', self.pid, subject) try: self._coordinator.broadcast_send(body=None, sender=self.pid, subject=subject) @@ -759,7 +763,7 @@ def on_entered(self, from_state: Optional[state_machine.State]) -> None: raise def on_exiting(self) -> None: - state = self.state + state = self.state_label if state == process_states.ProcessState.WAITING: call_with_super_check(self.on_exit_waiting) elif state == process_states.ProcessState.RUNNING: @@ -1126,9 +1130,9 @@ def pause(self, msg_text: str | None = None) -> Union[bool, CancellableAction]: return self._pausing if self._stepping: - if not isinstance(self._state, Interruptable): + if not isinstance(self.state, Interruptable): raise exceptions.InvalidStateError( - f'cannot interrupt {self._state.__class__}, method `interrupt` not implement' + f'cannot interrupt {self.state.__class__}, method `interrupt` not implement' ) # Ask the step function to pause by setting this flag and giving the @@ -1227,7 +1231,7 @@ def play(self) -> bool: @event(from_states=process_states.Waiting) def resume(self, *args: Any) -> None: """Start running the process again.""" - return self._state.resume(*args) # type: ignore + return self.state.resume(*args) # type: ignore @event(to_states=process_states.Excepted) def fail(self, exception: Optional[BaseException], traceback: Optional[TracebackType]) -> None: @@ -1246,7 +1250,7 @@ def kill(self, msg_text: Optional[str] = None) -> Union[bool, asyncio.Future]: Kill the process :param msg: An optional kill message """ - if self.state == process_states.ProcessState.KILLED: + if self.state_label == process_states.ProcessState.KILLED: # Already killed return True @@ -1258,7 +1262,7 @@ def kill(self, msg_text: Optional[str] = None) -> Union[bool, asyncio.Future]: # Already killing return self._killing - if self._stepping and isinstance(self._state, Interruptable): + if self._stepping and isinstance(self.state, Interruptable): # Ask the step function to pause by setting this flag and giving the # caller back a future interrupt_exception = process_states.KillInterruption(msg_text) @@ -1334,8 +1338,8 @@ async def step(self) -> None: if self.paused and self._paused is not None: await self._paused - if not isinstance(self._state, Proceedable): - raise StateMachineError(f'cannot step from {self._state.__class__}, async method `execute` not implemented') + if not isinstance(self.state, Proceedable): + raise StateMachineError(f'cannot step from {self.state.__class__}, async method `execute` not implemented') try: self._stepping = True diff --git a/src/plumpy/workchains.py b/src/plumpy/workchains.py index 348be7d1..6609f9cc 100644 --- a/src/plumpy/workchains.py +++ b/src/plumpy/workchains.py @@ -8,7 +8,6 @@ import logging import re from typing import ( - TYPE_CHECKING, Any, Callable, Dict, @@ -31,7 +30,7 @@ from plumpy.base.utils import call_with_super_check from plumpy.event_helper import EventHelper from plumpy.exceptions import InvalidStateError -from plumpy.persistence import LoadSaveContext, auto_persist, auto_save, ensure_object_loader, Savable +from plumpy.persistence import LoadSaveContext, Savable, auto_persist, auto_save, ensure_object_loader from plumpy.process_listener import ProcessListener from . import lang, persistence, process_states, processes @@ -221,7 +220,7 @@ def recreate_from( else: proc._loop = asyncio.get_event_loop() - proc._state: state_machine.State = proc.recreate_state(saved_state['_state']) + proc._state = proc.recreate_state(saved_state['_state']) if 'communicator' in load_context: proc._communicator = load_context.communicator @@ -230,7 +229,7 @@ def recreate_from( proc._logger = load_context.logger # Need to call this here as things downstream may rely on us having the runtime variable above - persistence.auto_load(proc, saved_state, load_context) + persistence.load_auto_persist_params(proc, saved_state, load_context) # Inputs/outputs try: @@ -370,8 +369,7 @@ def recreate_from( """ load_context = ensure_object_loader(load_context, saved_state) - obj = cls.__new__(cls) - persistence.auto_load(obj, saved_state, load_context) + obj = persistence.auto_load(cls, saved_state, load_context) obj._workchain = load_context.workchain obj._fn = getattr(obj._workchain.__class__, saved_state['_fn']) @@ -444,7 +442,7 @@ def finished(self) -> bool: def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SAVED_STATE_TYPE: out_state: SAVED_STATE_TYPE = persistence.auto_save(self, save_context) - if self._child_stepper is not None: + if self._child_stepper is not None and isinstance(self._child_stepper, Savable): out_state[STEPPER_STATE] = self._child_stepper.save() return out_state @@ -461,8 +459,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa """ load_context = ensure_object_loader(load_context, saved_state) - obj = cls.__new__(cls) - persistence.auto_load(obj, saved_state, load_context) + obj = persistence.auto_load(cls, saved_state, load_context) obj._workchain = load_context.workchain obj._block = load_context.block_instruction stepper_state = saved_state.get(STEPPER_STATE, None) @@ -599,7 +596,7 @@ def finished(self) -> bool: def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SAVED_STATE_TYPE: out_state: SAVED_STATE_TYPE = persistence.auto_save(self, save_context) - if self._child_stepper is not None: + if self._child_stepper is not None and isinstance(self._child_stepper, Savable): out_state[STEPPER_STATE] = self._child_stepper.save() return out_state @@ -616,8 +613,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa """ load_context = ensure_object_loader(load_context, saved_state) - obj = cls.__new__(cls) - persistence.auto_load(obj, saved_state, load_context) + obj = persistence.auto_load(cls, saved_state, load_context) obj._workchain = load_context.workchain obj._if_instruction = load_context.if_instruction stepper_state = saved_state.get(STEPPER_STATE, None) @@ -729,8 +725,7 @@ def recreate_from( """ load_context = ensure_object_loader(load_context, saved_state) - obj = cls.__new__(cls) - persistence.auto_load(obj, saved_state, load_context) + obj = persistence.auto_load(cls, saved_state, load_context) obj._workchain = load_context.workchain obj._while_instruction = load_context.while_instruction stepper_state = saved_state.get(STEPPER_STATE, None) diff --git a/tests/base/test_statemachine.py b/tests/base/test_statemachine.py index f046aaa8..44a084d4 100644 --- a/tests/base/test_statemachine.py +++ b/tests/base/test_statemachine.py @@ -150,22 +150,22 @@ def stop(self): class TestStateMachine(unittest.TestCase): def test_basic(self): cd_player = CdPlayer() - self.assertEqual(cd_player.state, STOPPED) + self.assertEqual(cd_player.state_label, STOPPED) cd_player.play('Eminem - The Real Slim Shady') - self.assertEqual(cd_player.state, PLAYING) + self.assertEqual(cd_player.state_label, PLAYING) time.sleep(1.0) cd_player.pause() - self.assertEqual(cd_player.state, PAUSED) + self.assertEqual(cd_player.state_label, PAUSED) cd_player.play() - self.assertEqual(cd_player.state, PLAYING) + self.assertEqual(cd_player.state_label, PLAYING) self.assertEqual(cd_player.play(), False) cd_player.stop() - self.assertEqual(cd_player.state, STOPPED) + self.assertEqual(cd_player.state_label, STOPPED) def test_invalid_event(self): cd_player = CdPlayer() diff --git a/tests/rmq/test_process_control.py b/tests/rmq/test_process_control.py index 79a98ba3..7c3b431c 100644 --- a/tests/rmq/test_process_control.py +++ b/tests/rmq/test_process_control.py @@ -68,7 +68,7 @@ async def test_play(self, _coordinator, async_controller): # Check that all is as we expect assert result - assert proc.state == plumpy.ProcessState.WAITING + assert proc.state_label == plumpy.ProcessState.WAITING # if not close the background process will raise exception # make sure proc reach the final state @@ -85,7 +85,7 @@ async def test_kill(self, _coordinator, async_controller): # Check the outcome assert result - assert proc.state == plumpy.ProcessState.KILLED + assert proc.state_label == plumpy.ProcessState.KILLED @pytest.mark.asyncio async def test_status(self, _coordinator, async_controller): @@ -173,7 +173,7 @@ async def test_play(self, _coordinator, sync_controller): # Check that all is as we expect assert result - assert proc.state == plumpy.ProcessState.CREATED + assert proc.state_label == plumpy.ProcessState.CREATED @pytest.mark.asyncio async def test_kill(self, _coordinator, sync_controller): @@ -187,7 +187,7 @@ async def test_kill(self, _coordinator, sync_controller): # Check the outcome assert result # Occasionally fail - assert proc.state == plumpy.ProcessState.KILLED + assert proc.state_label == plumpy.ProcessState.KILLED @pytest.mark.asyncio async def test_kill_all(self, _coordinator, sync_controller): @@ -198,7 +198,7 @@ async def test_kill_all(self, _coordinator, sync_controller): sync_controller.kill_all(msg_text='bang bang, I shot you down') await utils.wait_util(lambda: all([proc.killed() for proc in procs])) - assert all([proc.state == plumpy.ProcessState.KILLED for proc in procs]) + assert all([proc.state_label == plumpy.ProcessState.KILLED for proc in procs]) @pytest.mark.asyncio async def test_status(self, _coordinator, sync_controller): diff --git a/tests/test_persistence.py b/tests/test_persistence.py index 4ec4c1a5..7f616433 100644 --- a/tests/test_persistence.py +++ b/tests/test_persistence.py @@ -5,7 +5,7 @@ import yaml import plumpy -from plumpy.persistence import auto_load, auto_persist, auto_save +from plumpy.persistence import auto_load, auto_persist, auto_save, ensure_object_loader from plumpy.utils import SAVED_STATE_TYPE from . import utils @@ -25,8 +25,8 @@ def recreate_from(cls, saved_state, load_context=None): :return: The recreated instance """ - obj = cls.__new__(cls) - auto_load(obj, saved_state, load_context) + load_context = ensure_object_loader(load_context, saved_state) + obj = auto_load(cls, saved_state, load_context) return obj def save(self, save_context=None) -> SAVED_STATE_TYPE: @@ -55,8 +55,8 @@ def recreate_from(cls, saved_state, load_context=None): :return: The recreated instance """ - obj = cls.__new__(cls) - auto_load(obj, saved_state, load_context) + load_context = ensure_object_loader(load_context, saved_state) + obj = auto_load(cls, saved_state, load_context) return obj def save(self, save_context=None) -> SAVED_STATE_TYPE: @@ -81,8 +81,8 @@ def recreate_from(cls, saved_state, load_context=None): :return: The recreated instance """ - obj = cls.__new__(cls) - auto_load(obj, saved_state, load_context) + load_context = ensure_object_loader(load_context, saved_state) + obj = auto_load(cls, saved_state, load_context) return obj def save(self, save_context=None) -> SAVED_STATE_TYPE: diff --git a/tests/test_processes.py b/tests/test_processes.py index e2b0f640..f989bef7 100644 --- a/tests/test_processes.py +++ b/tests/test_processes.py @@ -22,6 +22,7 @@ # FIXME: any process listener is savable # FIXME: any process control commands are savable + class ForgetToCallParent(plumpy.Process): def __init__(self, forget_on): super().__init__() @@ -242,7 +243,7 @@ def test_execute(self): proc.execute() self.assertTrue(proc.has_terminated()) - self.assertEqual(proc.state, ProcessState.FINISHED) + self.assertEqual(proc.state_label, ProcessState.FINISHED) self.assertEqual(proc.outputs, {'default': 5}) def test_run_from_class(self): @@ -280,7 +281,7 @@ def test_exception(self): proc = utils.ExceptionProcess() with self.assertRaises(RuntimeError): proc.execute() - self.assertEqual(proc.state, ProcessState.EXCEPTED) + self.assertEqual(proc.state_label, ProcessState.EXCEPTED) def test_run_kill(self): proc = utils.KillProcess() @@ -347,7 +348,7 @@ def test_wait_continue(self): # Check it's done self.assertTrue(proc.has_terminated()) - self.assertEqual(proc.state, ProcessState.FINISHED) + self.assertEqual(proc.state_label, ProcessState.FINISHED) def test_exc_info(self): proc = utils.ExceptionProcess() @@ -371,7 +372,7 @@ def test_wait_pause_play_resume(self): async def async_test(): await utils.run_until_waiting(proc) - self.assertEqual(proc.state, ProcessState.WAITING) + self.assertEqual(proc.state_label, ProcessState.WAITING) result = await proc.pause() self.assertTrue(result) @@ -387,7 +388,7 @@ async def async_test(): # Check it's done self.assertTrue(proc.has_terminated()) - self.assertEqual(proc.state, ProcessState.FINISHED) + self.assertEqual(proc.state_label, ProcessState.FINISHED) loop.create_task(proc.step_until_terminated()) loop.run_until_complete(async_test()) @@ -408,7 +409,7 @@ def test_pause_play_status_messaging(self): async def async_test(): await utils.run_until_waiting(proc) - self.assertEqual(proc.state, ProcessState.WAITING) + self.assertEqual(proc.state_label, ProcessState.WAITING) result = await proc.pause(PAUSE_STATUS) self.assertTrue(result) @@ -428,7 +429,7 @@ async def async_test(): loop.run_until_complete(async_test()) self.assertTrue(proc.has_terminated()) - self.assertEqual(proc.state, ProcessState.FINISHED) + self.assertEqual(proc.state_label, ProcessState.FINISHED) def test_kill_in_run(self): class KillProcess(Process): @@ -446,7 +447,7 @@ def run(self, **kwargs): proc.execute() self.assertTrue(proc.after_kill) - self.assertEqual(proc.state, ProcessState.KILLED) + self.assertEqual(proc.state_label, ProcessState.KILLED) def test_kill_when_paused_in_run(self): class PauseProcess(Process): @@ -458,7 +459,7 @@ def run(self, **kwargs): with self.assertRaises(plumpy.KilledError): proc.execute() - self.assertEqual(proc.state, ProcessState.KILLED) + self.assertEqual(proc.state_label, ProcessState.KILLED) def test_kill_when_paused(self): loop = asyncio.get_event_loop() @@ -482,7 +483,7 @@ async def async_test(): loop.create_task(proc.step_until_terminated()) loop.run_until_complete(async_test()) - self.assertEqual(proc.state, ProcessState.KILLED) + self.assertEqual(proc.state_label, ProcessState.KILLED) def test_run_multiple(self): # Create and play some processes @@ -558,7 +559,7 @@ def run(self): loop.run_forever() self.assertTrue(proc.paused) - self.assertEqual(plumpy.ProcessState.FINISHED, proc.state) + self.assertEqual(proc.state_label, plumpy.ProcessState.FINISHED) def test_pause_play_in_process(self): """Test that we can pause and play that by playing within the process""" @@ -576,7 +577,7 @@ def run(self): proc.execute() self.assertFalse(proc.paused) - self.assertEqual(plumpy.ProcessState.FINISHED, proc.state) + self.assertEqual(proc.state_label, plumpy.ProcessState.FINISHED) def test_process_stack(self): test_case = self @@ -787,7 +788,7 @@ def test_saving_each_step(self): proc = proc_class() saver = utils.ProcessSaver(proc) saver.capture() - self.assertEqual(proc.state, ProcessState.FINISHED) + self.assertEqual(proc.state_label, ProcessState.FINISHED) self.assertTrue(utils.check_process_against_snapshots(loop, proc_class, saver.snapshots)) def test_restart(self): @@ -802,7 +803,7 @@ async def async_test(): # Load a process from the saved state loaded_proc = saved_state.unbundle() - self.assertEqual(loaded_proc.state, ProcessState.WAITING) + self.assertEqual(loaded_proc.state_label, ProcessState.WAITING) # Now resume it loaded_proc.resume() @@ -825,7 +826,7 @@ async def async_test(): # Load a process from the saved state loaded_proc = saved_state.unbundle() - self.assertEqual(loaded_proc.state, ProcessState.WAITING) + self.assertEqual(loaded_proc.state_label, ProcessState.WAITING) # Now resume it twice in succession loaded_proc.resume() @@ -867,7 +868,7 @@ async def async_test(): def test_killed(self): proc = utils.DummyProcess() proc.kill() - self.assertEqual(proc.state, plumpy.ProcessState.KILLED) + self.assertEqual(proc.state_label, plumpy.ProcessState.KILLED) self._check_round_trip(proc) def _check_round_trip(self, proc1): @@ -990,40 +991,40 @@ def run(self): self.out(namespace_nested + '.two', 2) # Run the process in default mode which should not add any outputs and therefore fail - process = DummyDynamicProcess() - process.execute() + proc = DummyDynamicProcess() + proc.execute() - self.assertEqual(process.state, ProcessState.FINISHED) - self.assertFalse(process.is_successful) - self.assertDictEqual(process.outputs, {}) + self.assertEqual(proc.state_label, ProcessState.FINISHED) + self.assertFalse(proc.is_successful) + self.assertDictEqual(proc.outputs, {}) # Attaching only namespaced ports should fail, because the required port is not added - process = DummyDynamicProcess(inputs={'output_mode': OutputMode.DYNAMIC_PORT_NAMESPACE}) - process.execute() + proc = DummyDynamicProcess(inputs={'output_mode': OutputMode.DYNAMIC_PORT_NAMESPACE}) + proc.execute() - self.assertEqual(process.state, ProcessState.FINISHED) - self.assertFalse(process.is_successful) - self.assertEqual(process.outputs[namespace]['nested']['one'], 1) - self.assertEqual(process.outputs[namespace]['nested']['two'], 2) + self.assertEqual(proc.state_label, ProcessState.FINISHED) + self.assertFalse(proc.is_successful) + self.assertEqual(proc.outputs[namespace]['nested']['one'], 1) + self.assertEqual(proc.outputs[namespace]['nested']['two'], 2) # Attaching only the single required top-level port should be fine - process = DummyDynamicProcess(inputs={'output_mode': OutputMode.SINGLE_REQUIRED_PORT}) - process.execute() + proc = DummyDynamicProcess(inputs={'output_mode': OutputMode.SINGLE_REQUIRED_PORT}) + proc.execute() - self.assertEqual(process.state, ProcessState.FINISHED) - self.assertTrue(process.is_successful) - self.assertEqual(process.outputs['required_bool'], False) + self.assertEqual(proc.state_label, ProcessState.FINISHED) + self.assertTrue(proc.is_successful) + self.assertEqual(proc.outputs['required_bool'], False) # Attaching both the required and namespaced ports should result in a successful termination - process = DummyDynamicProcess(inputs={'output_mode': OutputMode.BOTH_SINGLE_AND_NAMESPACE}) - process.execute() - - self.assertIsNotNone(process.outputs) - self.assertEqual(process.state, ProcessState.FINISHED) - self.assertTrue(process.is_successful) - self.assertEqual(process.outputs['required_bool'], False) - self.assertEqual(process.outputs[namespace]['nested']['one'], 1) - self.assertEqual(process.outputs[namespace]['nested']['two'], 2) + proc = DummyDynamicProcess(inputs={'output_mode': OutputMode.BOTH_SINGLE_AND_NAMESPACE}) + proc.execute() + + self.assertIsNotNone(proc.outputs) + self.assertEqual(proc.state_label, ProcessState.FINISHED) + self.assertTrue(proc.is_successful) + self.assertEqual(proc.outputs['required_bool'], False) + self.assertEqual(proc.outputs[namespace]['nested']['one'], 1) + self.assertEqual(proc.outputs[namespace]['nested']['two'], 2) class TestProcessEvents(unittest.TestCase): diff --git a/tests/utils.py b/tests/utils.py index 3d4458f4..18082fd4 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -596,7 +596,7 @@ def run_until_waiting(proc): listener = plumpy.ProcessListener() in_waiting = asyncio.Future() - if proc.state == ProcessState.WAITING: + if proc.state_label == ProcessState.WAITING: in_waiting.set_result(True) else: