Skip to content

Commit

Permalink
Added marginal views of events.
Browse files Browse the repository at this point in the history
  • Loading branch information
tomsch420 committed Mar 15, 2024
1 parent fc43377 commit 1403f74
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/random_events/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '2.0.4'
__version__ = '2.0.5'
30 changes: 30 additions & 0 deletions src/random_events/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,13 @@ def decode(self):
"""
return self.__copy__()

def marginal_event(self, variables: Iterable[Variable]) -> Self:
"""
Get the marginal event of this event with respect to a variable.
"""
return self.__class__({variable: self[variable] for variable in variables if variable in self})


class EncodedEvent(Event):
"""
A map of variables to indices of their respective domains.
Expand Down Expand Up @@ -476,6 +483,7 @@ def encode(self) -> Self:
return self.__copy__()



class ComplexEvent(SupportsSetOperations):
"""
A complex event is a set of mutually exclusive events.
Expand Down Expand Up @@ -577,6 +585,9 @@ def simplify(self) -> Self:
are merged.
"""

if len(self.variables) == 1:
return self.merge_if_one_dimensional()

# for every pair of events
for index, event in enumerate(self.events):
for other_event in self.events[index + 1:]:
Expand Down Expand Up @@ -678,5 +689,24 @@ def decode(self) -> ComplexEvent:
"""
return ComplexEvent([event.decode() for event in self.events])

def marginal_event(self, variables: Iterable[Variable]) -> Self:
"""
Get the marginal event of this complex event with respect to a variable.
"""
return ComplexEvent([event.marginal_event(variables) for event in self.events]).simplify()

def merge_if_one_dimensional(self) -> Self:
"""
Merge all events into a single event if they are all one-dimensional.
"""
if not len(self.variables) == 1:
return self
variable = self.variables[0]
value = self.events[0][variable]

for event in self.events[1:]:
value = variable.union_of_assignments(value, event[variable])
return ComplexEvent([Event({variable: value})])


EventType = Union[Event, EncodedEvent, ComplexEvent]
16 changes: 16 additions & 0 deletions test/test_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,22 @@ def test_decode_encode(self):
decoded = encoded.decode()
self.assertEqual(event, decoded)

def test_marginal_event(self):
event = Event({self.x: portion.closed(0, 1), self.y: portion.closed(0, 1)})
complement = event.complement()
marginal_event = complement.marginal_event([self.x])
self.assertEqual(len(marginal_event.events), 1)
self.assertEqual(marginal_event.events[0][self.x], portion.open(-portion.inf, portion.inf))

def test_merge_if_1d(self):
event1 = Event({self.x: portion.closed(0, 1)})
event2 = Event({self.x: portion.closed(3, 4)})
complex_event = ComplexEvent([event1, event2])
merged = complex_event.merge_if_one_dimensional()
self.assertEqual(len(merged.events), 1)
self.assertEqual(merged.events[0][self.x], portion.closed(0, 1) | portion.closed(3, 4))



class PlottingTestCase(unittest.TestCase):
x: Continuous = Continuous("x")
Expand Down

0 comments on commit 1403f74

Please sign in to comment.