Skip to content

Commit

Permalink
Using msgpack for passing control messages
Browse files Browse the repository at this point in the history
  • Loading branch information
unkcpz committed Jan 30, 2025
1 parent 498a8d4 commit 4f8a832
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 17 deletions.
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ keywords = ['workflow', 'multithreaded', 'rabbitmq']
requires-python = '>=3.10'
dependencies = [
'kiwipy[rmq]~=0.8.5',
'msgpack~=1.1',
'nest_asyncio~=1.5,>=1.5.1',
'pyyaml~=6.0',
'typing-extensions~=4.12'
Expand Down Expand Up @@ -131,6 +132,7 @@ module = [
'aiocontextvars.*',
'frozendict.*',
'kiwipy.*',
'msgpack.*',
'nest_asyncio.*',
'tblib.*',
]
Expand Down
14 changes: 0 additions & 14 deletions src/plumpy/persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import inspect
import os
import pickle
import uuid
from collections.abc import Generator, Iterable
from typing import (
TYPE_CHECKING,
Expand All @@ -33,19 +32,6 @@
from .processes import Process


def uuid_representer(dumper, data): # type: ignore
return dumper.represent_scalar('!uuid', str(data))


def uuid_constructor(loader, node): # type: ignore
value = loader.construct_scalar(node)
return uuid.UUID(value)


yaml.add_representer(uuid.UUID, uuid_representer)
yaml.add_constructor('!uuid', uuid_constructor)


class LoadSaveContext:
def __init__(self, loader: loaders.ObjectLoader | None = None, **kwargs: Any) -> None:
self._values = dict(**kwargs)
Expand Down
44 changes: 44 additions & 0 deletions src/plumpy/rmq/process_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,13 @@
"""Module for process level communication functions and classes"""

import asyncio
import functools
import uuid
from typing import Any, Dict, Hashable, Optional, Sequence, Union

import kiwipy
import msgpack
import yaml

from plumpy import loaders
from plumpy.coordinator import Coordinator
Expand All @@ -27,6 +31,46 @@
ProcessStatus = Any


# Define yaml type represender and constructor for UUID type handling in message passing
# NOTE: it is recommend to use msgpack for sending message, the yaml is only here for reference.
def uuid_representer(dumper, data): # type: ignore
return dumper.represent_scalar('!uuid', str(data))


def uuid_constructor(loader, node): # type: ignore
value = loader.construct_scalar(node)
return uuid.UUID(value)


yaml.add_representer(uuid.UUID, uuid_representer)
yaml.add_constructor('!uuid', uuid_constructor)

YAML_ENCODER = functools.partial(yaml.dump, encoding='utf-8')
YAML_DECODER = functools.partial(yaml.load, Loader=yaml.FullLoader)

# Define ext hook for msgpack to handle UUID type in message passing

UUID_EXT_CODE = 42 # Just pick any integer < 128


def default_uuid_ext(obj: Any) -> msgpack.ExtType:
"""Convert UUID objects into a custom msgpack.ExtType."""
if isinstance(obj, uuid.UUID):
return msgpack.ExtType(UUID_EXT_CODE, obj.bytes)
raise TypeError(f'Cannot serialize type {type(obj)}')


def ext_hook(code: Any, data: bytes | None) -> Any:
"""Recreate the object from the custom msgpack.ExtType."""
if code == UUID_EXT_CODE:
return uuid.UUID(bytes=data)
return msgpack.ExtType(code, data)


MSGPACK_ENCODER = functools.partial(msgpack.packb, default=default_uuid_ext)
MSGPACK_DECODER = functools.partial(msgpack.unpackb, ext_hook=ext_hook)


# FIXME: the class not fit typing of ProcessController protocol
class RemoteProcessController:
"""
Expand Down
7 changes: 4 additions & 3 deletions tests/rmq/test_communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
import uuid
import pytest
import shortuuid
import yaml
# import yaml
import msgpack

from kiwipy.rmq import RmqThreadCommunicator

Expand Down Expand Up @@ -41,8 +42,8 @@ def _coordinator():
message_exchange=message_exchange,
task_exchange=task_exchange,
task_queue=task_queue,
encoder=functools.partial(yaml.dump, encoding='utf-8'),
decoder=functools.partial(yaml.load, Loader=yaml.FullLoader),
encoder=process_control.MSGPACK_ENCODER,
decoder=process_control.MSGPACK_DECODER,
)

loop = asyncio.get_event_loop()
Expand Down
54 changes: 54 additions & 0 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 4f8a832

Please sign in to comment.