Skip to content

Commit

Permalink
Make auto_load symmetry with auto_save and state/state_label distinguish
Browse files Browse the repository at this point in the history
  • Loading branch information
unkcpz committed Jan 18, 2025
1 parent 9493779 commit bd997c9
Show file tree
Hide file tree
Showing 11 changed files with 145 additions and 135 deletions.
10 changes: 8 additions & 2 deletions src/plumpy/base/state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,13 @@ def create_initial_state(self, *args: Any, **kwargs: Any) -> State:
return self.get_state_class(self.initial_state_label())(self, *args, **kwargs)

@property
def state(self) -> Any:
def state(self) -> State | None:
if self._state is None:
return None
return self._state

@property
def state_label(self) -> Any:
if self._state is None:
return None
return self._state.LABEL
Expand Down Expand Up @@ -326,7 +332,7 @@ def transition_to(self, new_state: State | None, **kwargs: Any) -> None:
# it can happened when transit from terminal state
return None

initial_state_label = self._state.LABEL if self._state is not None else None
initial_state_label = self.state_label
label = None
try:
self._transitioning = True
Expand Down
3 changes: 1 addition & 2 deletions src/plumpy/event_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa
"""
load_context = ensure_object_loader(load_context, saved_state)
obj = cls.__new__(cls)
auto_load(obj, saved_state, load_context)
obj = auto_load(cls, saved_state, load_context)
return obj

def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE:
Expand Down
19 changes: 18 additions & 1 deletion src/plumpy/persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
List,
Optional,
Protocol,
TypeVar,
cast,
runtime_checkable,
)
Expand Down Expand Up @@ -523,6 +524,8 @@ def auto_save(obj: Savable, save_context: Optional[LoadSaveContext] = None) -> S
value = value.__name__
elif isinstance(value, Savable) and not isinstance(value, type):
# persist for a savable obj, call `save` method of obj.
# the rhs branch is for when value is a Savable class, it is true runtime check
# of lhs condition.
SaveUtil.set_meta_type(out_state, member, META__TYPE__SAVABLE)
value = value.save()
else:
Expand All @@ -532,11 +535,25 @@ def auto_save(obj: Savable, save_context: Optional[LoadSaveContext] = None) -> S
return out_state


def auto_load(obj: SavableWithAutoPersist, saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext) -> None:
def load_auto_persist_params(
obj: SavableWithAutoPersist, saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext | None
) -> None:
for member in obj._auto_persist:
setattr(obj, member, _get_value(obj, saved_state, member, load_context))


T = TypeVar('T', bound=Savable)


def auto_load(cls: type[T], saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext | None) -> T:
obj = cls.__new__(cls)

if isinstance(obj, SavableWithAutoPersist):
load_auto_persist_params(obj, saved_state, load_context)

return obj


def _get_value(
obj: Any, saved_state: SAVED_STATE_TYPE, name: str, load_context: LoadSaveContext | None
) -> MethodType | Savable:
Expand Down
40 changes: 14 additions & 26 deletions src/plumpy/process_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa
:return: The recreated instance
"""
obj = cls.__new__(cls)
auto_load(obj, saved_state, load_context)
load_context = ensure_object_loader(load_context, saved_state)
obj = auto_load(cls, saved_state, load_context)
return obj

def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE:
Expand Down Expand Up @@ -151,15 +151,15 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa
"""
load_context = ensure_object_loader(load_context, saved_state)
obj = cls.__new__(cls)
auto_load(obj, saved_state, load_context)
obj = auto_load(cls, saved_state, load_context)

obj.state_machine = load_context.process
try:
obj.continue_fn = utils.load_function(saved_state[obj.CONTINUE_FN])
except ValueError:
process = load_context.process
obj.continue_fn = getattr(process, saved_state[obj.CONTINUE_FN])
if load_context is not None:
obj.continue_fn = getattr(load_context.proc, saved_state[obj.CONTINUE_FN])
else:
raise
return obj


Expand Down Expand Up @@ -215,12 +215,8 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa
"""
load_context = ensure_object_loader(load_context, saved_state)
obj = cls.__new__(cls)

auto_load(obj, saved_state, load_context)

obj = auto_load(cls, saved_state, load_context)
obj.process = load_context.process

obj.run_fn = getattr(obj.process, saved_state[obj.RUN_FN])

return obj
Expand Down Expand Up @@ -286,15 +282,12 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa
"""
load_context = ensure_object_loader(load_context, saved_state)
obj = cls.__new__(cls)
auto_load(obj, saved_state, load_context)

obj = auto_load(cls, saved_state, load_context)
obj.process = load_context.process

obj.run_fn = getattr(obj.process, saved_state[obj.RUN_FN])
if obj.COMMAND in saved_state:
# FIXME: typing
obj._command = persistence.load(saved_state[obj.COMMAND], load_context) # type: ignore

return obj

def interrupt(self, reason: Any) -> None:
Expand Down Expand Up @@ -424,9 +417,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa
"""
load_context = ensure_object_loader(load_context, saved_state)
obj = cls.__new__(cls)
auto_load(obj, saved_state, load_context)

obj = auto_load(cls, saved_state, load_context)
obj.process = load_context.process

callback_name = saved_state.get(obj.DONE_CALLBACK, None)
Expand Down Expand Up @@ -530,8 +521,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa
"""
load_context = ensure_object_loader(load_context, saved_state)
obj = cls.__new__(cls)
auto_load(obj, saved_state, load_context)
obj = auto_load(cls, saved_state, load_context)

obj.exception = yaml.load(saved_state[obj.EXC_VALUE], Loader=Loader)
if _HAS_TBLIB:
Expand Down Expand Up @@ -590,8 +580,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa
"""
load_context = ensure_object_loader(load_context, saved_state)
obj = cls.__new__(cls)
auto_load(obj, saved_state, load_context)
obj = auto_load(cls, saved_state, load_context)
return obj

def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE:
Expand Down Expand Up @@ -639,8 +628,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa
"""
load_context = ensure_object_loader(load_context, saved_state)
obj = cls.__new__(cls)
auto_load(obj, saved_state, load_context)
obj = auto_load(cls, saved_state, load_context)
return obj

def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE:
Expand Down
Loading

0 comments on commit bd997c9

Please sign in to comment.