From 94937792caab27a889c90cb781fc921d3651aa27 Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Mon, 9 Dec 2024 23:14:28 +0100 Subject: [PATCH] forming Savable protocol - remove persist_config flag of savable --- src/plumpy/event_helper.py | 12 +- src/plumpy/persistence.py | 245 ++++++++++++++++----------------- src/plumpy/process_listener.py | 16 ++- src/plumpy/process_states.py | 71 ++++++---- src/plumpy/processes.py | 7 +- src/plumpy/workchains.py | 75 ++++------ tests/test_persistence.py | 32 +++-- tests/test_processes.py | 4 + tests/test_workchains.py | 2 + 9 files changed, 248 insertions(+), 216 deletions(-) diff --git a/src/plumpy/event_helper.py b/src/plumpy/event_helper.py index e20dae3f..abc2b24b 100644 --- a/src/plumpy/event_helper.py +++ b/src/plumpy/event_helper.py @@ -2,20 +2,21 @@ import logging from typing import TYPE_CHECKING, Any, Callable, Optional +from plumpy.persistence import LoadSaveContext, Savable, auto_load, auto_save, ensure_object_loader from plumpy.utils import SAVED_STATE_TYPE from . import persistence -from plumpy.persistence import Savable, LoadSaveContext, _ensure_object_loader, auto_load if TYPE_CHECKING: from typing import Set, Type + from .process_listener import ProcessListener _LOGGER = logging.getLogger(__name__) @persistence.auto_persist('_listeners', '_listener_type') -class EventHelper(persistence.Savable): +class EventHelper: def __init__(self, listener_type: 'Type[ProcessListener]'): assert listener_type is not None, 'Must provide valid listener type' @@ -43,11 +44,16 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa :return: The recreated instance """ - load_context = _ensure_object_loader(load_context, saved_state) + load_context = ensure_object_loader(load_context, saved_state) obj = cls.__new__(cls) auto_load(obj, saved_state, load_context) return obj + def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) + + return out_state + @property def listeners(self) -> 'Set[ProcessListener]': return self._listeners diff --git a/src/plumpy/persistence.py b/src/plumpy/persistence.py index 0963445e..afe82439 100644 --- a/src/plumpy/persistence.py +++ b/src/plumpy/persistence.py @@ -9,12 +9,24 @@ import os import pickle from types import MethodType -from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, Iterable, List, Optional, Set, TypeVar, Union +from typing import ( + TYPE_CHECKING, + Any, + Callable, + ClassVar, + Dict, + Generator, + Iterable, + List, + Optional, + Protocol, + cast, + runtime_checkable, +) import yaml from . import futures, loaders, utils -from .base.utils import call_with_super_check, super_check from .utils import PID_TYPE, SAVED_STATE_TYPE PersistedCheckpoint = collections.namedtuple('PersistedCheckpoint', ['pid', 'tag']) @@ -88,10 +100,10 @@ def load(saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext | None = N :return: The loaded Savable instance """ - load_context = _ensure_object_loader(load_context, saved_state) + load_context = ensure_object_loader(load_context, saved_state) assert load_context.loader is not None # required for type checking try: - class_name = Savable._get_class_name(saved_state) + class_name = SaveUtil.get_class_name(saved_state) load_cls: Savable = load_context.loader.load_object(class_name) except KeyError: raise ValueError('Class name not found in saved state') @@ -380,22 +392,7 @@ def delete_process_checkpoints(self, pid: PID_TYPE) -> None: del self._checkpoints[pid] -SavableClsType = TypeVar('SavableClsType', bound='type[Savable]') - - -def auto_persist(*members: str) -> Callable[[SavableClsType], SavableClsType]: - def wrapped(savable: SavableClsType) -> SavableClsType: - if savable._auto_persist is None: - savable._auto_persist = set() - else: - savable._auto_persist = set(savable._auto_persist) - savable.auto_persist(*members) - return savable - - return wrapped - - -def _ensure_object_loader(context: Optional['LoadSaveContext'], saved_state: SAVED_STATE_TYPE) -> 'LoadSaveContext': +def ensure_object_loader(context: Optional['LoadSaveContext'], saved_state: SAVED_STATE_TYPE) -> 'LoadSaveContext': """ Given a LoadSaveContext this method will ensure that it has a valid class loader using the following priorities: @@ -417,7 +414,7 @@ def _ensure_object_loader(context: Optional['LoadSaveContext'], saved_state: SAV # 2) Try getting from saved_state default_loader = loaders.get_object_loader() try: - loader_identifier = Savable.get_custom_meta(saved_state, META__OBJECT_LOADER) + loader_identifier = SaveUtil.get_custom_meta(saved_state, META__OBJECT_LOADER) except ValueError: # 3) Fall back to default loader = default_loader @@ -436,45 +433,10 @@ def _ensure_object_loader(context: Optional['LoadSaveContext'], saved_state: SAV META__TYPE__SAVABLE: str = 'S' -class Savable: - CLASS_NAME: str = 'class_name' - - _auto_persist: Optional[Set[str]] = None - _persist_configured = False - - @classmethod - def auto_persist(cls, *members: str) -> None: - if cls._auto_persist is None: - cls._auto_persist = set() - cls._auto_persist.update(members) - - @classmethod - def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': - """ - Recreate a :class:`Savable` from a saved state using an optional load context. - - :param saved_state: The saved state - :param load_context: An optional load context - - :return: The recreated instance - - """ - ... - - def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: - out_state: SAVED_STATE_TYPE = auto_save(self, save_context) - - return out_state - - def _ensure_persist_configured(self) -> None: - if not self._persist_configured: - self._persist_configured = True - - # region Metadata getter/setters - +class SaveUtil: @staticmethod def set_custom_meta(out_state: SAVED_STATE_TYPE, name: str, value: Any) -> None: - user_dict = Savable._get_create_meta(out_state).setdefault(META__USER, {}) + user_dict = SaveUtil.get_create_meta(out_state).setdefault(META__USER, {}) user_dict[name] = value @staticmethod @@ -485,47 +447,127 @@ def get_custom_meta(saved_state: SAVED_STATE_TYPE, name: str) -> Any: raise ValueError(f"Unknown meta key '{name}'") @staticmethod - def _get_create_meta(out_state: SAVED_STATE_TYPE) -> Dict[str, Any]: + def get_create_meta(out_state: SAVED_STATE_TYPE) -> Dict[str, Any]: return out_state.setdefault(META, {}) @staticmethod - def _set_class_name(out_state: SAVED_STATE_TYPE, name: str) -> None: - Savable._get_create_meta(out_state)[META__CLASS_NAME] = name + def set_class_name(out_state: SAVED_STATE_TYPE, name: str) -> None: + SaveUtil.get_create_meta(out_state)[META__CLASS_NAME] = name @staticmethod - def _get_class_name(saved_state: SAVED_STATE_TYPE) -> str: - return Savable._get_create_meta(saved_state)[META__CLASS_NAME] + def get_class_name(saved_state: SAVED_STATE_TYPE) -> str: + return SaveUtil.get_create_meta(saved_state)[META__CLASS_NAME] @staticmethod - def _set_meta_type(out_state: SAVED_STATE_TYPE, name: str, type_spec: Any) -> None: - type_dict = Savable._get_create_meta(out_state).setdefault(META__TYPES, {}) + def set_meta_type(out_state: SAVED_STATE_TYPE, name: str, type_spec: Any) -> None: + type_dict = SaveUtil.get_create_meta(out_state).setdefault(META__TYPES, {}) type_dict[name] = type_spec @staticmethod - def _get_meta_type(saved_state: SAVED_STATE_TYPE, name: str) -> Any: + def get_meta_type(saved_state: SAVED_STATE_TYPE, name: str) -> Any: try: return saved_state[META][META__TYPES][name] except KeyError: pass - # endregion - def _get_value( - self, saved_state: SAVED_STATE_TYPE, name: str, load_context: Optional[LoadSaveContext] - ) -> Union[MethodType, 'Savable']: - value = saved_state[name] +@runtime_checkable +class Savable(Protocol): + @classmethod + def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext | None = None) -> 'Savable': + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance + + """ + ... + + def save(self, save_context: LoadSaveContext | None = None) -> SAVED_STATE_TYPE: ... + + +@runtime_checkable +class SavableWithAutoPersist(Savable, Protocol): + _auto_persist: ClassVar[set[str]] = set() + + +def auto_save(obj: Savable, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = {} + + if save_context is None: + save_context = LoadSaveContext() + + utils.type_check(save_context, LoadSaveContext) + + default_loader = loaders.get_object_loader() + # If the user has specified a class loader, then save it in the saved state + if save_context.loader is not None: + loader_class = default_loader.identify_object(save_context.loader.__class__) + SaveUtil.set_custom_meta(out_state, META__OBJECT_LOADER, loader_class) + loader = save_context.loader + else: + loader = default_loader + + SaveUtil.set_class_name(out_state, loader.identify_object(obj.__class__)) + + if isinstance(obj, SavableWithAutoPersist): + for member in obj._auto_persist: + value = getattr(obj, member) + if inspect.ismethod(value): + if value.__self__ is not obj: + raise TypeError('Cannot persist methods of other classes') + SaveUtil.set_meta_type(out_state, member, META__TYPE__METHOD) + value = value.__name__ + elif isinstance(value, Savable) and not isinstance(value, type): + # persist for a savable obj, call `save` method of obj. + SaveUtil.set_meta_type(out_state, member, META__TYPE__SAVABLE) + value = value.save() + else: + value = copy.deepcopy(value) + out_state[member] = value + + return out_state + + +def auto_load(obj: SavableWithAutoPersist, saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext) -> None: + for member in obj._auto_persist: + setattr(obj, member, _get_value(obj, saved_state, member, load_context)) + - typ = Savable._get_meta_type(saved_state, name) - if typ == META__TYPE__METHOD: - value = getattr(self, value) - elif typ == META__TYPE__SAVABLE: - value = load(value, load_context) +def _get_value( + obj: Any, saved_state: SAVED_STATE_TYPE, name: str, load_context: LoadSaveContext | None +) -> MethodType | Savable: + value = saved_state[name] - return value + typ = SaveUtil.get_meta_type(saved_state, name) + if typ == META__TYPE__METHOD: + value = getattr(obj, value) + elif typ == META__TYPE__SAVABLE: + value = load(value, load_context) + + return value + + +def auto_persist(*members: str) -> Callable[..., Savable]: + def wrapped(savable_cls: type) -> Savable: + if not hasattr(savable_cls, '_auto_persist') or savable_cls._auto_persist is None: + savable_cls._auto_persist = set() # type: ignore[attr-defined] + else: + savable_cls._auto_persist = set(savable_cls._auto_persist) + savable_cls._auto_persist.update(members) # type: ignore[attr-defined] + # XXX: validate on `save` and `recreate_from` method?? + return cast(Savable, savable_cls) + return wrapped + + +# FIXME: move me to another module? savablefuture.py? @auto_persist('_state', '_result') -class SavableFuture(futures.Future, Savable): +class SavableFuture(futures.Future): """ A savable future. @@ -550,7 +592,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa :return: The recreated instance """ - load_context = _ensure_object_loader(load_context, saved_state) + load_context = ensure_object_loader(load_context, saved_state) try: loop = load_context.loop @@ -586,48 +628,3 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa # ## UNTILHERE XXX: return obj - - -def auto_save(obj: Savable, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: - out_state: SAVED_STATE_TYPE = {} - - if save_context is None: - save_context = LoadSaveContext() - - utils.type_check(save_context, LoadSaveContext) - - default_loader = loaders.get_object_loader() - # If the user has specified a class loader, then save it in the saved state - if save_context.loader is not None: - loader_class = default_loader.identify_object(save_context.loader.__class__) - Savable.set_custom_meta(out_state, META__OBJECT_LOADER, loader_class) - loader = save_context.loader - else: - loader = default_loader - - Savable._set_class_name(out_state, loader.identify_object(obj.__class__)) - - obj._ensure_persist_configured() - if obj._auto_persist is not None: - for member in obj._auto_persist: - value = getattr(obj, member) - if inspect.ismethod(value): - if value.__self__ is not obj: - raise TypeError('Cannot persist methods of other classes') - Savable._set_meta_type(out_state, member, META__TYPE__METHOD) - value = value.__name__ - elif isinstance(value, Savable): - Savable._set_meta_type(out_state, member, META__TYPE__SAVABLE) - value = value.save() - else: - value = copy.deepcopy(value) - out_state[member] = value - - return out_state - - -def auto_load(obj: Savable, saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext) -> None: - obj._ensure_persist_configured() - if obj._auto_persist is not None: - for member in obj._auto_persist: - setattr(obj, member, obj._get_value(saved_state, member, load_context)) diff --git a/src/plumpy/process_listener.py b/src/plumpy/process_listener.py index e84b504d..8e9673bb 100644 --- a/src/plumpy/process_listener.py +++ b/src/plumpy/process_listener.py @@ -2,16 +2,21 @@ import abc from typing import TYPE_CHECKING, Any, Dict, Optional +from plumpy.persistence import LoadSaveContext, auto_save, ensure_object_loader + from . import persistence from .utils import SAVED_STATE_TYPE -from plumpy.persistence import LoadSaveContext, _ensure_object_loader if TYPE_CHECKING: + from plumpy.persistence import Savable + from .processes import Process +# FIXME: test any process listener is a savable + @persistence.auto_persist('_params') -class ProcessListener(persistence.Savable, metaclass=abc.ABCMeta): +class ProcessListener(metaclass=abc.ABCMeta): # region Persistence methods def __init__(self) -> None: @@ -32,11 +37,16 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa :return: The recreated instance """ - load_context = _ensure_object_loader(load_context, saved_state) + load_context = ensure_object_loader(load_context, saved_state) obj = cls.__new__(cls) obj.init(**saved_state['_params']) return obj + def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) + + return out_state + # endregion def on_process_created(self, process: 'Process') -> None: diff --git a/src/plumpy/process_states.py b/src/plumpy/process_states.py index 3a1d128f..0659d1da 100644 --- a/src/plumpy/process_states.py +++ b/src/plumpy/process_states.py @@ -1,8 +1,6 @@ # -*- coding: utf-8 -*- from __future__ import annotations -import copy -import inspect import sys import traceback from enum import Enum @@ -13,13 +11,12 @@ Callable, ClassVar, Optional, - Protocol, Tuple, Type, Union, cast, final, - runtime_checkable, + override, ) import yaml @@ -69,8 +66,26 @@ def __init__(self, msg_text: str | None): # region Commands -class Command(persistence.Savable): - pass +class Command: + @classmethod + def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance + + """ + obj = cls.__new__(cls) + auto_load(obj, saved_state, load_context) + return obj + + def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) + + return out_state @auto_persist('msg') @@ -116,12 +131,14 @@ def __init__(self, continue_fn: Callable[..., Any], *args: Any, **kwargs: Any): self.args = args self.kwargs = kwargs + @override def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SAVED_STATE_TYPE: out_state: SAVED_STATE_TYPE = persistence.auto_save(self, save_context) out_state[self.CONTINUE_FN] = self.continue_fn.__name__ return out_state + @override @classmethod def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': """ @@ -133,7 +150,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa :return: The recreated instance """ - load_context = _ensure_object_loader(load_context, saved_state) + load_context = ensure_object_loader(load_context, saved_state) obj = cls.__new__(cls) auto_load(obj, saved_state, load_context) @@ -164,14 +181,9 @@ class ProcessState(Enum): KILLED = 'killed' -# @runtime_checkable -# class Savable(Protocol): -# def save(self, save_context: LoadSaveContext | None = None) -> SAVED_STATE_TYPE: ... - - @final @auto_persist('args', 'kwargs') -class Created(persistence.Savable): +class Created: LABEL: ClassVar = ProcessState.CREATED ALLOWED: ClassVar = {ProcessState.RUNNING, ProcessState.KILLED, ProcessState.EXCEPTED} @@ -202,7 +214,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa :return: The recreated instance """ - load_context = _ensure_object_loader(load_context, saved_state) + load_context = ensure_object_loader(load_context, saved_state) obj = cls.__new__(cls) auto_load(obj, saved_state, load_context) @@ -225,7 +237,7 @@ def exit(self) -> None: ... @final @auto_persist('args', 'kwargs') -class Running(persistence.Savable): +class Running: LABEL: ClassVar = ProcessState.RUNNING ALLOWED: ClassVar = { ProcessState.RUNNING, @@ -273,7 +285,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa :return: The recreated instance """ - load_context = _ensure_object_loader(load_context, saved_state) + load_context = ensure_object_loader(load_context, saved_state) obj = cls.__new__(cls) auto_load(obj, saved_state, load_context) @@ -357,7 +369,7 @@ def exit(self) -> None: ... @auto_persist('msg', 'data') -class Waiting(persistence.Savable): +class Waiting: LABEL: ClassVar = ProcessState.WAITING ALLOWED: ClassVar = { ProcessState.RUNNING, @@ -411,7 +423,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa :return: The recreated instance """ - load_context = _ensure_object_loader(load_context, saved_state) + load_context = ensure_object_loader(load_context, saved_state) obj = cls.__new__(cls) auto_load(obj, saved_state, load_context) @@ -464,7 +476,8 @@ def exit(self) -> None: ... @final -class Excepted(persistence.Savable): +@auto_persist() +class Excepted: """ Excepted state, can optionally provide exception and traceback @@ -516,7 +529,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa :return: The recreated instance """ - load_context = _ensure_object_loader(load_context, saved_state) + load_context = ensure_object_loader(load_context, saved_state) obj = cls.__new__(cls) auto_load(obj, saved_state, load_context) @@ -549,7 +562,7 @@ def exit(self) -> None: ... @final @auto_persist('result', 'successful') -class Finished(persistence.Savable): +class Finished: """State for process is finished. :param result: The result of process @@ -576,11 +589,16 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa :return: The recreated instance """ - load_context = _ensure_object_loader(load_context, saved_state) + load_context = ensure_object_loader(load_context, saved_state) obj = cls.__new__(cls) auto_load(obj, saved_state, load_context) return obj + def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) + + return out_state + def enter(self) -> None: ... def exit(self) -> None: ... @@ -588,7 +606,7 @@ def exit(self) -> None: ... @final @auto_persist('msg') -class Killed(persistence.Savable): +class Killed: """ Represents a state where a process has been killed. @@ -620,11 +638,16 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa :return: The recreated instance """ - load_context = _ensure_object_loader(load_context, saved_state) + load_context = ensure_object_loader(load_context, saved_state) obj = cls.__new__(cls) auto_load(obj, saved_state, load_context) return obj + def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) + + return out_state + def enter(self) -> None: ... def exit(self) -> None: ... diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index 6894d0ef..e8444bc5 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -116,7 +116,7 @@ def func_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: '_pre_paused_status', '_event_helper', ) -class Process(StateMachine, persistence.Savable, metaclass=ProcessStateMachineMeta): +class Process(StateMachine, metaclass=ProcessStateMachineMeta): """ The Process class is the base for any unit of work in plumpy. @@ -265,7 +265,7 @@ def recreate_from( :return: An instance of the object with its state loaded from the save state. """ - load_context = _ensure_object_loader(load_context, saved_state) + load_context = ensure_object_loader(load_context, saved_state) proc = cls.__new__(cls) # XXX: load_instance_state @@ -673,8 +673,7 @@ def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SA """ out_state: SAVED_STATE_TYPE = persistence.auto_save(self, save_context) - # FIXME: the combined ProcessState protocol should cover the case - if isinstance(self._state, process_states.Savable): + if isinstance(self._state, persistence.Savable): out_state['_state'] = self._state.save() # Inputs/outputs diff --git a/src/plumpy/workchains.py b/src/plumpy/workchains.py index 074edc3e..348be7d1 100644 --- a/src/plumpy/workchains.py +++ b/src/plumpy/workchains.py @@ -1,7 +1,6 @@ # -*- coding: utf-8 -*- from __future__ import annotations -import copy import abc import asyncio import collections @@ -9,6 +8,7 @@ import logging import re from typing import ( + TYPE_CHECKING, Any, Callable, Dict, @@ -26,16 +26,17 @@ from plumpy.coordinator import Coordinator +from plumpy import utils from plumpy.base import state_machine from plumpy.base.utils import call_with_super_check from plumpy.event_helper import EventHelper from plumpy.exceptions import InvalidStateError +from plumpy.persistence import LoadSaveContext, auto_persist, auto_save, ensure_object_loader, Savable from plumpy.process_listener import ProcessListener from . import lang, persistence, process_states, processes from .utils import PID_TYPE, SAVED_STATE_TYPE, AttributesDict -from plumpy import loaders, utils -from plumpy.persistence import _ensure_object_loader + ToContext = dict @@ -162,41 +163,9 @@ def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SA :param out_state: A bundle to save the state to :param save_context: The save context """ - out_state: SAVED_STATE_TYPE = {} - - if save_context is None: - save_context = persistence.LoadSaveContext() + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) - utils.type_check(save_context, persistence.LoadSaveContext) - - default_loader = loaders.get_object_loader() - # If the user has specified a class loader, then save it in the saved state - if save_context.loader is not None: - loader_class = default_loader.identify_object(save_context.loader.__class__) - persistence.Savable.set_custom_meta(out_state, persistence.META__OBJECT_LOADER, loader_class) - loader = save_context.loader - else: - loader = default_loader - - persistence.Savable._set_class_name(out_state, loader.identify_object(self.__class__)) - - self._ensure_persist_configured() - if self._auto_persist is not None: - for member in self._auto_persist: - value = getattr(self, member) - if inspect.ismethod(value): - if value.__self__ is not self: - raise TypeError('Cannot persist methods of other classes') - persistence.Savable._set_meta_type(out_state, member, persistence.META__TYPE__METHOD) - value = value.__name__ - elif isinstance(value, persistence.Savable): - persistence.Savable._set_meta_type(out_state, member, persistence.META__TYPE__SAVABLE) - value = value.save() - else: - value = copy.deepcopy(value) - out_state[member] = value - - if isinstance(self._state, process_states.Savable): + if isinstance(self._state, persistence.Savable): out_state['_state'] = self._state.save() # Inputs/outputs @@ -210,7 +179,7 @@ def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SA out_state[processes.BundleKeys.OUTPUTS] = self.encode_input_args(self.outputs) # Ask the stepper to save itself - if self._stepper is not None: + if self._stepper is not None and isinstance(self._stepper, Savable): out_state[self._STEPPER_STATE] = self._stepper.save() if self._context is not None: @@ -232,7 +201,7 @@ def recreate_from( """ ### FIXME: dup from process.create_from - load_context = _ensure_object_loader(load_context, saved_state) + load_context = ensure_object_loader(load_context, saved_state) proc = cls.__new__(cls) # XXX: load_instance_state @@ -375,7 +344,8 @@ def get_description(self) -> Any: """ -class _FunctionStepper(persistence.Savable): +@auto_persist() +class _FunctionStepper: def __init__(self, workchain: 'WorkChain', fn: WC_COMMAND_TYPE): self._workchain = workchain self._fn = fn @@ -387,7 +357,9 @@ def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SA return out_state @classmethod - def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': + def recreate_from( + cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[persistence.LoadSaveContext] = None + ) -> 'Savable': """ Recreate a :class:`Savable` from a saved state using an optional load context. @@ -397,7 +369,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa :return: The recreated instance """ - load_context = _ensure_object_loader(load_context, saved_state) + load_context = ensure_object_loader(load_context, saved_state) obj = cls.__new__(cls) persistence.auto_load(obj, saved_state, load_context) obj._workchain = load_context.workchain @@ -443,7 +415,7 @@ def get_description(self) -> str: @persistence.auto_persist('_pos') -class _BlockStepper(persistence.Savable): +class _BlockStepper: def __init__(self, block: Sequence[_Instruction], workchain: 'WorkChain') -> None: self._workchain = workchain self._block = block @@ -488,7 +460,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa :return: The recreated instance """ - load_context = _ensure_object_loader(load_context, saved_state) + load_context = ensure_object_loader(load_context, saved_state) obj = cls.__new__(cls) persistence.auto_load(obj, saved_state, load_context) obj._workchain = load_context.workchain @@ -591,7 +563,7 @@ def __str__(self) -> str: @persistence.auto_persist('_pos') -class _IfStepper(persistence.Savable): +class _IfStepper: def __init__(self, if_instruction: '_If', workchain: 'WorkChain') -> None: self._workchain = workchain self._if_instruction = if_instruction @@ -643,7 +615,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa :return: The recreated instance """ - load_context = _ensure_object_loader(load_context, saved_state) + load_context = ensure_object_loader(load_context, saved_state) obj = cls.__new__(cls) persistence.auto_load(obj, saved_state, load_context) obj._workchain = load_context.workchain @@ -714,7 +686,7 @@ def get_description(self) -> Mapping[str, Any]: return description -class _WhileStepper(persistence.Savable): +class _WhileStepper: def __init__(self, while_instruction: '_While', workchain: 'WorkChain') -> None: self._workchain = workchain self._while_instruction = while_instruction @@ -744,7 +716,9 @@ def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SA return out_state @classmethod - def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': + def recreate_from( + cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[persistence.LoadSaveContext] = None + ) -> 'Savable': """ Recreate a :class:`Savable` from a saved state using an optional load context. @@ -754,7 +728,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa :return: The recreated instance """ - load_context = _ensure_object_loader(load_context, saved_state) + load_context = ensure_object_loader(load_context, saved_state) obj = cls.__new__(cls) persistence.auto_load(obj, saved_state, load_context) obj._workchain = load_context.workchain @@ -801,7 +775,8 @@ def __init__(self, exit_code: Optional[EXIT_CODE_TYPE]) -> None: self.exit_code = exit_code -class _ReturnStepper(persistence.Savable): +@persistence.auto_persist() +class _ReturnStepper: def __init__(self, return_instruction: '_Return', workchain: 'WorkChain') -> None: self._workchain = workchain self._return_instruction = return_instruction diff --git a/tests/test_persistence.py b/tests/test_persistence.py index 65ef3226..4ec4c1a5 100644 --- a/tests/test_persistence.py +++ b/tests/test_persistence.py @@ -5,16 +5,17 @@ import yaml import plumpy -from plumpy.persistence import auto_load +from plumpy.persistence import auto_load, auto_persist, auto_save +from plumpy.utils import SAVED_STATE_TYPE from . import utils -class SaveEmpty(plumpy.Savable): - pass +@auto_persist() +class SaveEmpty: @classmethod - def recreate_from(cls, saved_state, load_context= None): + def recreate_from(cls, saved_state, load_context=None): """ Recreate a :class:`Savable` from a saved state using an optional load context. @@ -28,9 +29,14 @@ def recreate_from(cls, saved_state, load_context= None): auto_load(obj, saved_state, load_context) return obj + def save(self, save_context=None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) + + return out_state + @plumpy.auto_persist('test', 'test_method') -class Save1(plumpy.Savable): +class Save1: def __init__(self): self.test = 'sup yp' self.test_method = self.m @@ -39,7 +45,7 @@ def m(): pass @classmethod - def recreate_from(cls, saved_state, load_context= None): + def recreate_from(cls, saved_state, load_context=None): """ Recreate a :class:`Savable` from a saved state using an optional load context. @@ -53,14 +59,19 @@ def recreate_from(cls, saved_state, load_context= None): auto_load(obj, saved_state, load_context) return obj + def save(self, save_context=None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) + + return out_state + @plumpy.auto_persist('test') -class Save(plumpy.Savable): +class Save: def __init__(self): self.test = Save1() @classmethod - def recreate_from(cls, saved_state, load_context= None): + def recreate_from(cls, saved_state, load_context=None): """ Recreate a :class:`Savable` from a saved state using an optional load context. @@ -74,6 +85,11 @@ def recreate_from(cls, saved_state, load_context= None): auto_load(obj, saved_state, load_context) return obj + def save(self, save_context=None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) + + return out_state + class TestSavable(unittest.TestCase): def test_empty_savable(self): diff --git a/tests/test_processes.py b/tests/test_processes.py index 6fc0730a..e2b0f640 100644 --- a/tests/test_processes.py +++ b/tests/test_processes.py @@ -17,6 +17,10 @@ from plumpy.utils import AttributesFrozendict from . import utils +# FIXME: after deabstract on savable into a protocol, test that all state are savable +# FIXME: also that any process is savable +# FIXME: any process listener is savable +# FIXME: any process control commands are savable class ForgetToCallParent(plumpy.Process): def __init__(self, forget_on): diff --git a/tests/test_workchains.py b/tests/test_workchains.py index 08c7317a..4e34d2b4 100644 --- a/tests/test_workchains.py +++ b/tests/test_workchains.py @@ -11,6 +11,8 @@ from . import utils +# FIXME: after deabstract on savable into a protocol, test that all stepper are savable +# FIXME: workchani itself is savable class Wf(WorkChain): # Keep track of which steps were completed by the workflow