diff --git a/src/plumpy/event_helper.py b/src/plumpy/event_helper.py index e20dae3f..a7108947 100644 --- a/src/plumpy/event_helper.py +++ b/src/plumpy/event_helper.py @@ -5,7 +5,7 @@ from plumpy.utils import SAVED_STATE_TYPE from . import persistence -from plumpy.persistence import Savable, LoadSaveContext, _ensure_object_loader, auto_load +from plumpy.persistence import Savable, LoadSaveContext, _ensure_object_loader, auto_load, auto_save if TYPE_CHECKING: from typing import Set, Type @@ -48,6 +48,11 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa auto_load(obj, saved_state, load_context) return obj + def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) + + return out_state + @property def listeners(self) -> 'Set[ProcessListener]': return self._listeners diff --git a/src/plumpy/persistence.py b/src/plumpy/persistence.py index 2367c759..1100c086 100644 --- a/src/plumpy/persistence.py +++ b/src/plumpy/persistence.py @@ -103,7 +103,7 @@ def load(saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext | None = N load_context = _ensure_object_loader(load_context, saved_state) assert load_context.loader is not None # required for type checking try: - class_name = Savable._get_class_name(saved_state) + class_name = SaveUtil._get_class_name(saved_state) load_cls: Savable = load_context.loader.load_object(class_name) except KeyError: raise ValueError('Class name not found in saved state') @@ -429,7 +429,7 @@ def _ensure_object_loader(context: Optional['LoadSaveContext'], saved_state: SAV # 2) Try getting from saved_state default_loader = loaders.get_object_loader() try: - loader_identifier = Savable.get_custom_meta(saved_state, META__OBJECT_LOADER) + loader_identifier = SaveUtil.get_custom_meta(saved_state, META__OBJECT_LOADER) except ValueError: # 3) Fall back to default loader = default_loader @@ -448,45 +448,10 @@ def _ensure_object_loader(context: Optional['LoadSaveContext'], saved_state: SAV META__TYPE__SAVABLE: str = 'S' -class Savable: - CLASS_NAME: str = 'class_name' - - _auto_persist: Optional[Set[str]] = None - _persist_configured = False - - @classmethod - def auto_persist(cls, *members: str) -> None: - if cls._auto_persist is None: - cls._auto_persist = set() - cls._auto_persist.update(members) - - @classmethod - def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': - """ - Recreate a :class:`Savable` from a saved state using an optional load context. - - :param saved_state: The saved state - :param load_context: An optional load context - - :return: The recreated instance - - """ - ... - - def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: - out_state: SAVED_STATE_TYPE = auto_save(self, save_context) - - return out_state - - def _ensure_persist_configured(self) -> None: - if not self._persist_configured: - self._persist_configured = True - - # region Metadata getter/setters - +class SaveUtil: @staticmethod def set_custom_meta(out_state: SAVED_STATE_TYPE, name: str, value: Any) -> None: - user_dict = Savable._get_create_meta(out_state).setdefault(META__USER, {}) + user_dict = SaveUtil._get_create_meta(out_state).setdefault(META__USER, {}) user_dict[name] = value @staticmethod @@ -502,15 +467,15 @@ def _get_create_meta(out_state: SAVED_STATE_TYPE) -> Dict[str, Any]: @staticmethod def _set_class_name(out_state: SAVED_STATE_TYPE, name: str) -> None: - Savable._get_create_meta(out_state)[META__CLASS_NAME] = name + SaveUtil._get_create_meta(out_state)[META__CLASS_NAME] = name @staticmethod def _get_class_name(saved_state: SAVED_STATE_TYPE) -> str: - return Savable._get_create_meta(saved_state)[META__CLASS_NAME] + return SaveUtil._get_create_meta(saved_state)[META__CLASS_NAME] @staticmethod def _set_meta_type(out_state: SAVED_STATE_TYPE, name: str, type_spec: Any) -> None: - type_dict = Savable._get_create_meta(out_state).setdefault(META__TYPES, {}) + type_dict = SaveUtil._get_create_meta(out_state).setdefault(META__TYPES, {}) type_dict[name] = type_spec @staticmethod @@ -520,21 +485,32 @@ def _get_meta_type(saved_state: SAVED_STATE_TYPE, name: str) -> Any: except KeyError: pass - # endregion - def _get_value( - self, saved_state: SAVED_STATE_TYPE, name: str, load_context: Optional[LoadSaveContext] - ) -> Union[MethodType, 'Savable']: - value = saved_state[name] +class Savable: + CLASS_NAME: str = 'class_name' - typ = Savable._get_meta_type(saved_state, name) - if typ == META__TYPE__METHOD: - value = getattr(self, value) - elif typ == META__TYPE__SAVABLE: - value = load(value, load_context) + _auto_persist: Optional[Set[str]] = None - return value + @classmethod + def auto_persist(cls, *members: str) -> None: + if cls._auto_persist is None: + cls._auto_persist = set() + cls._auto_persist.update(members) + @classmethod + def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance + + """ + ... + + def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: ... @auto_persist('_state', '_result') class SavableFuture(futures.Future, Savable): @@ -612,24 +588,23 @@ def auto_save(obj: Savable, save_context: Optional[LoadSaveContext] = None) -> S # If the user has specified a class loader, then save it in the saved state if save_context.loader is not None: loader_class = default_loader.identify_object(save_context.loader.__class__) - Savable.set_custom_meta(out_state, META__OBJECT_LOADER, loader_class) + SaveUtil.set_custom_meta(out_state, META__OBJECT_LOADER, loader_class) loader = save_context.loader else: loader = default_loader - Savable._set_class_name(out_state, loader.identify_object(obj.__class__)) + SaveUtil._set_class_name(out_state, loader.identify_object(obj.__class__)) - obj._ensure_persist_configured() if obj._auto_persist is not None: for member in obj._auto_persist: value = getattr(obj, member) if inspect.ismethod(value): if value.__self__ is not obj: raise TypeError('Cannot persist methods of other classes') - Savable._set_meta_type(out_state, member, META__TYPE__METHOD) + SaveUtil._set_meta_type(out_state, member, META__TYPE__METHOD) value = value.__name__ elif isinstance(value, Savable): - Savable._set_meta_type(out_state, member, META__TYPE__SAVABLE) + SaveUtil._set_meta_type(out_state, member, META__TYPE__SAVABLE) value = value.save() else: value = copy.deepcopy(value) @@ -639,7 +614,19 @@ def auto_save(obj: Savable, save_context: Optional[LoadSaveContext] = None) -> S def auto_load(obj: Savable, saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext) -> None: - obj._ensure_persist_configured() if obj._auto_persist is not None: for member in obj._auto_persist: - setattr(obj, member, obj._get_value(saved_state, member, load_context)) + setattr(obj, member, _get_value(obj, saved_state, member, load_context)) + +def _get_value( + obj, saved_state: SAVED_STATE_TYPE, name: str, load_context: Optional[LoadSaveContext] +) -> Union[MethodType, 'Savable']: + value = saved_state[name] + + typ = SaveUtil._get_meta_type(saved_state, name) + if typ == META__TYPE__METHOD: + value = getattr(obj, value) + elif typ == META__TYPE__SAVABLE: + value = load(value, load_context) + + return value diff --git a/src/plumpy/process_states.py b/src/plumpy/process_states.py index 1d7f2350..8c5f0601 100644 --- a/src/plumpy/process_states.py +++ b/src/plumpy/process_states.py @@ -1,8 +1,6 @@ # -*- coding: utf-8 -*- from __future__ import annotations -import copy -import inspect import sys import traceback from enum import Enum @@ -13,19 +11,17 @@ Callable, ClassVar, Optional, - Protocol, Tuple, Type, Union, cast, final, - runtime_checkable, + override, ) import yaml from yaml.loader import Loader -from plumpy import loaders from plumpy.process_comms import KillMessage, MessageType from plumpy.persistence import _ensure_object_loader @@ -40,9 +36,6 @@ from .base import state_machine as st from .lang import NULL from .persistence import ( - META__OBJECT_LOADER, - META__TYPE__METHOD, - META__TYPE__SAVABLE, LoadSaveContext, Savable, auto_load, @@ -94,8 +87,26 @@ class PauseInterruption(Interruption): class Command(persistence.Savable): - pass + @classmethod + def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance + + """ + obj = cls.__new__(cls) + auto_load(obj, saved_state, load_context) + return obj + + def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) + + return out_state @auto_persist('msg') class Kill(Command): @@ -140,12 +151,14 @@ def __init__(self, continue_fn: Callable[..., Any], *args: Any, **kwargs: Any): self.args = args self.kwargs = kwargs + @override def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SAVED_STATE_TYPE: out_state: SAVED_STATE_TYPE = persistence.auto_save(self, save_context) out_state[self.CONTINUE_FN] = self.continue_fn.__name__ return out_state + @override @classmethod def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': """ @@ -605,6 +618,11 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa auto_load(obj, saved_state, load_context) return obj + def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) + + return out_state + def enter(self) -> None: ... def exit(self) -> None: ... @@ -649,6 +667,11 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa auto_load(obj, saved_state, load_context) return obj + def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) + + return out_state + def enter(self) -> None: ... def exit(self) -> None: ... diff --git a/src/plumpy/workchains.py b/src/plumpy/workchains.py index cf7ad81f..4a766c79 100644 --- a/src/plumpy/workchains.py +++ b/src/plumpy/workchains.py @@ -176,24 +176,23 @@ def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SA # If the user has specified a class loader, then save it in the saved state if save_context.loader is not None: loader_class = default_loader.identify_object(save_context.loader.__class__) - persistence.Savable.set_custom_meta(out_state, persistence.META__OBJECT_LOADER, loader_class) + persistence.SaveUtil.set_custom_meta(out_state, persistence.META__OBJECT_LOADER, loader_class) loader = save_context.loader else: loader = default_loader - persistence.Savable._set_class_name(out_state, loader.identify_object(self.__class__)) + persistence.SaveUtil._set_class_name(out_state, loader.identify_object(self.__class__)) - self._ensure_persist_configured() if self._auto_persist is not None: for member in self._auto_persist: value = getattr(self, member) if inspect.ismethod(value): if value.__self__ is not self: raise TypeError('Cannot persist methods of other classes') - persistence.Savable._set_meta_type(out_state, member, persistence.META__TYPE__METHOD) + persistence.SaveUtil._set_meta_type(out_state, member, persistence.META__TYPE__METHOD) value = value.__name__ elif isinstance(value, persistence.Savable): - persistence.Savable._set_meta_type(out_state, member, persistence.META__TYPE__SAVABLE) + persistence.SaveUtil._set_meta_type(out_state, member, persistence.META__TYPE__SAVABLE) value = value.save() else: value = copy.deepcopy(value) diff --git a/tests/test_persistence.py b/tests/test_persistence.py index 65ef3226..b4100391 100644 --- a/tests/test_persistence.py +++ b/tests/test_persistence.py @@ -5,7 +5,8 @@ import yaml import plumpy -from plumpy.persistence import auto_load +from plumpy.persistence import auto_load, auto_save +from plumpy.utils import SAVED_STATE_TYPE from . import utils @@ -14,7 +15,7 @@ class SaveEmpty(plumpy.Savable): pass @classmethod - def recreate_from(cls, saved_state, load_context= None): + def recreate_from(cls, saved_state, load_context=None): """ Recreate a :class:`Savable` from a saved state using an optional load context. @@ -28,6 +29,11 @@ def recreate_from(cls, saved_state, load_context= None): auto_load(obj, saved_state, load_context) return obj + def save(self, save_context=None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) + + return out_state + @plumpy.auto_persist('test', 'test_method') class Save1(plumpy.Savable): @@ -39,7 +45,7 @@ def m(): pass @classmethod - def recreate_from(cls, saved_state, load_context= None): + def recreate_from(cls, saved_state, load_context=None): """ Recreate a :class:`Savable` from a saved state using an optional load context. @@ -53,6 +59,11 @@ def recreate_from(cls, saved_state, load_context= None): auto_load(obj, saved_state, load_context) return obj + def save(self, save_context=None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) + + return out_state + @plumpy.auto_persist('test') class Save(plumpy.Savable): @@ -60,7 +71,7 @@ def __init__(self): self.test = Save1() @classmethod - def recreate_from(cls, saved_state, load_context= None): + def recreate_from(cls, saved_state, load_context=None): """ Recreate a :class:`Savable` from a saved state using an optional load context. @@ -74,6 +85,11 @@ def recreate_from(cls, saved_state, load_context= None): auto_load(obj, saved_state, load_context) return obj + def save(self, save_context=None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) + + return out_state + class TestSavable(unittest.TestCase): def test_empty_savable(self):