diff --git a/src/plumpy/base/state_machine.py b/src/plumpy/base/state_machine.py index a12981a0..2eaee534 100644 --- a/src/plumpy/base/state_machine.py +++ b/src/plumpy/base/state_machine.py @@ -1,28 +1,21 @@ # -*- coding: utf-8 -*- """The state machine for processes""" -from __future__ import annotations - import enum import functools import inspect import logging import os import sys +from collections.abc import Iterable, Sequence from types import TracebackType from typing import ( Any, Callable, ClassVar, - Dict, Hashable, - Iterable, - List, - Optional, Protocol, - Sequence, - Type, - Union, + final, runtime_checkable, ) @@ -34,19 +27,48 @@ _LOGGER = logging.getLogger(__name__) -EVENT_CALLBACK_TYPE = Callable[['StateMachine', Hashable, Optional['State']], None] +EVENT_CALLBACK_TYPE = Callable[['StateMachine', Hashable, 'State | None'], None] + + +@runtime_checkable +class State(Protocol): + LABEL: ClassVar[Any] + ALLOWED: ClassVar[set[Any]] + is_terminal: ClassVar[bool] + + def __init__(self, *args: Any, **kwargs: Any): ... + + def enter(self) -> None: ... + + def exit(self) -> None: ... + + +@runtime_checkable +class Interruptable(Protocol): + def interrupt(self, reason: Exception) -> None: ... + + +@runtime_checkable +class Proceedable(Protocol): + def execute(self) -> State | None: + """ + Execute the state, performing the actions that this state is responsible for. + :returns: a state to transition to or None if finished. + """ + ... class StateMachineError(Exception): """Base class for state machine errors""" +@final class StateEntryFailed(Exception): # noqa: N818 """ Failed to enter a state, can provide the next state to go to via this exception """ - def __init__(self, state: State, *args: Any, **kwargs: Any) -> None: + def __init__(self, state: 'State', *args: Any, **kwargs: Any) -> None: super().__init__('failed to enter state') self.state = state self.args = args @@ -63,11 +85,12 @@ def __init__(self, evt: str, msg: str): self.event = evt +@final class TransitionFailed(Exception): # noqa: N818 """A state transition failed""" def __init__( - self, initial_state: 'State', final_state: Optional['State'] = None, traceback_str: Optional[str] = None + self, initial_state: 'State', final_state: 'State | None' = None, traceback_str: str | None = None ) -> None: self.initial_state = initial_state self.final_state = final_state @@ -82,8 +105,8 @@ def _format_msg(self) -> str: def event( - from_states: Union[str, Type['State'], Iterable[Type['State']]] = '*', - to_states: Union[str, Type['State'], Iterable[Type['State']]] = '*', + from_states: str | type['State'] | Iterable[type['State']] = '*', + to_states: str | type['State'] | Iterable[type['State']] = '*', ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: """A decorator to check for correct transitions, raising ``EventError`` on invalid transitions.""" if from_states != '*': @@ -115,7 +138,7 @@ def transition(self: Any, *a: Any, **kw: Any) -> Any: raise EventError( evt_label, - 'Event produced invalid state transition from ' f'{initial.LABEL} to {self._state.LABEL}', + f'Event produced invalid state transition from {initial.LABEL} to {self._state.LABEL}', ) return result @@ -128,35 +151,7 @@ def transition(self: Any, *a: Any, **kw: Any) -> Any: return wrapper -@runtime_checkable -class State(Protocol): - LABEL: ClassVar[Any] - ALLOWED: ClassVar[set[Any]] - is_terminal: ClassVar[bool] - - def __init__(self, *args: Any, **kwargs: Any): ... - - def enter(self) -> None: ... - - def exit(self) -> None: ... - - -@runtime_checkable -class Interruptable(Protocol): - def interrupt(self, reason: Exception) -> None: ... - - -@runtime_checkable -class Proceedable(Protocol): - def execute(self) -> State | None: - """ - Execute the state, performing the actions that this state is responsible for. - :returns: a state to transition to or None if finished. - """ - ... - - -def create_state(st: StateMachine, state_label: Hashable, *args: Any, **kwargs: Any) -> State: +def create_state(st: 'StateMachine', state_label: Hashable, *args: Any, **kwargs: Any) -> State: if state_label not in st.get_states_map(): raise ValueError(f'{state_label} is not a valid state') @@ -192,20 +187,20 @@ def __call__(cls, *args: Any, **kwargs: Any) -> 'StateMachine': class StateMachine(metaclass=StateMachineMeta): - STATES: Optional[Sequence[Type[State]]] = None - _STATES_MAP: Optional[Dict[Hashable, Type[State]]] = None + STATES: Sequence[type[State]] | None = None + _STATES_MAP: dict[Hashable, type[State]] | None = None - _transitioning = False - _transition_failing = False + _transitioning: bool = False + _transition_failing: bool = False @classmethod - def get_states_map(cls) -> Dict[Hashable, Type[State]]: + def get_states_map(cls) -> dict[Hashable, type[State]]: cls.__ensure_built() assert cls._STATES_MAP is not None # required for type checking return cls._STATES_MAP @classmethod - def get_states(cls) -> Sequence[Type[State]]: + def get_states(cls) -> Sequence[type[State]]: if cls.STATES is not None: return cls.STATES @@ -218,7 +213,7 @@ def initial_state_label(cls) -> Any: return cls.STATES[0].LABEL @classmethod - def get_state_class(cls, label: Any) -> Type[State]: + def get_state_class(cls, label: Any) -> type[State]: cls.__ensure_built() assert cls._STATES_MAP is not None return cls._STATES_MAP[label] @@ -249,11 +244,11 @@ def __ensure_built(cls) -> None: def __init__(self) -> None: super().__init__() self.__ensure_built() - self._state: Optional[State] = None + self._state: State | None = None self._exception_handler = None # Note this appears to never be used self.set_debug((not sys.flags.ignore_environment and bool(os.environ.get('PYTHONSMDEBUG')))) self._transitioning = False - self._event_callbacks: Dict[Hashable, List[EVENT_CALLBACK_TYPE]] = {} + self._event_callbacks: dict[Hashable, list[EVENT_CALLBACK_TYPE]] = {} @super_check def init(self) -> None: @@ -298,7 +293,7 @@ def remove_state_event_callback(self, hook: Hashable, callback: EVENT_CALLBACK_T except (KeyError, ValueError): raise ValueError(f"Callback not set for hook '{hook}'") - def _fire_state_event(self, hook: Hashable, state: Optional[State]) -> None: + def _fire_state_event(self, hook: Hashable, state: State | None) -> None: for callback in self._event_callbacks.get(hook, []): callback(self, hook, state) diff --git a/src/plumpy/coordinator.py b/src/plumpy/coordinator.py index 702ea5f5..cffd08b5 100644 --- a/src/plumpy/coordinator.py +++ b/src/plumpy/coordinator.py @@ -1,7 +1,8 @@ # -*- coding: utf-8 -*- from __future__ import annotations -from typing import TYPE_CHECKING, Any, Callable, Hashable, Pattern, Protocol +import re +from typing import TYPE_CHECKING, Any, Callable, Hashable, Protocol if TYPE_CHECKING: # identifiers for subscribers @@ -23,8 +24,8 @@ def add_rpc_subscriber(self, subscriber: 'RpcSubscriber', identifier: 'ID_TYPE | def add_broadcast_subscriber( self, subscriber: 'BroadcastSubscriber', - subject_filters: list[Hashable | Pattern[str]] | None = None, - sender_filters: list[Hashable | Pattern[str]] | None = None, + subject_filters: list[Hashable | re.Pattern[str]] | None = None, + sender_filters: list[Hashable | re.Pattern[str]] | None = None, identifier: 'ID_TYPE | None' = None, ) -> Any: ... diff --git a/src/plumpy/event_helper.py b/src/plumpy/event_helper.py index 47188031..1a55939c 100644 --- a/src/plumpy/event_helper.py +++ b/src/plumpy/event_helper.py @@ -2,7 +2,7 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Any, Callable, Optional, final +from typing import TYPE_CHECKING, Any, Callable, final from typing_extensions import Self @@ -38,7 +38,7 @@ def remove_all_listeners(self) -> None: self._listeners.clear() @classmethod - def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> Self: + def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext | None = None) -> Self: """ Recreate a :class:`Savable` from a saved state using an optional load context. diff --git a/src/plumpy/events.py b/src/plumpy/events.py index a6e62529..7379a6b6 100644 --- a/src/plumpy/events.py +++ b/src/plumpy/events.py @@ -3,7 +3,7 @@ import asyncio import sys -from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Sequence +from typing import TYPE_CHECKING, Any, Callable, Dict, Sequence if TYPE_CHECKING: from .processes import Process @@ -22,7 +22,7 @@ def new_event_loop(*args: Any, **kwargs: Any) -> asyncio.AbstractEventLoop: class PlumpyEventLoopPolicy(asyncio.DefaultEventLoopPolicy): """Custom event policy that always returns the same event loop that is made reentrant by ``nest_asyncio``.""" - _loop: Optional[asyncio.AbstractEventLoop] = None + _loop: asyncio.AbstractEventLoop | None = None def get_event_loop(self) -> asyncio.AbstractEventLoop: """Return the patched event loop.""" @@ -55,7 +55,7 @@ def reset_event_loop_policy() -> None: asyncio.set_event_loop_policy(None) -def run_until_complete(future: asyncio.Future, loop: Optional[asyncio.AbstractEventLoop] = None) -> Any: +def run_until_complete(future: asyncio.Future, loop: asyncio.AbstractEventLoop | None = None) -> Any: loop = loop or get_event_loop() return loop.run_until_complete(future) diff --git a/src/plumpy/exceptions.py b/src/plumpy/exceptions.py index b4358770..7a559b07 100644 --- a/src/plumpy/exceptions.py +++ b/src/plumpy/exceptions.py @@ -1,5 +1,7 @@ # -*- coding: utf-8 -*- -from typing import Optional + + +from typing import final class KilledError(Exception): @@ -12,10 +14,11 @@ class InvalidStateError(Exception): """ +@final class UnsuccessfulResult: """The result of the process was unsuccessful""" - def __init__(self, result: Optional[int] = None): + def __init__(self, result: int | None = None): """Initialise. :param result: the exit code of the process diff --git a/src/plumpy/futures.py b/src/plumpy/futures.py index 3a59351d..e9d7e928 100644 --- a/src/plumpy/futures.py +++ b/src/plumpy/futures.py @@ -7,7 +7,8 @@ import asyncio import contextlib -from typing import Any, Awaitable, Callable, Generator, Optional +from collections.abc import Awaitable, Generator +from typing import Any, Callable, final class InvalidFutureError(Exception): @@ -33,6 +34,7 @@ def capture_exceptions(future, ignore: tuple[type[BaseException], ...] = ()) -> future.set_exception(exception) +@final class CancellableAction(Future): """ An action that can be launched and potentially cancelled @@ -64,7 +66,7 @@ def run(self, *args: Any, **kwargs: Any) -> None: self._action = None # type: ignore -def create_task(coro: Callable[[], Awaitable[Any]], loop: Optional[asyncio.AbstractEventLoop] = None) -> Future: +def create_task(coro: Callable[[], Awaitable[Any]], loop: asyncio.AbstractEventLoop | None = None) -> Future: """ Schedule a call to a coro in the event loop and wrap the outcome in a future. diff --git a/src/plumpy/loaders.py b/src/plumpy/loaders.py index bb248d6a..87b3b169 100644 --- a/src/plumpy/loaders.py +++ b/src/plumpy/loaders.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- import abc import importlib -from typing import Any, Optional +from typing import Any class ObjectLoader(metaclass=abc.ABCMeta): @@ -62,7 +62,7 @@ def identify_object(self, obj: Any) -> str: return identifier -OBJECT_LOADER: Optional[ObjectLoader] = None +OBJECT_LOADER: ObjectLoader | None = None def get_object_loader() -> ObjectLoader: @@ -78,7 +78,7 @@ def get_object_loader() -> ObjectLoader: return OBJECT_LOADER -def set_object_loader(loader: Optional[ObjectLoader]) -> None: +def set_object_loader(loader: ObjectLoader | None) -> None: """ Set the plumpy global object loader diff --git a/src/plumpy/message.py b/src/plumpy/message.py index 04f03bd9..1ed43845 100644 --- a/src/plumpy/message.py +++ b/src/plumpy/message.py @@ -5,7 +5,8 @@ import asyncio import logging -from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Union, cast +from collections.abc import Sequence +from typing import TYPE_CHECKING, Any, cast from plumpy.coordinator import Coordinator from plumpy.exceptions import PersistenceError, TaskRejectedError @@ -48,7 +49,7 @@ class Intent: LOGGER = logging.getLogger(__name__) -MessageType = Dict[str, Any] +MessageType = dict[str, Any] class MessageBuilder: @@ -90,12 +91,12 @@ def status(cls, text: str | None = None) -> MessageType: def create_launch_body( process_class: str, - init_args: Optional[Sequence[Any]] = None, - init_kwargs: Optional[Dict[str, Any]] = None, + init_args: Sequence[Any] | None = None, + init_kwargs: dict[str, Any] | None = None, persist: bool = False, - loader: Optional[loaders.ObjectLoader] = None, + loader: loaders.ObjectLoader | None = None, nowait: bool = True, -) -> Dict[str, Any]: +) -> dict[str, Any]: """ Create a message body for the launch action @@ -124,7 +125,7 @@ def create_launch_body( return msg_body -def create_continue_body(pid: 'PID_TYPE', tag: Optional[str] = None, nowait: bool = False) -> Dict[str, Any]: +def create_continue_body(pid: 'PID_TYPE', tag: str | None = None, nowait: bool = False) -> dict[str, Any]: """ Create a message body to continue an existing process :param pid: the pid of the existing process @@ -139,11 +140,11 @@ def create_continue_body(pid: 'PID_TYPE', tag: Optional[str] = None, nowait: boo def create_create_body( process_class: str, - init_args: Optional[Sequence[Any]] = None, - init_kwargs: Optional[Dict[str, Any]] = None, + init_args: Sequence[Any] | None = None, + init_kwargs: dict[str, Any] | None = None, persist: bool = False, - loader: Optional[loaders.ObjectLoader] = None, -) -> Dict[str, Any]: + loader: loaders.ObjectLoader | None = None, +) -> dict[str, Any]: """ Create a message body to create a new process :param process_class: the class of the process to launch @@ -196,10 +197,10 @@ class ProcessLauncher: def __init__( self, - loop: Optional[asyncio.AbstractEventLoop] = None, - persister: Optional[persistence.Persister] = None, - load_context: Optional[persistence.LoadSaveContext] = None, - loader: Optional[loaders.ObjectLoader] = None, + loop: asyncio.AbstractEventLoop | None = None, + persister: persistence.Persister | None = None, + load_context: persistence.LoadSaveContext | None = None, + loader: loaders.ObjectLoader | None = None, ) -> None: self._loop = loop self._persister = persister @@ -211,7 +212,7 @@ def __init__( else: self._loader = loaders.get_object_loader() - async def __call__(self, coordinator: Coordinator, task: Dict[str, Any]) -> Union[PID_TYPE, Any]: + async def __call__(self, coordinator: Coordinator, task: dict[str, Any]) -> PID_TYPE | Any: """ Receive a task. :param task: The task message @@ -231,9 +232,9 @@ async def _launch( process_class: str, persist: bool, nowait: bool, - init_args: Optional[Sequence[Any]] = None, - init_kwargs: Optional[Dict[str, Any]] = None, - ) -> Union[PID_TYPE, Any]: + init_args: Sequence[Any] | None = None, + init_kwargs: dict[str, Any] | None = None, + ) -> PID_TYPE | Any: """ Launch the process @@ -266,7 +267,7 @@ async def _launch( return proc.future().result() - async def _continue(self, pid: 'PID_TYPE', nowait: bool, tag: Optional[str] = None) -> Union[PID_TYPE, Any]: + async def _continue(self, pid: 'PID_TYPE', nowait: bool, tag: str | None = None) -> PID_TYPE | Any: """ Continue the process @@ -295,8 +296,8 @@ async def _create( self, process_class: str, persist: bool, - init_args: Optional[Sequence[Any]] = None, - init_kwargs: Optional[Dict[str, Any]] = None, + init_args: Sequence[Any] | None = None, + init_kwargs: dict[str, Any] | None = None, ) -> 'PID_TYPE': """ Create the process diff --git a/src/plumpy/persistence.py b/src/plumpy/persistence.py index 02b5ff76..fb93ca9f 100644 --- a/src/plumpy/persistence.py +++ b/src/plumpy/persistence.py @@ -10,16 +10,12 @@ import inspect import os import pickle +from collections.abc import Generator, Iterable from typing import ( TYPE_CHECKING, Any, Callable, ClassVar, - Dict, - Generator, - Iterable, - List, - Optional, Protocol, TypeVar, runtime_checkable, @@ -38,7 +34,7 @@ class LoadSaveContext: - def __init__(self, loader: Optional[loaders.ObjectLoader] = None, **kwargs: Any) -> None: + def __init__(self, loader: loaders.ObjectLoader | None = None, **kwargs: Any) -> None: self._values = dict(**kwargs) self.loader = loader @@ -134,7 +130,7 @@ def _bundle_constructor(loader: yaml.Loader, data: Any) -> Generator[Bundle, Non class Persister(metaclass=abc.ABCMeta): @abc.abstractmethod - def save_checkpoint(self, process: 'Process', tag: Optional[str] = None) -> None: + def save_checkpoint(self, process: 'Process', tag: str | None = None) -> None: """ Persist a Process instance @@ -145,7 +141,7 @@ def save_checkpoint(self, process: 'Process', tag: Optional[str] = None) -> None """ @abc.abstractmethod - def load_checkpoint(self, pid: PID_TYPE, tag: Optional[str] = None) -> Bundle: + def load_checkpoint(self, pid: PID_TYPE, tag: str | None = None) -> Bundle: """ Load a process from a persisted checkpoint by its process id @@ -158,7 +154,7 @@ def load_checkpoint(self, pid: PID_TYPE, tag: Optional[str] = None) -> Bundle: """ @abc.abstractmethod - def get_checkpoints(self) -> List[PersistedCheckpoint]: + def get_checkpoints(self) -> list[PersistedCheckpoint]: """ Return a list of all the current persisted process checkpoints with each element containing the process id and optional checkpoint tag @@ -167,7 +163,7 @@ def get_checkpoints(self) -> List[PersistedCheckpoint]: """ @abc.abstractmethod - def get_process_checkpoints(self, pid: PID_TYPE) -> List[PersistedCheckpoint]: + def get_process_checkpoints(self, pid: PID_TYPE) -> list[PersistedCheckpoint]: """ Return a list of all the current persisted process checkpoints for the specified process with each element containing the process id and @@ -178,7 +174,7 @@ def get_process_checkpoints(self, pid: PID_TYPE) -> List[PersistedCheckpoint]: """ @abc.abstractmethod - def delete_checkpoint(self, pid: PID_TYPE, tag: Optional[str] = None) -> None: + def delete_checkpoint(self, pid: PID_TYPE, tag: str | None = None) -> None: """ Delete a persisted process checkpoint. No error will be raised if the checkpoint does not exist @@ -251,7 +247,7 @@ def load_pickle(filepath: str) -> 'PersistedPickle': return persisted_pickle @staticmethod - def pickle_filename(pid: PID_TYPE, tag: Optional[str] = None) -> str: + def pickle_filename(pid: PID_TYPE, tag: str | None = None) -> str: """ Returns the relative filepath of the pickle for the given process id and optional checkpoint tag @@ -263,14 +259,14 @@ def pickle_filename(pid: PID_TYPE, tag: Optional[str] = None) -> str: return filename - def _pickle_filepath(self, pid: PID_TYPE, tag: Optional[str] = None) -> str: + def _pickle_filepath(self, pid: PID_TYPE, tag: str | None = None) -> str: """ Returns the full filepath of the pickle for the given process id and optional checkpoint tag """ return os.path.join(self._pickle_directory, PicklePersister.pickle_filename(pid, tag)) - def save_checkpoint(self, process: 'Process', tag: Optional[str] = None) -> None: + def save_checkpoint(self, process: 'Process', tag: str | None = None) -> None: """ Persist a process to a pickle on disk @@ -285,7 +281,7 @@ def save_checkpoint(self, process: 'Process', tag: Optional[str] = None) -> None with open(self._pickle_filepath(process.pid, tag), 'w+b') as handle: pickle.dump(persisted_pickle, handle) - def load_checkpoint(self, pid: PID_TYPE, tag: Optional[str] = None) -> Bundle: + def load_checkpoint(self, pid: PID_TYPE, tag: str | None = None) -> Bundle: """ Load a process from a persisted checkpoint by its process id @@ -300,7 +296,7 @@ def load_checkpoint(self, pid: PID_TYPE, tag: Optional[str] = None) -> Bundle: return checkpoint.bundle - def get_checkpoints(self) -> List[PersistedCheckpoint]: + def get_checkpoints(self) -> list[PersistedCheckpoint]: """ Return a list of all the current persisted process checkpoints with each element containing the process id and optional checkpoint tag @@ -318,7 +314,7 @@ def get_checkpoints(self) -> List[PersistedCheckpoint]: return checkpoints - def get_process_checkpoints(self, pid: PID_TYPE) -> List[PersistedCheckpoint]: + def get_process_checkpoints(self, pid: PID_TYPE) -> list[PersistedCheckpoint]: """ Return a list of all the current persisted process checkpoints for the specified process with each element containing the process id and @@ -329,7 +325,7 @@ def get_process_checkpoints(self, pid: PID_TYPE) -> List[PersistedCheckpoint]: """ return [c for c in self.get_checkpoints() if c.pid == pid] - def delete_checkpoint(self, pid: PID_TYPE, tag: Optional[str] = None) -> None: + def delete_checkpoint(self, pid: PID_TYPE, tag: str | None = None) -> None: """ Delete a persisted process checkpoint. No error will be raised if the checkpoint does not exist @@ -358,26 +354,26 @@ def delete_process_checkpoints(self, pid: PID_TYPE) -> None: class InMemoryPersister(Persister): """Mainly to be used in testing/debugging""" - def __init__(self, loader: Optional[loaders.ObjectLoader] = None) -> None: + def __init__(self, loader: loaders.ObjectLoader | None = None) -> None: super().__init__() - self._checkpoints: Dict[PID_TYPE, Dict[Optional[str], Bundle]] = {} + self._checkpoints: dict[PID_TYPE, dict[str | None, Bundle]] = {} self._save_context = LoadSaveContext(loader=loader) - def save_checkpoint(self, process: 'Process', tag: Optional[str] = None) -> None: + def save_checkpoint(self, process: 'Process', tag: str | None = None) -> None: self._checkpoints.setdefault(process.pid, {})[tag] = Bundle( process, self._save_context.loader, dereference=True ) - def load_checkpoint(self, pid: PID_TYPE, tag: Optional[str] = None) -> Bundle: + def load_checkpoint(self, pid: PID_TYPE, tag: str | None = None) -> Bundle: return self._checkpoints[pid][tag] - def get_checkpoints(self) -> List[PersistedCheckpoint]: + def get_checkpoints(self) -> list[PersistedCheckpoint]: cps = [] for pid in self._checkpoints: cps.extend(self.get_process_checkpoints(pid)) return cps - def get_process_checkpoints(self, pid: PID_TYPE) -> List[PersistedCheckpoint]: + def get_process_checkpoints(self, pid: PID_TYPE) -> list[PersistedCheckpoint]: cps = [] try: for tag, _ in self._checkpoints[pid].items(): @@ -386,7 +382,7 @@ def get_process_checkpoints(self, pid: PID_TYPE) -> List[PersistedCheckpoint]: pass return cps - def delete_checkpoint(self, pid: PID_TYPE, tag: Optional[str] = None) -> None: + def delete_checkpoint(self, pid: PID_TYPE, tag: str | None = None) -> None: try: del self._checkpoints[pid][tag] except KeyError: @@ -397,7 +393,7 @@ def delete_process_checkpoints(self, pid: PID_TYPE) -> None: del self._checkpoints[pid] -def ensure_object_loader(context: Optional['LoadSaveContext'], saved_state: SAVED_STATE_TYPE) -> 'LoadSaveContext': +def ensure_object_loader(context: 'LoadSaveContext' | None, saved_state: SAVED_STATE_TYPE) -> 'LoadSaveContext': """ Given a LoadSaveContext this method will ensure that it has a valid class loader using the following priorities: @@ -453,7 +449,7 @@ def get_custom_meta(saved_state: SAVED_STATE_TYPE, name: str) -> Any: raise ValueError(f"Unknown meta key '{name}'") @staticmethod - def get_create_meta(out_state: SAVED_STATE_TYPE) -> Dict[str, Any]: + def get_create_meta(out_state: SAVED_STATE_TYPE) -> dict[str, Any]: return out_state.setdefault(META, {}) @staticmethod @@ -605,7 +601,7 @@ def save(self, loader: loaders.ObjectLoader | None = None) -> SAVED_STATE_TYPE: return out_state @classmethod - def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> Self: + def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext | None = None) -> Self: """ Recreate a :class:`Savable` from a saved state using an optional load context. diff --git a/src/plumpy/ports.py b/src/plumpy/ports.py index 8522f061..d13ce5ad 100644 --- a/src/plumpy/ports.py +++ b/src/plumpy/ports.py @@ -1,13 +1,13 @@ # -*- coding: utf-8 -*- """Module for process ports""" -import collections import copy import inspect import json import logging import warnings -from typing import Any, Callable, Dict, Iterator, List, Mapping, MutableMapping, Optional, Sequence, Type, Union, cast +from collections.abc import Iterator, Mapping, MutableMapping, Sequence +from typing import Any, Callable, cast from plumpy.utils import AttributesFrozendict, is_mutable_property, type_check @@ -18,7 +18,7 @@ This has been deprecated and the new signature is `validator(value, port)` where the `port` argument will be the port instance to which the validator has been assigned.""" -VALIDATOR_TYPE = Callable[[Any, 'Port'], Optional[str]] +VALIDATOR_TYPE = Callable[[Any, 'Port'], str | None] class PortValidationError(Exception): @@ -64,10 +64,10 @@ class Port: def __init__( self, name: str, - valid_type: Optional[Type[Any]] = None, - help: Optional[str] = None, + valid_type: type[Any] | None = None, + help: str | None = None, required: bool = True, - validator: Optional[VALIDATOR_TYPE] = None, + validator: VALIDATOR_TYPE | None = None, ) -> None: self._name = name self._valid_type = valid_type @@ -83,7 +83,7 @@ def __str__(self) -> str: """ return json.dumps(self.get_description()) - def get_description(self) -> Dict[str, Any]: + def get_description(self) -> dict[str, Any]: """Return a description of the Port, which will be a dictionary of its attributes :returns: a dictionary of the stringified Port attributes @@ -106,7 +106,7 @@ def name(self) -> str: return self._name @property - def valid_type(self) -> Optional[Type[Any]]: + def valid_type(self) -> type[Any] | None: """Get the valid value type for this port if one is specified :return: the value value type @@ -115,7 +115,7 @@ def valid_type(self) -> Optional[Type[Any]]: return self._valid_type @valid_type.setter - def valid_type(self, valid_type: Optional[Type[Any]]) -> None: + def valid_type(self, valid_type: type[Any] | None) -> None: """Set the valid value type for this port :param valid_type: the value valid type @@ -124,7 +124,7 @@ def valid_type(self, valid_type: Optional[Type[Any]]) -> None: self._valid_type = valid_type @property - def help(self) -> Optional[str]: + def help(self) -> str | None: """Get the help string for this port :return: the help string @@ -133,7 +133,7 @@ def help(self) -> Optional[str]: return self._help @help.setter - def help(self, help: Optional[str]) -> None: + def help(self, help: str | None) -> None: """Set the help string for this port :param help: the help string @@ -160,16 +160,15 @@ def required(self, required: bool) -> None: self._required = required @property - def validator(self) -> Optional[VALIDATOR_TYPE]: + def validator(self) -> VALIDATOR_TYPE | None: """Get the validator for this port :return: the validator - :rtype: typing.Callable[[typing.Any], typing.Tuple[bool, typing.Optional[str]]] """ return self._validator @validator.setter - def validator(self, validator: Optional[VALIDATOR_TYPE]) -> None: + def validator(self, validator: VALIDATOR_TYPE | None) -> None: """Set the validator for this port :param validator: a validator function @@ -177,7 +176,7 @@ def validator(self, validator: Optional[VALIDATOR_TYPE]) -> None: """ self._validator = validator - def validate(self, value: Any, breadcrumbs: Sequence[str] = ()) -> Optional[PortValidationError]: + def validate(self, value: Any, breadcrumbs: Sequence[str] = ()) -> PortValidationError | None: """Validate a value to see if it is valid for this port :param value: the value to check @@ -231,11 +230,11 @@ def required_override(required: bool, default: Any) -> bool: def __init__( self, name: str, - valid_type: Optional[Type[Any]] = None, - help: Optional[str] = None, + valid_type: type[Any] | None = None, + help: str | None = None, default: Any = UNSPECIFIED, required: bool = True, - validator: Optional[VALIDATOR_TYPE] = None, + validator: VALIDATOR_TYPE | None = None, ) -> None: super().__init__( name, @@ -273,7 +272,7 @@ def default(self) -> Any: def default(self, default: Any) -> None: self._default = default - def get_description(self) -> Dict[str, str]: + def get_description(self) -> dict[str, str]: """ Return a description of the InputPort, which will be a dictionary of its attributes @@ -291,7 +290,7 @@ class OutputPort(Port): pass -class PortNamespace(collections.abc.MutableMapping, Port): +class PortNamespace(MutableMapping, Port): """ A container for Ports. Effectively it maintains a dictionary whose members are either a Port or yet another PortNamespace. This allows for the nesting of ports @@ -302,10 +301,10 @@ class PortNamespace(collections.abc.MutableMapping, Port): def __init__( self, name: str = '', # Note this was set to None, but that would fail if you tried to compute breadcrumbs - help: Optional[str] = None, + help: str | None = None, required: bool = True, - validator: Optional[VALIDATOR_TYPE] = None, - valid_type: Optional[Type[Any]] = None, + validator: VALIDATOR_TYPE | None = None, + valid_type: type[Any] | None = None, default: Any = UNSPECIFIED, dynamic: bool = False, populate_defaults: bool = True, @@ -326,7 +325,7 @@ def __init__( property is ignored and the population of defaults is always performed. """ super().__init__(name=name, help=help, required=required, validator=validator, valid_type=valid_type) - self._ports: Dict[str, Union[Port, 'PortNamespace']] = {} + self._ports: dict[str, 'Port | PortNamespace'] = {} self.default = default self.populate_defaults = populate_defaults self.valid_type = valid_type @@ -347,16 +346,16 @@ def __len__(self) -> int: def __delitem__(self, key: str) -> None: del self._ports[key] - def __getitem__(self, key: str) -> Union[Port, 'PortNamespace']: + def __getitem__(self, key: str) -> 'Port | PortNamespace': return self._ports[key] - def __setitem__(self, key: str, port: Union[Port, 'PortNamespace']) -> None: + def __setitem__(self, key: str, port: 'Port | PortNamespace') -> None: if not isinstance(port, Port): raise TypeError('port needs to be an instance of Port') self._ports[key] = port @property - def ports(self) -> Dict[str, Union[Port, 'PortNamespace']]: + def ports(self) -> dict[str, 'Port | PortNamespace']: return self._ports def has_default(self) -> bool: @@ -379,11 +378,11 @@ def dynamic(self, dynamic: bool) -> None: self._dynamic = dynamic @property - def valid_type(self) -> Optional[Type[Any]]: + def valid_type(self) -> type[Any] | None: return super().valid_type @valid_type.setter - def valid_type(self, valid_type: Optional[Type[Any]]) -> None: + def valid_type(self, valid_type: type[Any] | None) -> None: """Set the `valid_type` for the `PortNamespace`. If the `valid_type` is None, the `dynamic` property will be set to `False`, in all other cases `dynamic` will be @@ -404,7 +403,7 @@ def populate_defaults(self) -> bool: def populate_defaults(self, populate_defaults: bool) -> None: self._populate_defaults = populate_defaults - def get_description(self) -> Dict[str, Dict[str, Any]]: + def get_description(self) -> dict[str, dict[str, Any]]: """ Return a dictionary with a description of the ports this namespace contains Nested PortNamespaces will be properly recursed and Ports will print their properties in a list @@ -426,7 +425,7 @@ def get_description(self) -> Dict[str, Dict[str, Any]]: return description - def get_port(self, name: str, create_dynamically: bool = False) -> Union[Port, 'PortNamespace']: + def get_port(self, name: str, create_dynamically: bool = False) -> 'Port | PortNamespace': """ Retrieve a (namespaced) port from this PortNamespace. If any of the sub namespaces of the terminal port itself cannot be found, a ValueError will be raised @@ -510,10 +509,10 @@ def create_port_namespace(self, name: str, **kwargs: Any) -> 'PortNamespace': def absorb( self, port_namespace: 'PortNamespace', - exclude: Optional[Sequence[str]] = None, - include: Optional[Sequence[str]] = None, - namespace_options: Optional[Dict[str, Any]] = None, - ) -> List[str]: + exclude: Sequence[str] | None = None, + include: Sequence[str] | None = None, + namespace_options: dict[str, Any] | None = None, + ) -> list[str]: """Absorb another PortNamespace instance into oneself, including all its mutable properties and ports. Mutable properties of self will be overwritten with those of the port namespace that is to be absorbed. @@ -611,8 +610,8 @@ def project(self, port_values: MutableMapping[str, Any]) -> MutableMapping[str, return result def validate( - self, port_values: Optional[Mapping[str, Any]] = None, breadcrumbs: Sequence[str] = () - ) -> Optional[PortValidationError]: + self, port_values: Mapping[str, Any] | None = None, breadcrumbs: Sequence[str] = () + ) -> PortValidationError | None: """ Validate the namespace port itself and subsequently all the port_values it contains @@ -622,12 +621,12 @@ def validate( """ breadcrumbs_local = (*breadcrumbs, self.name) - message: Optional[str] + message: str | None if not port_values: port_values = {} - if not isinstance(port_values, collections.abc.Mapping): + if not isinstance(port_values, Mapping): message = f'specified value is of type {type(port_values)} which is not sub class of `Mapping`' return PortValidationError(message, breadcrumbs_to_port(breadcrumbs_local)) @@ -706,7 +705,7 @@ def pre_process(self, port_values: MutableMapping[str, Any]) -> AttributesFrozen def validate_ports( self, port_values: MutableMapping[str, Any], breadcrumbs: Sequence[str] - ) -> Optional[PortValidationError]: + ) -> PortValidationError | None: """ Validate port values with respect to the explicitly defined ports of the port namespace. Ports values that are matched to an actual Port will be popped from the dictionary @@ -725,7 +724,7 @@ def validate_ports( def validate_dynamic_ports( self, port_values: MutableMapping[str, Any], breadcrumbs: Sequence[str] = () - ) -> Optional[PortValidationError]: + ) -> PortValidationError | None: """ Validate port values with respect to the dynamic properties of the port namespace. It will check if the namespace is actually dynamic and if all values adhere to the valid types of @@ -736,7 +735,7 @@ def validate_dynamic_ports( :param breadcrumbs: a tuple of the path to having reached this point in validation :type breadcrumbs: typing.Tuple[str] :return: if invalid returns a string with the reason for the validation failure, otherwise None - :rtype: typing.Optional[str] + :rtype: str | None """ if port_values and not self.dynamic: msg = f'Unexpected ports {port_values}, for a non dynamic namespace' @@ -757,7 +756,7 @@ def validate_dynamic_ports( return None @staticmethod - def strip_namespace(namespace: str, separator: str, rules: Optional[Sequence[str]] = None) -> Optional[List[str]]: + def strip_namespace(namespace: str, separator: str, rules: Sequence[str] | None = None) -> list[str] | None: """Filter given exclude/include rules staring with namespace and strip the first level. For example if the namespace is `base` and the rules are:: diff --git a/src/plumpy/process_listener.py b/src/plumpy/process_listener.py index 8bc7c828..5a9098da 100644 --- a/src/plumpy/process_listener.py +++ b/src/plumpy/process_listener.py @@ -2,7 +2,7 @@ from __future__ import annotations import abc -from typing import TYPE_CHECKING, Any, Dict, Optional +from typing import TYPE_CHECKING, Any from typing_extensions import Self @@ -22,13 +22,13 @@ class ProcessListener(metaclass=abc.ABCMeta): def __init__(self) -> None: super().__init__() - self._params: Dict[str, Any] = {} + self._params: dict[str, Any] = {} def init(self, **kwargs: Any) -> None: self._params = kwargs @classmethod - def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> Self: + def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext | None = None) -> Self: """ Recreate a :class:`Savable` from a saved state using an optional load context. diff --git a/src/plumpy/utils.py b/src/plumpy/utils.py index cb75f7bd..f7cbcf58 100644 --- a/src/plumpy/utils.py +++ b/src/plumpy/utils.py @@ -17,7 +17,6 @@ MutableMapping, Optional, Tuple, - Type, ) from . import lang @@ -180,7 +179,7 @@ def load_module(fullname: str) -> Tuple[types.ModuleType, deque]: return mod, remainder -def type_check(obj: Any, expected_type: Type) -> None: +def type_check(obj: Any, expected_type: type[Any]) -> None: if not isinstance(obj, expected_type): raise TypeError(f"Got object of type '{type(obj)}' when expecting '{expected_type}'")