Skip to content

Commit

Permalink
Fixed bug in simplify method
Browse files Browse the repository at this point in the history
  • Loading branch information
tomsch420 committed Mar 28, 2024
1 parent fe32ed1 commit 747f3c4
Show file tree
Hide file tree
Showing 8 changed files with 265 additions and 245 deletions.
53 changes: 23 additions & 30 deletions examples/door.ipynb

Large diffs are not rendered by default.

96 changes: 50 additions & 46 deletions examples/example.ipynb

Large diffs are not rendered by default.

106 changes: 51 additions & 55 deletions examples/logo_generation.ipynb

Large diffs are not rendered by default.

174 changes: 87 additions & 87 deletions examples/product_spaces.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion src/random_events/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '2.0.9'
__version__ = '2.0.10'
64 changes: 41 additions & 23 deletions src/random_events/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,13 @@ def _from_json(cls, data: Dict[str, Any]) -> Self:
result[variable] = assignment
return result

def get_variables_where_assignment_is_different(self, other: Self) -> List[Variable]:
"""
Get all variables where the assignment is different from the other event's assignment
"""
return [variable for variable in self.keys() if self[variable] != other[variable]]



class EncodedEvent(Event):
"""
Expand Down Expand Up @@ -605,33 +612,44 @@ def simplify(self) -> Self:
"""
Simplify the complex event such that sub-unions of events that can be expressed as a single events
are merged.
This is done by seeking events that are equal in all but one dimensions to another event and merging them.
"""

if len(self.variables) == 1:
return self.merge_if_one_dimensional()

# for every pair of events
# for every unique pair of events
for index, event in enumerate(self.events):
for other_event in self.events[index + 1:]:

# for every variable in the event
for variable, value in event.items():

# if the events match in this dimension
if other_event[variable] == value:

# form the simpler union of the two events
unified_event = Event({variable: value})
for variable_ in event.keys():
if variable_ != variable:
unified_event[variable_] = variable.union_of_assignments(event[variable_],
other_event[variable_])
# recurse into the simpler complex event
result = ComplexEvent([])
result.events.append(unified_event)
result.events.extend([event__ for event__ in self.events if event__ != event
and event__ != other_event])
return result.simplify()
for other_index, other_event in enumerate(self.events[index + 1:]):
other_index += index + 1

# get the different variables
different_variables = event.get_variables_where_assignment_is_different(other_event)

# if they are the same event
if len(different_variables) == 0:
# recurse into the simpler complex event
result = ComplexEvent([])
result.events.extend([event_ for index_, event_ in enumerate(self.events)
if index_ != other_index])
return result.simplify()

# if they differ in only one dimension
if len(different_variables) == 1:
mismatching_variable = different_variables[0]

unified_event = event.__copy__()
unified_event[mismatching_variable] = (mismatching_variable.
union_of_assignments(event[mismatching_variable],
other_event[mismatching_variable]))

# recurse into the simpler complex event
result = ComplexEvent([])
result.events.append(unified_event)
result.events.extend([event_ for index_, event_ in enumerate(self.events)
if index_ not in [index, other_index]])
return result.simplify()

# if no simplification is possible, return the current complex event
return self.__copy__()
Expand All @@ -640,7 +658,7 @@ def intersection(self, other: EventType) -> Self:
if isinstance(other, Event):
return self.intersection(ComplexEvent([other]))
intersections = [event.intersection(other_event) for other_event in other.events for event in self.events]
return ComplexEvent(intersections)
return ComplexEvent(intersections).simplify()

def difference(self, other: EventType) -> Self:
if isinstance(other, Event):
Expand All @@ -652,7 +670,7 @@ def complement(self) -> Self:
for event in self.events[1:]:
current_complement = event.complement()
result = result.intersection(current_complement)
return result.make_events_disjoint() # .simplify()
return result.make_events_disjoint().simplify()

def are_events_disjoint(self) -> bool:
"""
Expand Down
14 changes: 12 additions & 2 deletions test/test_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,9 +290,7 @@ def test_union_of_complex_events(self):
complex_event_1 = event_1.union(event_2)
event_3 = Event({self.x: portion.closed(0.5, 2), self.y: portion.closed(-0.5, 2)})
complex_event_2 = event_1.union(event_3)

result = complex_event_1.union(complex_event_2)

self.assertIsInstance(result, ComplexEvent)
self.assertTrue(result.are_events_disjoint())

Expand Down Expand Up @@ -431,6 +429,18 @@ def test_intersection_symbol_and_real(self):
self.assertEqual(event_[self.x], portion.closed(0, 1))
self.assertEqual(event_[self.a], (0, ))

def test_no_simplify_high_dimensions(self):
x = Continuous("x")
y = Continuous("y")
z = Continuous("z")

event_1 = Event({x: portion.closed(0, 1), y: portion.closed(0, 1), z: portion.closed(0, 1)})
event_2 = Event({x: portion.closed(0, 1), y: portion.closed(0.5, 1.5), z: portion.closed(0.5, 1.5)})
complex_event = ComplexEvent([event_1, event_2])
simplified = complex_event.simplify()
self.assertEqual(len(simplified.events), 2)



class PlottingTestCase(unittest.TestCase):
x: Continuous = Continuous("x")
Expand Down
1 change: 0 additions & 1 deletion test/test_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,5 @@ def test_complement_of_assignment(self):
self.assertEqual(self.symbol.complement_of_assignment(("a",)), ("b", "c", ))



if __name__ == '__main__':
unittest.main()

0 comments on commit 747f3c4

Please sign in to comment.