Skip to content

Commit

Permalink
feat(serialization): Implement thread-safe enabling/disabling of seri…
Browse files Browse the repository at this point in the history
…alization for `to_dict()` method
  • Loading branch information
gMatas committed Dec 29, 2024
1 parent c3084b3 commit 5700458
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 12 deletions.
18 changes: 13 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,14 @@

## About

EzSerialization is meant to be simple in features and usage. It follows these two ideas:
EzSerialization is meant to be simple in features and usage. It follows these three ideas:

- **Python dicts based**. This package only helps to serialize objects to dicts.
Converting them to JSON, XML, etc. is left to the user.
- **Transparent serialization logic**. It does not have automatic `from_dict()` & `to_dict()` methods that convert class
instances of any kind to dicts. Implementing them is left to the end-user, thus being transparent with what actually
happens with the user data.
happens with this data.
- **Thread-safe**. Serialization, deserialization & its enabling/disabling is thread-safe.

All EzSerialization do is it wraps `to_dict()` & `from_dict()` methods for selected classes to inject, register and
use class type information for deserialization.
Expand Down Expand Up @@ -53,7 +54,7 @@ Here's an example:
```python
from pprint import pprint
from typing import Mapping
from ezserialization import serializable, deserialize
from ezserialization import serializable, deserialize, no_serialization

@serializable
class Example:
Expand All @@ -69,14 +70,21 @@ class Example:


obj = Example("wow")
obj_dict = obj.to_dict()

# Serialization without ability to automatically deserialize:
with no_serialization():
raw_obj_dict = obj.to_dict()
pprint(raw_obj_dict, indent=2)
# Output:
# {'some_value': 'wow'}

# Serialization with ability automatic deserialization:
obj_dict = obj.to_dict()
pprint(obj_dict, indent=2)
# Output:
# {'_type_': '__main__.Example', 'some_value': 'wow'}

obj2 = deserialize(obj_dict)

print(obj.value == obj2.value)
# Output:
# True
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "ezserialization"
version = "0.2.8"
version = "0.2.9"
description = "Simple, easy to use & transparent python objects serialization & deserialization."
authors = ["Matas Gumbinas <[email protected]>"]
repository = "https://github.com/gMatas/ezserialization"
Expand Down
49 changes: 45 additions & 4 deletions src/ezserialization/_serialization.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
import contextlib
import functools
import importlib
import threading
from abc import abstractmethod
from typing import Callable, Dict, Mapping, Optional, Protocol, Type, TypeVar
from copy import copy
from typing import Callable, Dict, Iterator, Mapping, Optional, Protocol, Type, TypeVar, cast

__all__ = [
"TYPE_FIELD_NAME",
"using_serialization",
"use_serialization",
"no_serialization",
"Serializable",
"serializable",
"deserialize",
Expand Down Expand Up @@ -45,6 +51,39 @@ def _is_serializable_subclass(cls: Type) -> bool:
Serializable object type.
"""

_thread_local = threading.local()
_thread_local.enabled = (_SERIALIZATION_ENABLED_DEFAULT := True)
"""
Thread-safe serialization enabling/disabling flag.
"""


def using_serialization() -> bool:
if not hasattr(_thread_local, "enabled"):
_thread_local.enabled = _SERIALIZATION_ENABLED_DEFAULT
return cast(bool, _thread_local.enabled)


@contextlib.contextmanager
def use_serialization() -> Iterator[None]:
prev = using_serialization()
try:
_thread_local.enabled = _SERIALIZATION_ENABLED_DEFAULT
yield
finally:
_thread_local.enabled = prev


@contextlib.contextmanager
def no_serialization() -> Iterator[None]:
prev = using_serialization()
try:
_thread_local.enabled = not _SERIALIZATION_ENABLED_DEFAULT
yield
finally:
_thread_local.enabled = prev


_types_: Dict[str, Type[Serializable]] = {}
_typenames_: Dict[Type[Serializable], str] = {}
_typename_aliases_: Dict[str, str] = {}
Expand Down Expand Up @@ -108,8 +147,10 @@ def to_dict_wrapper(obj: Serializable) -> Mapping:
# Wrap object with serialization metadata.
if TYPE_FIELD_NAME in data:
raise KeyError(f"Key '{TYPE_FIELD_NAME}' already exist in the serialized data mapping!")
typename = _typenames_[type(obj)]
return {TYPE_FIELD_NAME: typename, **data}
if using_serialization():
typename = _typenames_[type(obj)]
return {TYPE_FIELD_NAME: typename, **data}
return copy(data)

return to_dict_wrapper

Expand All @@ -123,7 +164,7 @@ def from_dict_wrapper(*args) -> Serializable:
src = args[1] if len(args) == 2 else args[0]
# Remove deserialization metadata.
src = dict(src)
del src[TYPE_FIELD_NAME]
src.pop(TYPE_FIELD_NAME, None)
# Deserialize as-is.
return method(src)

Expand Down
72 changes: 70 additions & 2 deletions tests/ezserialization_tests/test_serialization.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
import threading
import time
from typing import Mapping, cast

from ezserialization import Serializable, deserialize, serializable
from ezserialization import (
Serializable,
deserialize,
no_serialization,
serializable,
use_serialization,
using_serialization,
)


@serializable # <- valid for serialization
Expand All @@ -10,9 +19,12 @@ class _CaseAUsingAutoName(Serializable):
def __init__(self, value: str):
self.value = value

def to_dict(self) -> Mapping:
def to_raw_dict(self) -> dict:
return {"value": self.value}

def to_dict(self) -> Mapping:
return self.to_raw_dict()

@classmethod
def from_dict(cls, src: Mapping):
return cls(value=src["value"])
Expand Down Expand Up @@ -53,3 +65,59 @@ def test_serialization_typenames_order():
data = b.to_dict()
assert data["_type_"] == "B"
assert b.value == cast(_CaseBUsingNameAlias, deserialize(data)).value


def test_threadsafe_serialization_enabling_and_disabling():
a = _CaseAUsingAutoName("foo")

assert using_serialization(), "By default, serialization must be enabled!"

a_dict = a.to_dict()
raw_a_dict = a.to_raw_dict()
assert a_dict != raw_a_dict, "Bad test setup."

thread = _TestThread()
with no_serialization():
with use_serialization():
assert a.to_dict() == a_dict
assert a.to_dict() == raw_a_dict

thread.start()
while not thread.serialization_explicitly_enabled:
time.sleep(0.1)

assert not using_serialization()

assert using_serialization()
thread.should_stop = True
while not thread.finished:
time.sleep(0.1)

if thread.exception is not None:
raise thread.exception

assert using_serialization()


class _TestThread(threading.Thread):
def __init__(self):
self.exception = None
self.finished = False
self.should_stop = False
self.serialization_explicitly_enabled = False
super().__init__(target=self._fun, daemon=True)

def _fun(self):
try:
assert using_serialization()
with use_serialization():
assert using_serialization()
self.serialization_explicitly_enabled = True
while not self.should_stop:
time.sleep(0.1)
except Exception as e:
self.exception = e
finally:
self.finished = True
self.should_stop = True
self.serialization_explicitly_enabled = True

0 comments on commit 5700458

Please sign in to comment.