Skip to content

Commit

Permalink
Completely working, just have to update the documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
tomsch420 committed May 31, 2024
1 parent 0ecce90 commit f65ea4a
Show file tree
Hide file tree
Showing 10 changed files with 2,805 additions and 1,833 deletions.
82 changes: 63 additions & 19 deletions examples/door.ipynb

Large diffs are not rendered by default.

430 changes: 239 additions & 191 deletions examples/example.ipynb

Large diffs are not rendered by default.

2,575 changes: 1,945 additions & 630 deletions examples/logo_generation.ipynb

Large diffs are not rendered by default.

333 changes: 180 additions & 153 deletions examples/product_spaces.ipynb

Large diffs are not rendered by default.

877 changes: 65 additions & 812 deletions examples/self_assessment.ipynb

Large diffs are not rendered by default.

62 changes: 62 additions & 0 deletions src/random_events/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,11 @@ def to_json(self) -> Dict[str, Any]:
def _from_json(cls, data: Dict[str, Any]) -> Self:
return cls(data['lower'], data['upper'], Bound[data['left']], Bound[data['right']])

def center(self) -> float:
"""
:return: The center point of the interval
"""
return ((self.lower + self.upper) / 2) + self.lower


class Interval(sigma_algebra.AbstractCompositeSet):
Expand Down Expand Up @@ -188,3 +193,60 @@ def new_empty_set(self) -> Self:

def complement_if_empty(self) -> Self:
return Interval([SimpleInterval(float('-inf'), float('inf'), Bound.OPEN, Bound.OPEN)])


def open(left: float, right: float) -> Interval:
"""
Creates an open interval.
:param left: The left bound of the interval.
:param right: The right bound of the interval.
:return: The open interval.
"""
return Interval([SimpleInterval(left, right, Bound.OPEN, Bound.OPEN)])


def closed(left: float, right: float) -> Interval:
"""
Creates a closed interval.
:param left: The left bound of the interval.
:param right: The right bound of the interval.
:return: The closed interval.
"""
return Interval([SimpleInterval(left, right, Bound.CLOSED, Bound.CLOSED)])


def open_closed(left: float, right: float) -> Interval:
"""
Creates an open-closed interval.
:param left: The left bound of the interval.
:param right: The right bound of the interval.
:return: The open-closed interval.
"""
return Interval([SimpleInterval(left, right, Bound.OPEN, Bound.CLOSED)])


def closed_open(left: float, right: float) -> Interval:
"""
Creates a closed-open interval.
:param left: The left bound of the interval.
:param right: The right bound of the interval.
:return: The closed-open interval.
"""
return Interval([SimpleInterval(left, right, Bound.CLOSED, Bound.OPEN)])


def singleton(value: float) -> Interval:
"""
Creates a singleton interval.
:param value: The value of the interval.
:return: The singleton interval.
"""
return Interval([SimpleInterval(value, value, Bound.CLOSED, Bound.CLOSED)])


def reals() -> Interval:
"""
Creates the set of real numbers.
:return: The set of real numbers.
"""
return Interval([SimpleInterval(float('-inf'), float('inf'), Bound.OPEN, Bound.OPEN)])
169 changes: 166 additions & 3 deletions src/random_events/product_algebra.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import numpy as np
from sortedcontainers import SortedDict, SortedKeysView, SortedValuesView
from typing_extensions import Union, Any
from typing_extensions import List
import plotly.graph_objects as go

from .variable import *
from .sigma_algebra import *
from .variable import *
from .variable import Variable


Expand Down Expand Up @@ -139,6 +141,115 @@ def __lt__(self, other: Self):
def non_empty_to_string(self) -> str:
return "{" + ", ".join(f"{variable.name} = {assignment}" for variable, assignment in self.items()) + "}"

def to_json(self) -> Dict[str, Any]:
return {**super().to_json(),
"assignments": [(variable.to_json(), assignment.to_json()) for variable, assignment in self.items()]}

@classmethod
def _from_json(cls, data: Dict[str, Any]) -> Self:
return cls(
{Variable.from_json(variable): AbstractCompositeSet.from_json(assignment) for variable, assignment in
data["assignments"]})

def plot(self) -> Union[List[go.Scatter], List[go.Mesh3d]]:
"""
Plot the event.
"""
assert all(isinstance(variable, Continuous) for variable in self.keys()), \
"Plotting is only supported for events that consist of only continuous variables."
if len(self.keys()) == 2:
return self.plot_2d()
elif len(self.keys()) == 3:
return self.plot_3d()
else:
raise NotImplementedError("Plotting is only supported for two and three dimensional events")

def plot_2d(self) -> List[go.Scatter]:
"""
Plot the event in 2D.
"""

# form cartesian product of all intervals
intervals = [value.simple_sets for value in self.values()]
interval_combinations = list(itertools.product(*intervals))

xs = []
ys = []

# for every atomic interval
for interval_combination in interval_combinations:

# plot a rectangle
points = np.asarray(list(itertools.product(*[[axis.lower, axis.upper] for axis in interval_combination])))
y_points = points[:, 1]
y_points[len(y_points) // 2:] = y_points[len(y_points) // 2:][::-1]
xs.extend(points[:, 0].tolist() + [points[0, 0], None])
ys.extend(y_points.tolist()+ [y_points[0], None])

return [go.Scatter(x=xs, y=ys, mode="lines", name="Event", fill="toself")]

def plot_3d(self) -> List[go.Mesh3d]:
"""
Plot the event in 3D.
"""

# form cartesian product of all intervals
intervals = [value.simple_sets for _, value in sorted(self.items())]
simple_events = list(itertools.product(*intervals))
traces = []

# shortcut for the dimensions
x, y, z = 0, 1, 2

# for every atomic interval
for simple_event in simple_events:

# Create a 3D mesh trace for the rectangle
traces.append(go.Mesh3d(
# 8 vertices of a cube
x=[simple_event[x].lower, simple_event[x].lower, simple_event[x].upper, simple_event[x].upper,
simple_event[x].lower, simple_event[x].lower, simple_event[x].upper, simple_event[x].upper],
y=[simple_event[y].lower, simple_event[y].upper, simple_event[y].upper, simple_event[y].lower,
simple_event[y].lower, simple_event[y].upper, simple_event[y].upper, simple_event[y].lower],
z=[simple_event[z].lower, simple_event[z].lower, simple_event[z].lower, simple_event[z].lower,
simple_event[z].upper, simple_event[z].upper, simple_event[z].upper, simple_event[z].upper],
# i, j and k give the vertices of triangles
i=[7, 0, 0, 0, 4, 4, 6, 6, 4, 0, 3, 2],
j=[3, 4, 1, 2, 5, 6, 5, 2, 0, 1, 6, 3],
k=[0, 7, 2, 3, 6, 7, 1, 1, 5, 5, 7, 6],
flatshading=True
))
return traces

def plotly_layout(self) -> Dict:
"""
Create a layout for the plotly plot.
"""
if len(self.variables) == 2:
result = {"xaxis_title": self.variables[0].name,
"yaxis_title": self.variables[1].name}
elif len(self.variables) == 3:
result = dict(scene=dict(
xaxis_title=self.variables[0].name,
yaxis_title=self.variables[1].name,
zaxis_title=self.variables[2].name)
)
else:
raise NotImplementedError("Plotting is only supported for two and three dimensional events")

return result

def fill_missing_variables(self, variables: SortedSet[Variable]):
"""
Fill this with the variables that are not in self but in `variables`.
The variables are mapped to their domain.
:param variables: The variables to fill the event with
"""
for variable in variables:
if variable not in self:
self[variable] = variable.domain


class Event(AbstractCompositeSet):
"""
Expand All @@ -151,6 +262,23 @@ class Event(AbstractCompositeSet):

simple_sets: SortedSet[SimpleEvent]

def __init__(self, simple_sets: Iterable[SimpleEvent]):
super().__init__(simple_sets)
self.fill_missing_variables()

@property
def all_variables(self) -> SortedSet[Variable]:
result = SortedSet()
return result.union(*[SortedSet(simple_set.variables) for simple_set in self.simple_sets])

def fill_missing_variables(self):
"""
Fill all simple sets with the missing variables.
"""
all_variables = self.all_variables
for simple_set in self.simple_sets:
simple_set.fill_missing_variables(all_variables)

def simplify(self) -> Self:
simplified, changed = self.simplify_once()
while changed:
Expand Down Expand Up @@ -208,7 +336,42 @@ def simplify_once(self) -> Tuple[Self, bool]:
return self, False

def new_empty_set(self) -> Self:
return Event()
return Event([])

def complement_if_empty(self) -> Self:
raise NotImplementedError("Complement of an empty Event is not yet supported.")

def plot(self, color="#636EFA") -> Union[List[go.Scatter], List[go.Mesh3d]]:
"""
Plot the complex event.
:param color: The color to use for this event
"""
traces = []
show_legend = True
for index, event in enumerate(self.simple_sets):
event_traces = event.plot()
for event_trace in event_traces:
if len(event.keys()) == 2:
event_trace.update(name="Event", legendgroup=id(self), showlegend=show_legend,
line=dict(color=color))
if len(event.keys()) == 3:
event_trace.update(name="Event", legendgroup=id(self), showlegend=show_legend, color=color)
show_legend = False
traces.append(event_trace)
return traces

def plotly_layout(self) -> Dict:
"""
Create a layout for the plotly plot.
"""
return self.simple_sets[0].plotly_layout()

def add_simple_set(self, simple_set: AbstractSimpleSet):
"""
Add a simple set to this event.
:param simple_set: The simple set to add
"""
super().add_simple_set(simple_set)
self.fill_missing_variables()
59 changes: 37 additions & 22 deletions src/random_events/sigma_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from abc import abstractmethod
from typing import Tuple, Dict, Any

from typing_extensions import Self, Set, Iterable, Optional
from sortedcontainers import SortedSet
from typing_extensions import Self, Iterable, Optional

from .utils import SubclassJSONSerializer

Expand Down Expand Up @@ -84,7 +84,6 @@ def difference_with(self, other: Self) -> SortedSet[Self]:
# if it intersects with this set
intersection = element.intersection_with(self)
if not intersection.is_empty():

# add the intersection to the result
result.add(intersection)

Expand Down Expand Up @@ -173,8 +172,8 @@ def intersection_with_simple_sets(self, other: SortedSet[AbstractSimpleSet]) ->
:return: The intersection of this set with the set of simple sets
"""
result = self.new_empty_set()
[result.simple_sets.update(self.intersection_with_simple_set(other_simple_set).simple_sets)
for other_simple_set in other]
[result.simple_sets.update(self.intersection_with_simple_set(other_simple_set).simple_sets) for other_simple_set
in other]
return result

def intersection_with(self, other: Self) -> Self:
Expand Down Expand Up @@ -322,8 +321,8 @@ def split_into_disjoint_and_non_disjoint(self) -> Tuple[Self, Self]:
This method requires:
- the intersection of two simple sets as a simple set
- the difference of a simple set (A) and another simple set (B) that is completely contained in A (B ⊆ A).
The result of that difference has to be a composite set with only one simple set in it.
- the difference_of_a_with_every_b of a simple set (A) and another simple set (B) that is completely contained in A (B ⊆ A).
The result of that difference_of_a_with_every_b has to be a composite set with only one simple set in it.
:return: A tuple of the disjoint and non-disjoint set.
"""
Expand All @@ -334,43 +333,45 @@ def split_into_disjoint_and_non_disjoint(self) -> Tuple[Self, Self]:

# for every simple set (a)
for simple_set_a in self.simple_sets:
simple_set_a: AbstractSimpleSet

# initialize the difference of a with every b
difference = simple_set_a
difference_of_a_with_every_b: Optional[AbstractSimpleSet] = simple_set_a

# for every other simple set (b)
for simple_set_b in self.simple_sets:
simple_set_b: AbstractSimpleSet

# skip symmetric iterations
if simple_set_a == simple_set_b:
continue

# get the intersection of a and b
intersection = simple_set_a.intersection_with(simple_set_b)
intersection_a_b: AbstractSimpleSet = simple_set_a.intersection_with(simple_set_b)

# if the intersection is not empty add it to the non-disjoint set
non_disjoint.add_simple_set(intersection)
non_disjoint.add_simple_set(intersection_a_b)

# get the difference of the simple set with the intersection.
difference_with_intersection = difference.difference_with(intersection)
difference_with_intersection = difference_of_a_with_every_b.difference_with(intersection_a_b)

# if the difference is empty
# if the difference of a with every b is empty
if len(difference_with_intersection) == 0:
# skip the rest of the loop and mark the set for discarding
difference = None
difference_of_a_with_every_b = None
continue

# the now should contain only 1 element
assert len(difference_with_intersection) == 1
difference = difference_with_intersection[0]
# assert len(difference_with_intersection) == 1
difference_of_a_with_every_b = difference_of_a_with_every_b.difference_with(intersection_a_b)[0]

# if the difference has become None
if difference is None:
# if the difference_of_a_with_every_b has become None
if difference_of_a_with_every_b is None:
# skip the rest of the loop
continue

# append the simple_set_a without every other simple set to the disjoint set
disjoint.simple_sets.add(difference)
disjoint.simple_sets.add(difference_of_a_with_every_b)

return disjoint, non_disjoint

Expand Down Expand Up @@ -410,11 +411,25 @@ def __hash__(self):
return hash(tuple(self.simple_sets))

def __lt__(self, other: Self):
if self.is_empty():
return True
if other.is_empty():
return False
return self.simple_sets[0] < other.simple_sets[0]
"""
Compare this set with another set.
The sets are compared by comparing the simple sets in order.
If the pair of simple sets are equal, the next pair is compared.
If all pairs are equal, the set with the least amount of simple sets is considered smaller.
..note:: This does not define a total order in the mathematical sense. In the mathematical sense, this defines
a partial order.
:param other: The other set
:return: Rather this set is smaller than the other set
"""
for a, b in zip(self.simple_sets, other.simple_sets):
if a == b:
continue
else:
return a < b
return len(self.simple_sets) < len(other.simple_sets)

def to_json(self) -> Dict[str, Any]:
return {**super().to_json(), "simple_sets": [simple_set.to_json() for simple_set in self.simple_sets]}
Expand Down
Loading

0 comments on commit f65ea4a

Please sign in to comment.