From 1403f7409decdbc3b851cf3e04ddc135bcff292e Mon Sep 17 00:00:00 2001 From: Tom Schierenbeck Date: Fri, 15 Mar 2024 16:55:55 +0100 Subject: [PATCH] Added marginal views of events. --- src/random_events/__init__.py | 2 +- src/random_events/events.py | 30 ++++++++++++++++++++++++++++++ test/test_events.py | 16 ++++++++++++++++ 3 files changed, 47 insertions(+), 1 deletion(-) diff --git a/src/random_events/__init__.py b/src/random_events/__init__.py index f593cd5..4c354e0 100644 --- a/src/random_events/__init__.py +++ b/src/random_events/__init__.py @@ -1 +1 @@ -__version__ = '2.0.4' +__version__ = '2.0.5' diff --git a/src/random_events/events.py b/src/random_events/events.py index ad62c94..63b3349 100644 --- a/src/random_events/events.py +++ b/src/random_events/events.py @@ -428,6 +428,13 @@ def decode(self): """ return self.__copy__() + def marginal_event(self, variables: Iterable[Variable]) -> Self: + """ + Get the marginal event of this event with respect to a variable. + """ + return self.__class__({variable: self[variable] for variable in variables if variable in self}) + + class EncodedEvent(Event): """ A map of variables to indices of their respective domains. @@ -476,6 +483,7 @@ def encode(self) -> Self: return self.__copy__() + class ComplexEvent(SupportsSetOperations): """ A complex event is a set of mutually exclusive events. @@ -577,6 +585,9 @@ def simplify(self) -> Self: are merged. """ + if len(self.variables) == 1: + return self.merge_if_one_dimensional() + # for every pair of events for index, event in enumerate(self.events): for other_event in self.events[index + 1:]: @@ -678,5 +689,24 @@ def decode(self) -> ComplexEvent: """ return ComplexEvent([event.decode() for event in self.events]) + def marginal_event(self, variables: Iterable[Variable]) -> Self: + """ + Get the marginal event of this complex event with respect to a variable. + """ + return ComplexEvent([event.marginal_event(variables) for event in self.events]).simplify() + + def merge_if_one_dimensional(self) -> Self: + """ + Merge all events into a single event if they are all one-dimensional. + """ + if not len(self.variables) == 1: + return self + variable = self.variables[0] + value = self.events[0][variable] + + for event in self.events[1:]: + value = variable.union_of_assignments(value, event[variable]) + return ComplexEvent([Event({variable: value})]) + EventType = Union[Event, EncodedEvent, ComplexEvent] diff --git a/test/test_events.py b/test/test_events.py index bb2de7b..ed25c67 100644 --- a/test/test_events.py +++ b/test/test_events.py @@ -378,6 +378,22 @@ def test_decode_encode(self): decoded = encoded.decode() self.assertEqual(event, decoded) + def test_marginal_event(self): + event = Event({self.x: portion.closed(0, 1), self.y: portion.closed(0, 1)}) + complement = event.complement() + marginal_event = complement.marginal_event([self.x]) + self.assertEqual(len(marginal_event.events), 1) + self.assertEqual(marginal_event.events[0][self.x], portion.open(-portion.inf, portion.inf)) + + def test_merge_if_1d(self): + event1 = Event({self.x: portion.closed(0, 1)}) + event2 = Event({self.x: portion.closed(3, 4)}) + complex_event = ComplexEvent([event1, event2]) + merged = complex_event.merge_if_one_dimensional() + self.assertEqual(len(merged.events), 1) + self.assertEqual(merged.events[0][self.x], portion.closed(0, 1) | portion.closed(3, 4)) + + class PlottingTestCase(unittest.TestCase): x: Continuous = Continuous("x")