From 33ed6bd80e641b200d9e094665c2bb280ae37033 Mon Sep 17 00:00:00 2001 From: Tom Schierenbeck Date: Thu, 26 Oct 2023 15:00:37 +0200 Subject: [PATCH] EncodedEvents are now more forgiving and tested. --- src/random_events/__init__.py | 2 +- src/random_events/events.py | 46 ++++++++++--------------------- test/test_events.py | 51 +++++++++++++++++++++++++++++++++++ 3 files changed, 66 insertions(+), 33 deletions(-) diff --git a/src/random_events/__init__.py b/src/random_events/__init__.py index 1a72d32..b3ddbc4 100644 --- a/src/random_events/__init__.py +++ b/src/random_events/__init__.py @@ -1 +1 @@ -__version__ = '1.1.0' +__version__ = '1.1.1' diff --git a/src/random_events/events.py b/src/random_events/events.py index 2fde256..76699c8 100644 --- a/src/random_events/events.py +++ b/src/random_events/events.py @@ -319,44 +319,26 @@ class EncodedEvent(Event): @staticmethod def check_element(variable: Variable, element: Any) -> Union[tuple, portion.Interval]: - if isinstance(element, Iterable) and not isinstance(element, (str, portion.Interval)): - element = tuple(element) - - # if the element is already wrapped - if isinstance(element, (tuple, portion.Interval)): - - # check that the element is in the variable's domain - if isinstance(variable, Discrete): - if not all(0 <= elem < len(variable.domain) for elem in element): - # raise an error - raise ValueError(f"Element {element} not in domain {variable.domain}") - - element = tuple(sorted(element)) - # return the element directly - return element - - # if the variable is continuous + # if the variable is continuous, don't process the element if isinstance(variable, Continuous): + return element - if element not in variable.domain: - # raise an error - raise ValueError(f"Element {element} not in domain {variable.domain}") + # if its any kind of iterable that's not an interval convert it to a tuple + if isinstance(element, Iterable) and not isinstance(element, portion.Interval): + element = tuple(sorted(element)) - # return the element as a singleton interval - return portion.singleton(element) + # if it is just an int, convert it to a tuple containing the int + elif isinstance(element, int): + element = (element, ) - # if the variable is discrete - elif isinstance(variable, Discrete): + if not isinstance(element, tuple): + raise ValueError("Element for a discrete domain must be a tuple, not {}".format(type(element))) - # if the element is not in the variables' domain - if 0 <= element < len(variable.domain): - # raise an error - raise ValueError(f"Element {element} not in domain {variable.domain}") + # if any element is not in the index set of the domain, raise an error + if not all(0 <= elem < len(variable.domain) for elem in element): + raise ValueError(f"Element {element} not in the index set of the domain {variable.domain}") - # return the element as a set - return (element,) - else: - raise TypeError(f"Unknown variable type {type(variable)}") + return element def decode(self) -> Event: """ diff --git a/test/test_events.py b/test/test_events.py index ee33e4c..c8b7d10 100644 --- a/test/test_events.py +++ b/test/test_events.py @@ -185,5 +185,56 @@ def test_equality(self): self.assertNotEqual(self.event, Event()) +class EncodedEventTestCase(unittest.TestCase): + + integer: Integer + symbol: Symbolic + real: Continuous + + @classmethod + def setUpClass(cls): + """ + Create some event for testing. + """ + cls.integer = Integer("integer", set(range(10))) + cls.symbol = Symbolic("symbol", {"a", "b", "c"}) + cls.real = Continuous("real") + + def test_creation(self): + event = EncodedEvent() + event[self.integer] = 1 + self.assertEqual(event[self.integer], (1,)) + event[self.integer] = (1, 2) + self.assertEqual(event[self.integer], (1, 2)) + event[self.symbol] = 0 + self.assertEqual(event[self.symbol], (0, )) + event[self.symbol] = {1, 0} + self.assertEqual(event[self.symbol], (0, 1)) + + interval = portion.open(0, 1) + event[self.real] = interval + self.assertEqual(interval, event[self.real]) + + def test_raises(self): + event = EncodedEvent() + with self.assertRaises(ValueError): + event[self.symbol] = 3 + + with self.assertRaises(ValueError): + event[self.symbol] = portion.open(0, 1) + + with self.assertRaises(ValueError): + event[self.symbol] = (1, 2, 3, 4) + + def test_dict_like_creation(self): + event = EncodedEvent(zip([self.integer, self.symbol], [1, 0])) + self.assertEqual(event[self.integer], (1,)) + self.assertEqual(event[self.symbol], (0,)) + + event = EncodedEvent(zip([self.integer, self.symbol], [[0, 1], 0])) + self.assertEqual(event[self.integer], (0, 1)) + self.assertEqual(event[self.symbol], (0,)) + + if __name__ == '__main__': unittest.main()