From 8f927e35f501eff80e734fb04a6d552e766e3998 Mon Sep 17 00:00:00 2001 From: Tom Schierenbeck Date: Thu, 21 Mar 2024 13:12:06 +0100 Subject: [PATCH] Added serialization of events --- src/random_events/__init__.py | 2 +- src/random_events/events.py | 32 ++++++++++++++++++++++++--- src/random_events/utils.py | 40 ++++++++++++++++++++++++++++++++++ src/random_events/variables.py | 40 +++++++++++++++++++++------------- test/test_events.py | 27 +++++++++++++++++++++++ 5 files changed, 122 insertions(+), 19 deletions(-) diff --git a/src/random_events/__init__.py b/src/random_events/__init__.py index 4c354e0..13ce17d 100644 --- a/src/random_events/__init__.py +++ b/src/random_events/__init__.py @@ -1 +1 @@ -__version__ = '2.0.5' +__version__ = '2.0.6' diff --git a/src/random_events/events.py b/src/random_events/events.py index 63b3349..d3233f2 100644 --- a/src/random_events/events.py +++ b/src/random_events/events.py @@ -10,6 +10,7 @@ from typing_extensions import Set, Union, Any, TYPE_CHECKING, Iterable, List, Self, Dict, Tuple from .variables import Variable, Continuous, Discrete +from .utils import SubclassJSONSerializer # Type hinting for Python 3.7 to 3.9 @@ -109,7 +110,7 @@ def is_empty(self) -> bool: raise NotImplementedError -class Event(SupportsSetOperations, EventMapType): +class Event(SupportsSetOperations, EventMapType, SubclassJSONSerializer): """ A map of variables to values of their respective domains. """ @@ -434,6 +435,21 @@ def marginal_event(self, variables: Iterable[Variable]) -> Self: """ return self.__class__({variable: self[variable] for variable in variables if variable in self}) + def to_json(self) -> Dict[str, Any]: + result = super().to_json() + event = [(variable.to_json(), variable.assignment_to_json(assignment)) for variable, assignment in self.items()] + result["event"] = event + return result + + @classmethod + def _from_json(cls, data: Dict[str, Any]) -> Self: + result = cls() + for variable_json, assignment_json in data["event"]: + variable = Variable.from_json(variable_json) + assignment = variable.assignment_from_json(assignment_json) + result[variable] = assignment + return result + class EncodedEvent(Event): """ @@ -483,8 +499,7 @@ def encode(self) -> Self: return self.__copy__() - -class ComplexEvent(SupportsSetOperations): +class ComplexEvent(SupportsSetOperations, SubclassJSONSerializer): """ A complex event is a set of mutually exclusive events. """ @@ -708,5 +723,16 @@ def merge_if_one_dimensional(self) -> Self: value = variable.union_of_assignments(value, event[variable]) return ComplexEvent([Event({variable: value})]) + def to_json(self) -> Dict[str, Any]: + result = super().to_json() + events = [event.to_json() for event in self.events] + result["events"] = events + return result + + @classmethod + def _from_json(cls, data: Dict[str, Any]) -> Self: + events = [Event.from_json(event) for event in data["events"]] + return cls(events) + EventType = Union[Event, EncodedEvent, ComplexEvent] diff --git a/src/random_events/utils.py b/src/random_events/utils.py index 0c08258..ea87afe 100644 --- a/src/random_events/utils.py +++ b/src/random_events/utils.py @@ -1,3 +1,6 @@ +from typing_extensions import Dict, Any, Self + + def get_full_class_name(cls): """ Returns the full name of a class, including the module name. @@ -14,3 +17,40 @@ def recursive_subclasses(cls): :return: A list of the classes subclasses. """ return cls.__subclasses__() + [g for s in cls.__subclasses__() for g in recursive_subclasses(s)] + + +class SubclassJSONSerializer: + """ + Class for automatic (de)serialization of subclasses. + Classes that inherit from this class can be serialized and deserialized automatically by calling this classes + 'from_json' method. + """ + + def to_json(self) -> Dict[str, Any]: + return {"type": get_full_class_name(self.__class__)} + + @classmethod + def _from_json(cls, data: Dict[str, Any]) -> Self: + """ + Create a variable from a json dict. + This method is called from the from_json method after the correct subclass is determined and should be + overwritten by the respective subclass. + + :param data: The json dict + :return: The deserialized object + """ + raise NotImplementedError() + + @classmethod + def from_json(cls, data: Dict[str, Any]) -> Self: + """ + Create the correct instanceof the subclass from a json dict. + + :param data: The json dict + :return: The correct instance of the subclass + """ + for subclass in recursive_subclasses(SubclassJSONSerializer): + if get_full_class_name(subclass) == data["type"]: + return subclass._from_json(data) + + raise ValueError("Unknown type {}".format(data["type"])) \ No newline at end of file diff --git a/src/random_events/variables.py b/src/random_events/variables.py index 78e0230..43741a9 100644 --- a/src/random_events/variables.py +++ b/src/random_events/variables.py @@ -8,7 +8,7 @@ AssignmentType = Union[portion.Interval, Tuple] -class Variable: +class Variable(utils.SubclassJSONSerializer): """ Abstract base class for all variables. """ @@ -101,20 +101,6 @@ def _from_json(cls, data: Dict[str, Any]) -> 'Variable': """ return cls(name=data["name"], domain=data["domain"]) - @classmethod - def from_json(cls, data: Dict[str, Any]) -> 'Variable': - """ - Create the correct instanceof the subclass from a json dict. - - :param data: The json dict - :return: The correct instance of the subclass - """ - for subclass in utils.recursive_subclasses(Variable): - if utils.get_full_class_name(subclass) == data["type"]: - return subclass._from_json(data) - - raise ValueError("Unknown type for variable. Type is {}".format(data["type"])) - def complement_of_assignment(self, assignment: AssignmentType, encoded: bool = False) -> AssignmentType: """ Returns the complement of the assignment for the variable. @@ -153,6 +139,18 @@ def union_of_assignments(assignment1: AssignmentType, """ raise NotImplementedError + def assignment_to_json(self, assignment: AssignmentType) -> Any: + """ + Convert an assignment to a json serializable object. + """ + raise NotImplementedError + + def assignment_from_json(self, data: Any) -> AssignmentType: + """ + Convert an assignment from a json serializable object. + """ + raise NotImplementedError + class Continuous(Variable): """ @@ -187,6 +185,12 @@ def union_of_assignments(assignment1: portion.Interval, encoded: bool = False) -> portion.Interval: return assignment1 | assignment2 + def assignment_to_json(self, assignment: portion.Interval) -> Any: + return portion.to_data(assignment) + + def assignment_from_json(self, data: Any) -> portion.Interval: + return portion.from_data(data) + class Discrete(Variable): """ @@ -252,6 +256,12 @@ def union_of_assignments(assignment1: Tuple, encoded: bool = False) -> Tuple: return tuple(sorted(set(assignment1) | set(assignment2))) + def assignment_to_json(self, assignment: Tuple) -> Tuple: + return assignment + + def assignment_from_json(self, data: Any) -> AssignmentType: + return tuple(data) + class Symbolic(Discrete): """ diff --git a/test/test_events.py b/test/test_events.py index ed25c67..820cc2e 100644 --- a/test/test_events.py +++ b/test/test_events.py @@ -180,6 +180,17 @@ def test_raises_on_operation_with_different_types(self): with self.assertRaises(TypeError): self.event - self.event.encode() + def test_serialization(self): + json = self.event.to_json() + event = Event.from_json(json) + self.assertEqual(event, self.event) + + def test_serialization_with_complex_interval(self): + event = Event({self.real: portion.closed(0, 1) | portion.closed(2, 3)}) + json = event.to_json() + event_ = Event.from_json(json) + self.assertEqual(event_, event) + class EncodedEventTestCase(unittest.TestCase): @@ -244,6 +255,16 @@ def test_intersection_with_empty(self): self.assertIn(self.integer, intersection.keys()) self.assertTrue(intersection.is_empty()) + def test_serialization(self): + event = EncodedEvent() + event[self.integer] = (1, 2) + event[self.symbol] = {1, 0} + event[self.real] = portion.open(0, 1) + + json = event.to_json() + event_ = EncodedEvent.from_json(json) + self.assertEqual(event, event_) + class ComplexEventTestCase(unittest.TestCase): @@ -393,6 +414,12 @@ def test_merge_if_1d(self): self.assertEqual(len(merged.events), 1) self.assertEqual(merged.events[0][self.x], portion.closed(0, 1) | portion.closed(3, 4)) + def test_serialization(self): + event = Event({self.x: portion.closed(0, 1), self.y: portion.closed(0, 1)}) + complement = event.complement() + json = complement.to_json() + complement_ = ComplexEvent.from_json(json) + self.assertEqual(complement, complement_) class PlottingTestCase(unittest.TestCase):