Skip to content

Commit

Permalink
Create/Launch/Continue body into builder (#26)
Browse files Browse the repository at this point in the history
Using the same interface for creating all message type, create/launch/continue, play/pause/kill/status
It revert the message API did by using MessageBuilder, it cannot handle `MessageBuilder.continue`, it is not a good API design.
  • Loading branch information
unkcpz authored Jan 31, 2025
1 parent 4edd4df commit 0730394
Show file tree
Hide file tree
Showing 9 changed files with 163 additions and 149 deletions.
28 changes: 9 additions & 19 deletions src/plumpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
)
from .futures import CancellableAction, Future, capture_exceptions, create_task
from .loaders import DefaultObjectLoader, ObjectLoader, get_object_loader, set_object_loader
from .message import MessageBuilder, ProcessLauncher, create_continue_body, create_launch_body
from .message import Message, MsgContinue, MsgCreate, MsgKill, MsgLaunch, MsgPause, MsgPlay, MsgStatus, ProcessLauncher
from .persistence import (
Bundle,
InMemoryPersister,
Expand Down Expand Up @@ -64,26 +64,17 @@
from .workchains import ToContext, WorkChain, WorkChainSpec, if_, return_, while_

__all__ = (
# ports
'UNSPECIFIED',
# utils
'AttributesDict',
# persistence
'Bundle',
# processes
'BundleKeys',
# futures
'CancellableAction',
# exceptions
'ClosedError',
# process_states/States
'Continue',
# coordinator
'Coordinator',
'CoordinatorConnectionError',
'CoordinatorTimeoutError',
'Created',
# loaders
'DefaultObjectLoader',
'Excepted',
'Finished',
Expand All @@ -92,39 +83,40 @@
'InputPort',
'Interruption',
'InvalidStateError',
# process_states/Commands
'Kill',
'KillInterruption',
'Killed',
'KilledError',
'LoadSaveContext',
# message
'MessageBuilder',
'Message',
'MsgContinue',
'MsgCreate',
'MsgKill',
'MsgLaunch',
'MsgPause',
'MsgPlay',
'MsgStatus',
'ObjectLoader',
'OutputPort',
'PauseInterruption',
'PersistedCheckpoint',
'PersistenceError',
'Persister',
'PicklePersister',
# event
'PlumpyEventLoopPolicy',
'Port',
'PortNamespace',
'PortValidationError',
'Process',
# controller
'ProcessController',
'ProcessLauncher',
# process_listener
'ProcessListener',
'ProcessSpec',
'ProcessState',
'Running',
'Savable',
'SavableFuture',
'Stop',
# workchain
'ToContext',
'TransitionFailed',
'UnsuccessfulResult',
Expand All @@ -134,8 +126,6 @@
'WorkChainSpec',
'auto_persist',
'capture_exceptions',
'create_continue_body',
'create_launch_body',
'create_task',
'get_event_loop',
'get_object_loader',
Expand Down
184 changes: 103 additions & 81 deletions src/plumpy/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,124 +48,146 @@ class Intent:
LOGGER = logging.getLogger(__name__)

MessageType = dict[str, Any]
Message = dict[str, Any]


class MessageBuilder:
"""MessageBuilder will construct different messages that can passing over coordinator."""

class MsgPlay:
@classmethod
def play(cls, text: str | None = None) -> MessageType:
def new(cls, text: str | None = None) -> Message:
"""The play message send over coordinator."""
return {
INTENT_KEY: Intent.PLAY,
MESSAGE_TEXT_KEY: text,
}


class MsgPause:
"""
The 'pause' message sent over a coordinator.
"""

@classmethod
def pause(cls, text: str | None = None) -> MessageType:
"""The pause message send over coordinator."""
def new(cls, text: str | None = None) -> MessageType:
return {
INTENT_KEY: Intent.PAUSE,
MESSAGE_TEXT_KEY: text,
}


class MsgKill:
"""
The 'kill' message sent over a coordinator.
"""

@classmethod
def kill(cls, text: str | None = None, force_kill: bool = False) -> MessageType:
"""The kill message send over coordinator."""
def new(cls, text: str | None = None, force_kill: bool = False) -> MessageType:
return {
INTENT_KEY: Intent.KILL,
MESSAGE_TEXT_KEY: text,
FORCE_KILL_KEY: force_kill,
}


class MsgStatus:
"""
The 'status' message sent over a coordinator.
"""

@classmethod
def status(cls, text: str | None = None) -> MessageType:
"""The status message send over coordinator."""
def new(cls, text: str | None = None) -> MessageType:
return {
INTENT_KEY: Intent.STATUS,
MESSAGE_TEXT_KEY: text,
}


def create_launch_body(
process_class: str,
init_args: Sequence[Any] | None = None,
init_kwargs: dict[str, Any] | None = None,
persist: bool = False,
loader: loaders.ObjectLoader | None = None,
nowait: bool = True,
) -> dict[str, Any]:
"""
Create a message body for the launch action
:param process_class: the class of the process to launch
:param init_args: any initialisation positional arguments
:param init_kwargs: any initialisation keyword arguments
:param persist: persist this process if True, otherwise don't
:param loader: the loader to use to load the persisted process
:param nowait: wait for the process to finish before completing the task, otherwise just return the PID
:return: a dictionary with the body of the message to launch the process
:rtype: dict
class MsgLaunch:
"""
if loader is None:
loader = loaders.get_object_loader()

msg_body = {
TASK_KEY: LAUNCH_TASK,
TASK_ARGS: {
PROCESS_CLASS_KEY: loader.identify_object(process_class),
PERSIST_KEY: persist,
NOWAIT_KEY: nowait,
ARGS_KEY: init_args,
KWARGS_KEY: init_kwargs,
},
}
return msg_body


def create_continue_body(pid: 'PID_TYPE', tag: str | None = None, nowait: bool = False) -> dict[str, Any]:
Create the message payload for the launch action.
"""
Create a message body to continue an existing process
:param pid: the pid of the existing process
:param tag: the optional persistence tag
:param nowait: wait for the process to finish before completing the task, otherwise just return the PID
:return: a dictionary with the body of the message to continue the process

"""
msg_body = {TASK_KEY: CONTINUE_TASK, TASK_ARGS: {PID_KEY: pid, NOWAIT_KEY: nowait, TAG_KEY: tag}}
return msg_body
@classmethod
def new(
cls,
process_class: str,
init_args: Sequence[Any] | None = None,
init_kwargs: dict[str, Any] | None = None,
persist: bool = False,
loader: 'loaders.ObjectLoader | None' = None,
nowait: bool = True,
) -> dict[str, Any]:
"""
Create a message body for the launch action
"""
if loader is None:
loader = loaders.get_object_loader()

return {
TASK_KEY: LAUNCH_TASK,
TASK_ARGS: {
PROCESS_CLASS_KEY: loader.identify_object(process_class),
PERSIST_KEY: persist,
NOWAIT_KEY: nowait,
ARGS_KEY: init_args,
KWARGS_KEY: init_kwargs,
},
}


def create_create_body(
process_class: str,
init_args: Sequence[Any] | None = None,
init_kwargs: dict[str, Any] | None = None,
persist: bool = False,
loader: loaders.ObjectLoader | None = None,
) -> dict[str, Any]:
class MsgContinue:
"""
Create the message payload to continue an existing process.
"""
Create a message body to create a new process
:param process_class: the class of the process to launch
:param init_args: any initialisation positional arguments
:param init_kwargs: any initialisation keyword arguments
:param persist: persist this process if True, otherwise don't
:param loader: the loader to use to load the persisted process
:return: a dictionary with the body of the message to launch the process

@classmethod
def new(
cls,
pid: 'PID_TYPE',
tag: str | None = None,
nowait: bool = False,
) -> dict[str, Any]:
"""
Create a message body to continue an existing process.
"""
return {
TASK_KEY: CONTINUE_TASK,
TASK_ARGS: {
PID_KEY: pid,
NOWAIT_KEY: nowait,
TAG_KEY: tag,
},
}


class MsgCreate:
"""
Create the message payload to create a new process.
"""
if loader is None:
loader = loaders.get_object_loader()

msg_body = {
TASK_KEY: CREATE_TASK,
TASK_ARGS: {
PROCESS_CLASS_KEY: loader.identify_object(process_class),
PERSIST_KEY: persist,
ARGS_KEY: init_args,
KWARGS_KEY: init_kwargs,
},
}
return msg_body

@classmethod
def new(
cls,
process_class: str,
init_args: Sequence[Any] | None = None,
init_kwargs: dict[str, Any] | None = None,
persist: bool = False,
loader: 'loaders.ObjectLoader | None' = None,
) -> dict[str, Any]:
"""
Create a message body to create a new process.
"""
if loader is None:
loader = loaders.get_object_loader()

return {
TASK_KEY: CREATE_TASK,
TASK_ARGS: {
PROCESS_CLASS_KEY: loader.identify_object(process_class),
PERSIST_KEY: persist,
ARGS_KEY: init_args,
KWARGS_KEY: init_kwargs,
},
}


class ProcessLauncher:
Expand Down
14 changes: 7 additions & 7 deletions src/plumpy/process_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from typing_extensions import Self, override

from plumpy.loaders import ObjectLoader
from plumpy.message import MessageBuilder, MessageType
from plumpy.message import Message, MsgKill, MsgPause
from plumpy.persistence import ensure_object_loader

try:
Expand Down Expand Up @@ -64,17 +64,17 @@ class Interruption(Exception): # noqa: N818
class KillInterruption(Interruption):
def __init__(self, msg_text: str | None):
super().__init__()
msg = MessageBuilder.kill(text=msg_text)
msg = MsgKill.new(text=msg_text)

self.msg: MessageType = msg
self.msg: Message = msg


class PauseInterruption(Interruption):
def __init__(self, msg_text: str | None):
super().__init__()
msg = MessageBuilder.pause(text=msg_text)
msg = MsgPause.new(text=msg_text)

self.msg: MessageType = msg
self.msg: Message = msg


# region Commands
Expand Down Expand Up @@ -104,7 +104,7 @@ def save(self, loader: ObjectLoader | None = None) -> SAVED_STATE_TYPE:

@auto_persist('msg')
class Kill(Command):
def __init__(self, msg: MessageType | None = None):
def __init__(self, msg: Message | None = None):
super().__init__()
self.msg = msg

Expand Down Expand Up @@ -633,7 +633,7 @@ class Killed:

is_terminal: ClassVar[bool] = True

def __init__(self, msg: MessageType | None):
def __init__(self, msg: Message | None):
"""
:param msg: Optional kill message
"""
Expand Down
Loading

0 comments on commit 0730394

Please sign in to comment.