Skip to content

Commit

Permalink
Correct typing in workchain.py (#21)
Browse files Browse the repository at this point in the history
  • Loading branch information
unkcpz authored Jan 31, 2025
1 parent 0730394 commit c6d5b8f
Showing 1 changed file with 9 additions and 8 deletions.
17 changes: 9 additions & 8 deletions src/plumpy/workchains.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import abc
import asyncio
import collections
import inspect
import logging
import re
Expand Down Expand Up @@ -157,6 +156,7 @@ def save(self, loader: ObjectLoader | None = None) -> SAVED_STATE_TYPE:
"""
out_state: SAVED_STATE_TYPE = auto_save(self, loader)

# TODO: revert condition and raise if not savable
if isinstance(self._state, persistence.Savable):
out_state['_state'] = self._state.save()

Expand Down Expand Up @@ -466,25 +466,24 @@ def __str__(self) -> str:
return str(self._pos) + ':' + str(self._child_stepper)


class _Block(_Instruction, collections.abc.Sequence):
class _Block(_Instruction, Sequence[_Instruction]):
"""
Represents a block of instructions i.e. a sequential list of instructions.
"""

# XXX: swap workchain and instructions
def __init__(self, instructions: Sequence[_Instruction | WC_COMMAND_TYPE]) -> None:
# Build up the list of commands
comms: MutableSequence[_Instruction | _FunctionCall] = []
comms: MutableSequence[_Instruction] = []
for instruction in instructions:
if not isinstance(instruction, _Instruction):
# Assume it's a function call
comms.append(_FunctionCall(instruction))
else:
comms.append(instruction)

self._instruction: MutableSequence[_Instruction | _FunctionCall] = comms
self._instruction: MutableSequence[_Instruction] = comms

def __getitem__(self, index: int) -> _Instruction | _FunctionCall: # type: ignore
def __getitem__(self, index: int) -> _Instruction: # type: ignore
return self._instruction[index]

def __len__(self) -> int:
Expand Down Expand Up @@ -625,7 +624,7 @@ def __str__(self) -> str:
return string


class _If(_Instruction, collections.abc.Sequence):
class _If(_Instruction, Sequence[_Conditional]):
def __init__(self, condition: PREDICATE_TYPE) -> None:
super().__init__()
self._ifs: list[_Conditional] = [_Conditional(self, condition, label=if_.__name__)]
Expand Down Expand Up @@ -668,6 +667,8 @@ def recreate_stepper(self, saved_state: SAVED_STATE_TYPE, workchain: 'WorkChain'
return _IfStepper.recreate_from(saved_state, load_context)

def get_description(self) -> Mapping[str, Any]:
import collections

description = collections.OrderedDict()

description[f'if({self._ifs[0].predicate.__name__})'] = self._ifs[0].body.get_description()
Expand Down Expand Up @@ -737,7 +738,7 @@ def __str__(self) -> str:
return string


class _While(_Conditional, _Instruction, collections.abc.Sequence):
class _While(_Conditional, _Instruction, Sequence[_Conditional]):
def __init__(self, predicate: PREDICATE_TYPE) -> None:
super().__init__(self, predicate, label=while_.__name__)

Expand Down

0 comments on commit c6d5b8f

Please sign in to comment.