From f1b1fd07c18c1dc4a87ed653c5b82d9fad514880 Mon Sep 17 00:00:00 2001 From: Tom Schierenbeck Date: Thu, 21 Mar 2024 15:49:47 +0100 Subject: [PATCH] Fixed bug in intersection of encoded events. --- src/random_events/__init__.py | 2 +- src/random_events/events.py | 4 ++-- src/random_events/variables.py | 4 ++++ test/test_events.py | 9 +++++++++ 4 files changed, 16 insertions(+), 3 deletions(-) diff --git a/src/random_events/__init__.py b/src/random_events/__init__.py index 13ce17d..4b259db 100644 --- a/src/random_events/__init__.py +++ b/src/random_events/__init__.py @@ -1 +1 @@ -__version__ = '2.0.6' +__version__ = '2.0.7' diff --git a/src/random_events/events.py b/src/random_events/events.py index d3233f2..ccd9a57 100644 --- a/src/random_events/events.py +++ b/src/random_events/events.py @@ -137,8 +137,8 @@ def intersection(self, other: EventType) -> EventType: variables = set(self.keys()) | set(other.keys()) for variable in variables: - assignment1 = self.get(variable, variable.domain) - assignment2 = other.get(variable, variable.domain) + assignment1 = self.get(variable, variable.encoded_domain if isinstance(self, EncodedEvent) else variable.domain) + assignment2 = other.get(variable, variable.encoded_domain if isinstance(self, EncodedEvent) else variable.domain) intersection = variable.intersection_of_assignments(assignment1, assignment2) result[variable] = intersection diff --git a/src/random_events/variables.py b/src/random_events/variables.py index 43741a9..3991513 100644 --- a/src/random_events/variables.py +++ b/src/random_events/variables.py @@ -151,6 +151,10 @@ def assignment_from_json(self, data: Any) -> AssignmentType: """ raise NotImplementedError + @property + def encoded_domain(self): + return self.encode_many(self.domain) + class Continuous(Variable): """ diff --git a/test/test_events.py b/test/test_events.py index 820cc2e..4b1dcb7 100644 --- a/test/test_events.py +++ b/test/test_events.py @@ -421,6 +421,15 @@ def test_serialization(self): complement_ = ComplexEvent.from_json(json) self.assertEqual(complement, complement_) + def test_intersection_symbol_and_real(self): + event = ComplexEvent([EncodedEvent({self.x: portion.closed(0, 1)})]) + event2 = EncodedEvent({self.a: (0, )}) + result = event & event2 + self.assertEqual(len(result.events), 1) + event_ = result.events[0] + self.assertEqual(event_[self.x], portion.closed(0, 1)) + self.assertEqual(event_[self.a], (0, )) + class PlottingTestCase(unittest.TestCase): x: Continuous = Continuous("x")