Skip to content

Commit

Permalink
Decreased serialization size of events.
Browse files Browse the repository at this point in the history
  • Loading branch information
tomsch420 committed Jun 27, 2024
1 parent 47ac61f commit 3c0023d
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 5 deletions.
36 changes: 32 additions & 4 deletions src/random_events/product_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,15 +180,31 @@ def marginal(self, variables: VariableSet) -> SimpleEvent:
def non_empty_to_string(self) -> str:
return "{" + ", ".join(f"{variable.name} = {assignment}" for variable, assignment in self.items()) + "}"

def variables_to_json(self) -> List:
return [variable.to_json() for variable in self.keys()]

def assignments_to_json(self) -> List:
return [assignment.to_json() for assignment in self.values()]

def to_json(self) -> Dict[str, Any]:
return {**super().to_json(),
"assignments": [(variable.to_json(), assignment.to_json()) for variable, assignment in self.items()]}
"variables": self.variables_to_json(),
"assignments": self.assignments_to_json()}

def to_json_assignments_only(self) -> Dict[str, Any]:
return {**super().to_json(),
"assignments": self.assignments_to_json()}

@classmethod
def _from_json(cls, data: Dict[str, Any]) -> Self:
return cls(
{Variable.from_json(variable): AbstractCompositeSet.from_json(assignment) for variable, assignment in
data["assignments"]})
variables = [Variable.from_json(variable) for variable in data["variables"]]
assignments = [AbstractCompositeSet.from_json(assignment) for assignment in data["assignments"]]
return cls({variable: assignment for variable, assignment in zip(variables, assignments)})

@classmethod
def from_json_given_variables(cls, data: Dict[str, Any], variables: List[Variable]) -> Self:
assignments = [AbstractCompositeSet.from_json(assignment) for assignment in data["assignments"]]
return cls({variable: assignment for variable, assignment in zip(variables, assignments)})

def plot(self) -> Union[List[go.Scatter], List[go.Mesh3d]]:
"""
Expand Down Expand Up @@ -446,6 +462,18 @@ def add_simple_set(self, simple_set: AbstractSimpleSet):
super().add_simple_set(simple_set)
self.fill_missing_variables()

def to_json(self) -> Dict[str, Any]:
variables = [variable.to_json() for variable in self.all_variables]
simple_sets = [simple_set.to_json_assignments_only() for simple_set in self.simple_sets]
return {**SubclassJSONSerializer.to_json(self),
"variables": variables, "simple_sets": simple_sets}

@classmethod
def _from_json(cls, data: Dict[str, Any]) -> Self:
variables = [Variable.from_json(variable) for variable in data["variables"]]
simple_sets = [SimpleEvent.from_json_given_variables(simple_set, variables) for simple_set in data["simple_sets"]]
return cls(*simple_sets)


# Type definitions
if TYPE_CHECKING:
Expand Down
7 changes: 6 additions & 1 deletion test/test_product_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def test_union(self):
second_event = SimpleEvent({self.a: Set(TestEnum.A, TestEnum.B), self.x: open(1, 4)}).as_composite_set()
union = event | second_event
result = Event(SimpleEvent({self.a: TestEnum.A, self.x: open(-float("inf"), 4)}),
SimpleEvent({self.a: TestEnum.B, self.x: open(1, 4)}))
SimpleEvent({self.a: TestEnum.B, self.x: open(1, 4)}))
self.assertEqual(union, result)

def test_marginal_event(self):
Expand All @@ -109,6 +109,11 @@ def test_marginal_event(self):
fig = go.Figure(marginal.plot())
# fig.show()

def test_to_json_multiple_events(self):
event = SimpleEvent({self.x: closed(0, 1), self.y: SimpleInterval(3, 5),
self.a: Set(TestEnum.A, TestEnum.B)}).as_composite_set()
event_ = AbstractSimpleSet.from_json(event.to_json())
self.assertEqual(event_, event)

if __name__ == '__main__':
unittest.main()

0 comments on commit 3c0023d

Please sign in to comment.