Skip to content

Commit

Permalink
Added 1d plotting and marginalization of events.
Browse files Browse the repository at this point in the history
  • Loading branch information
tomsch420 committed Jun 13, 2024
1 parent a5fc1a6 commit fa21af3
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 2 deletions.
46 changes: 44 additions & 2 deletions src/random_events/product_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ def __init__(self, *args, **kwargs):
for key, value in self.items():
self[key] = value


def as_composite_set(self) -> Event:
return Event(self)

Expand Down Expand Up @@ -166,6 +165,18 @@ def __lt__(self, other: Self):
else:
return self[variable] < other[variable]

def marginal(self, variables: VariableSet) -> SimpleEvent:
"""
Create the marginal event, that only contains the variables given..
:param variables: The variables to contain in the marginal event
:return: The marginal event
"""
result = self.__class__()
for variable in variables:
result[variable] = self[variable]
return result

def non_empty_to_string(self) -> str:
return "{" + ", ".join(f"{variable.name} = {assignment}" for variable, assignment in self.items()) + "}"

Expand All @@ -185,13 +196,30 @@ def plot(self) -> Union[List[go.Scatter], List[go.Mesh3d]]:
"""
assert all(isinstance(variable, Continuous) for variable in self.keys()), \
"Plotting is only supported for events that consist of only continuous variables."
if len(self.keys()) == 1:
return self.plot_1d()
if len(self.keys()) == 2:
return self.plot_2d()
elif len(self.keys()) == 3:
return self.plot_3d()
else:
raise NotImplementedError("Plotting is only supported for two and three dimensional events")

def plot_1d(self) -> List[go.Scatter]:
"""
Plot the event in 1D.
"""
xs = []
ys = []

interval: Interval = list(self.values())[0]
for simple_interval in interval.simple_sets:
simple_interval: SimpleInterval
xs.extend([simple_interval.lower, simple_interval.upper, None])
ys.extend([0, 0, None])

return [go.Scatter(x=xs, y=ys, mode="lines", name="Event", fill="toself")]

def plot_2d(self) -> List[go.Scatter]:
"""
Plot the event in 2D.
Expand Down Expand Up @@ -253,7 +281,9 @@ def plotly_layout(self) -> Dict:
"""
Create a layout for the plotly plot.
"""
if len(self.variables) == 2:
if len(self.variables) == 1:
result = {"xaxis_title": self.variables[0].name}
elif len(self.variables) == 2:
result = {"xaxis_title": self.variables[0].name,
"yaxis_title": self.variables[1].name}
elif len(self.variables) == 3:
Expand Down Expand Up @@ -369,6 +399,18 @@ def new_empty_set(self) -> Self:
def complement_if_empty(self) -> Self:
raise NotImplementedError("Complement of an empty Event is not yet supported.")

def marginal(self, variables: VariableSet) -> Event:
"""
Create the marginal event, that only contains the variables given..
:param variables: The variables to contain in the marginal event
:return: The marginal event
"""
result = self.__class__()
for simple_set in self.simple_sets:
result.add_simple_set(simple_set.marginal(variables))
return result.make_disjoint()

def plot(self, color="#636EFA") -> Union[List[go.Scatter], List[go.Mesh3d]]:
"""
Plot the complex event.
Expand Down
3 changes: 3 additions & 0 deletions src/random_events/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ def __init__(self, name: str, domain: Union[Type[SetElement], SetElement]):
else:
super().__init__(name, domain)

def domain_type(self) -> Type[SetElement]:
return self.domain.simple_sets[0].all_elements


class Integer(Variable):
"""
Expand Down
10 changes: 10 additions & 0 deletions test/test_product_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,16 @@ def test_union(self):
SimpleEvent({self.a: TestEnum.B, self.x: open(1, 4)}))
self.assertEqual(union, result)

def test_marginal_event(self):
event_1 = SimpleEvent({self.x: closed(0, 1), self.y: SimpleInterval(0, 1)})
event_2 = SimpleEvent({self.x: closed(1, 2), self.y: Interval(SimpleInterval(3, 4))})
event_3 = SimpleEvent({self.x: closed(5, 6), self.y: Interval(SimpleInterval(5, 6))})
event = Event(event_1, event_2, event_3)
marginal = event.marginal(SortedSet([self.x]))
self.assertEqual(marginal, SimpleEvent({self.x: closed(0, 2) | closed(5, 6)}).as_composite_set())
fig = go.Figure(marginal.plot())
# fig.show()


if __name__ == '__main__':
unittest.main()

0 comments on commit fa21af3

Please sign in to comment.