Skip to content

Commit

Permalink
test and typing generic for the state savable types
Browse files Browse the repository at this point in the history
  • Loading branch information
unkcpz committed Jan 22, 2025
1 parent a90dcff commit 14aa317
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 24 deletions.
24 changes: 12 additions & 12 deletions src/plumpy/persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
List,
Optional,
Protocol,
Type,
TypeVar,
cast,
runtime_checkable,
Expand Down Expand Up @@ -474,10 +475,13 @@ def get_meta_type(saved_state: SAVED_STATE_TYPE, name: str) -> Any:
pass


T = TypeVar('T', bound='Savable')


@runtime_checkable
class Savable(Protocol):
@classmethod
def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext | None = None) -> 'Savable':
def recreate_from(cls: type[T], saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext | None = None) -> T:
"""
Recreate a :class:`Savable` from a saved state using an optional load context.
Expand Down Expand Up @@ -544,9 +548,6 @@ def load_auto_persist_params(
setattr(obj, member, _get_value(obj, saved_state, member, load_context))


T = TypeVar('T', bound=Savable)


def auto_load(cls: type[T], saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext | None) -> T:
obj = cls.__new__(cls)

Expand All @@ -570,21 +571,20 @@ def _get_value(
return value


def auto_persist(*members: str) -> Callable[..., Savable]:
def wrapped(savable_cls: type) -> Savable:
if not hasattr(savable_cls, '_auto_persist') or savable_cls._auto_persist is None:
savable_cls._auto_persist = set() # type: ignore[attr-defined]
def auto_persist(*members: str) -> Callable[[type[T]], type[T]]:
def wrapped(cls: type[T]) -> type[T]:
if not hasattr(cls, '_auto_persist') or cls._auto_persist is None:
cls._auto_persist = set() # type: ignore[attr-defined]
else:
savable_cls._auto_persist = set(savable_cls._auto_persist)
cls._auto_persist = set(cls._auto_persist)

savable_cls._auto_persist.update(members) # type: ignore[attr-defined]
cls._auto_persist.update(members) # type: ignore[attr-defined]
# XXX: validate on `save` and `recreate_from` method??
return cast(Savable, savable_cls)
return cls

return wrapped


# FIXME: move me to another module? savablefuture.py?
@auto_persist('_state', '_result')
class SavableFuture(futures.Future):
"""
Expand Down
23 changes: 12 additions & 11 deletions src/plumpy/process_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
Callable,
ClassVar,
Optional,
Self,
Tuple,
Type,
Union,
Expand Down Expand Up @@ -92,7 +93,7 @@ def __init__(self, msg_text: str | None):

class Command:
@classmethod
def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable':
def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> Self:
"""
Recreate a :class:`Savable` from a saved state using an optional load context.
Expand Down Expand Up @@ -164,7 +165,7 @@ def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SA

@override
@classmethod
def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable':
def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> Self:
"""
Recreate a :class:`Savable` from a saved state using an optional load context.
Expand Down Expand Up @@ -214,7 +215,7 @@ class Created:
RUN_FN = 'run_fn'
is_terminal: ClassVar[bool] = False

def __init__(self, process: 'Process', run_fn: Callable[..., Any], *args: Any, **kwargs: Any) -> None:
def __init__(self, process: 'st.StateMachine', run_fn: Callable[..., Any], *args: Any, **kwargs: Any) -> None:
assert run_fn is not None
self.process = process
self.run_fn = run_fn
Expand All @@ -228,7 +229,7 @@ def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TY
return out_state

@classmethod
def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable':
def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> Self:
"""
Recreate a :class:`Savable` from a saved state using an optional load context.
Expand Down Expand Up @@ -278,7 +279,7 @@ class Running:
is_terminal: ClassVar[bool] = False

def __init__(
self, process: 'Process', run_fn: Callable[..., Union[Awaitable[Any], Any]], *args: Any, **kwargs: Any
self, process: 'st.StateMachine', run_fn: Callable[..., Union[Awaitable[Any], Any]], *args: Any, **kwargs: Any
) -> None:
assert run_fn is not None
self.process = process
Expand All @@ -297,7 +298,7 @@ def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TY
return out_state

@classmethod
def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable':
def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> Self:
"""
Recreate a :class:`Savable` from a saved state using an optional load context.
Expand Down Expand Up @@ -413,7 +414,7 @@ def __str__(self) -> str:

def __init__(
self,
process: 'Process',
process: 'st.StateMachine',
done_callback: Optional[Callable[..., Any]],
msg: Optional[str] = None,
data: Optional[Any] = None,
Expand All @@ -433,7 +434,7 @@ def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TY
return out_state

@classmethod
def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable':
def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> Self:
"""
Recreate a :class:`Savable` from a saved state using an optional load context.
Expand Down Expand Up @@ -537,7 +538,7 @@ def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TY
return out_state

@classmethod
def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable':
def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> Self:
"""
Recreate a :class:`Savable` from a saved state using an optional load context.
Expand Down Expand Up @@ -596,7 +597,7 @@ def __init__(self, result: Any, successful: bool) -> None:
self.successful = successful

@classmethod
def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable':
def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> Self:
"""
Recreate a :class:`Savable` from a saved state using an optional load context.
Expand Down Expand Up @@ -644,7 +645,7 @@ def __init__(self, msg: Optional[MessageType]):
self.msg = msg

@classmethod
def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable':
def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> Self:
"""
Recreate a :class:`Savable` from a saved state using an optional load context.
Expand Down
42 changes: 42 additions & 0 deletions tests/test_process_states.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# FIXME: after deabstract on savable into a protocol, test that all state are savable

import pytest
from plumpy.base.state_machine import StateMachine
from plumpy.message import MessageBuilder
from plumpy.persistence import Savable
from plumpy.process_states import Created, Excepted, Finished, Killed, Running, Waiting
from tests.utils import DummyProcess


@pytest.fixture(scope='function')
def proc() -> 'StateMachine':
return DummyProcess()


def test_create_savable(proc: StateMachine):
state = Created(proc, run_fn=lambda: None)
assert isinstance(state, Savable)


def test_running_savable(proc: StateMachine):
state = Running(proc, run_fn=lambda: None)
assert isinstance(state, Savable)


def test_waiting_savable(proc: StateMachine):
state = Waiting(proc, done_callback=lambda: None)
assert isinstance(state, Savable)


def test_excepted_savable():
state = Excepted(exception=ValueError('dummy'))
assert isinstance(state, Savable)


def test_finished_savable():
state = Finished(result='done', successful=True)
assert isinstance(state, Savable)

def test_killed_savable():
state = Killed(msg=MessageBuilder.kill('kill it'))
assert isinstance(state, Savable)
1 change: 0 additions & 1 deletion tests/test_processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from plumpy.utils import AttributesFrozendict
from . import utils

# FIXME: after deabstract on savable into a protocol, test that all state are savable
# FIXME: also that any process is savable
# FIXME: any process listener is savable
# FIXME: any process control commands are savable
Expand Down

0 comments on commit 14aa317

Please sign in to comment.