From ffec3a46937bafa3253f387d37d7bbb503c53da0 Mon Sep 17 00:00:00 2001 From: Tom Schierenbeck Date: Fri, 15 Mar 2024 15:24:54 +0100 Subject: [PATCH] Fixed bug of empty intersections. --- src/random_events/__init__.py | 2 +- src/random_events/events.py | 3 --- test/test_events.py | 7 +++++++ test/test_variables.py | 1 + 4 files changed, 9 insertions(+), 4 deletions(-) diff --git a/src/random_events/__init__.py b/src/random_events/__init__.py index 3f39079..668c344 100644 --- a/src/random_events/__init__.py +++ b/src/random_events/__init__.py @@ -1 +1 @@ -__version__ = '2.0.1' +__version__ = '2.0.2' diff --git a/src/random_events/events.py b/src/random_events/events.py index be34780..3914eba 100644 --- a/src/random_events/events.py +++ b/src/random_events/events.py @@ -135,13 +135,10 @@ def intersection(self, other: EventType) -> EventType: result = self.__class__() variables = set(self.keys()) | set(other.keys()) - for variable in variables: assignment1 = self.get(variable, variable.domain) assignment2 = other.get(variable, variable.domain) intersection = variable.intersection_of_assignments(assignment1, assignment2) - if len(intersection) == 0: - return self.__class__() result[variable] = intersection return result diff --git a/test/test_events.py b/test/test_events.py index 011ea78..848d6c6 100644 --- a/test/test_events.py +++ b/test/test_events.py @@ -237,6 +237,13 @@ def test_set_operations_return_type(self): self.assertEqual(type(event | event), ComplexEvent) self.assertEqual(type(event - event), ComplexEvent) + def test_intersection_with_empty(self): + event = Event({self.integer: ()}) + complete_event = Event({self.integer: self.integer.domain}) + intersection = event.intersection(complete_event) + self.assertIn(self.integer, intersection.keys()) + self.assertTrue(intersection.is_empty()) + class ComplexEventTestCase(unittest.TestCase): diff --git a/test/test_variables.py b/test/test_variables.py index e06a50f..869d638 100644 --- a/test/test_variables.py +++ b/test/test_variables.py @@ -101,5 +101,6 @@ def test_complement_of_assignment(self): self.assertEqual(self.symbol.complement_of_assignment(("a",)), ("b", "c", )) + if __name__ == '__main__': unittest.main()