Skip to content

Commit

Permalink
fix(serialization): Fix enable/disable defaults
Browse files Browse the repository at this point in the history
  • Loading branch information
gMatas committed Dec 29, 2024
1 parent 6713b28 commit 2b574da
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 19 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "ezserialization"
version = "0.2.10"
version = "0.2.11"
description = "Simple, easy to use & transparent python objects serialization & deserialization."
authors = ["Matas Gumbinas <[email protected]>"]
repository = "https://github.com/gMatas/ezserialization"
Expand Down
41 changes: 23 additions & 18 deletions src/ezserialization/_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,42 +46,41 @@ def _is_serializable_subclass(cls: Type) -> bool:
return hasattr(cls, "from_dict") and hasattr(cls, "to_dict")


_T = TypeVar("_T", bound=Serializable)
"""
Serializable object type.
"""

_thread_local = threading.local()
_thread_local.enabled = (_SERIALIZATION_ENABLED_DEFAULT := True)
"""
Thread-safe serialization enabling/disabling flag.
"""


def using_serialization() -> bool:
def _get_serialization_enabled() -> bool:
if not hasattr(_thread_local, "enabled"):
_thread_local.enabled = _SERIALIZATION_ENABLED_DEFAULT
_thread_local.enabled = True
return cast(bool, _thread_local.enabled)


def _set_serialization_enabled(enabled: bool) -> None:
_thread_local.enabled = enabled


def using_serialization() -> bool:
return _get_serialization_enabled()


@contextlib.contextmanager
def use_serialization() -> Iterator[None]:
prev = using_serialization()
prev = _get_serialization_enabled()
try:
_thread_local.enabled = _SERIALIZATION_ENABLED_DEFAULT
_set_serialization_enabled(True)
yield
finally:
_thread_local.enabled = prev
_set_serialization_enabled(prev)


@contextlib.contextmanager
def no_serialization() -> Iterator[None]:
prev = using_serialization()
prev = _get_serialization_enabled()
try:
_thread_local.enabled = not _SERIALIZATION_ENABLED_DEFAULT
_set_serialization_enabled(False)
yield
finally:
_thread_local.enabled = prev
_set_serialization_enabled(prev)


_types_: Dict[str, Type[Serializable]] = {}
Expand Down Expand Up @@ -125,6 +124,12 @@ def _is_same_type_by_qualname(a: Type, b: Type) -> bool:
return _abs_qualname(a) == _abs_qualname(b)


_T = TypeVar("_T", bound=Serializable)
"""
Serializable object type.
"""


def serializable(cls: Optional[Type[_T]] = None, *, name: Optional[str] = None):
def wrapper(cls_: Type[_T]) -> Type[_T]:
nonlocal name
Expand All @@ -147,7 +152,7 @@ def to_dict_wrapper(obj: Serializable) -> Mapping:
# Wrap object with serialization metadata.
if TYPE_FIELD_NAME in data:
raise KeyError(f"Key '{TYPE_FIELD_NAME}' already exist in the serialized data mapping!")
if using_serialization():
if _get_serialization_enabled():
typename = _typenames_[type(obj)]
return {TYPE_FIELD_NAME: typename, **data}
return copy(data)
Expand Down

0 comments on commit 2b574da

Please sign in to comment.