Skip to content

Commit

Permalink
EncodedEvents are now more forgiving and tested.
Browse files Browse the repository at this point in the history
  • Loading branch information
tomsch420 committed Oct 26, 2023
1 parent 0bd35d3 commit 33ed6bd
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 33 deletions.
2 changes: 1 addition & 1 deletion src/random_events/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.1.0'
__version__ = '1.1.1'
46 changes: 14 additions & 32 deletions src/random_events/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,44 +319,26 @@ class EncodedEvent(Event):
@staticmethod
def check_element(variable: Variable, element: Any) -> Union[tuple, portion.Interval]:

if isinstance(element, Iterable) and not isinstance(element, (str, portion.Interval)):
element = tuple(element)

# if the element is already wrapped
if isinstance(element, (tuple, portion.Interval)):

# check that the element is in the variable's domain
if isinstance(variable, Discrete):
if not all(0 <= elem < len(variable.domain) for elem in element):
# raise an error
raise ValueError(f"Element {element} not in domain {variable.domain}")

element = tuple(sorted(element))
# return the element directly
return element

# if the variable is continuous
# if the variable is continuous, don't process the element
if isinstance(variable, Continuous):
return element

if element not in variable.domain:
# raise an error
raise ValueError(f"Element {element} not in domain {variable.domain}")
# if its any kind of iterable that's not an interval convert it to a tuple
if isinstance(element, Iterable) and not isinstance(element, portion.Interval):
element = tuple(sorted(element))

# return the element as a singleton interval
return portion.singleton(element)
# if it is just an int, convert it to a tuple containing the int
elif isinstance(element, int):
element = (element, )

# if the variable is discrete
elif isinstance(variable, Discrete):
if not isinstance(element, tuple):
raise ValueError("Element for a discrete domain must be a tuple, not {}".format(type(element)))

# if the element is not in the variables' domain
if 0 <= element < len(variable.domain):
# raise an error
raise ValueError(f"Element {element} not in domain {variable.domain}")
# if any element is not in the index set of the domain, raise an error
if not all(0 <= elem < len(variable.domain) for elem in element):
raise ValueError(f"Element {element} not in the index set of the domain {variable.domain}")

# return the element as a set
return (element,)
else:
raise TypeError(f"Unknown variable type {type(variable)}")
return element

def decode(self) -> Event:
"""
Expand Down
51 changes: 51 additions & 0 deletions test/test_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,5 +185,56 @@ def test_equality(self):
self.assertNotEqual(self.event, Event())


class EncodedEventTestCase(unittest.TestCase):

integer: Integer
symbol: Symbolic
real: Continuous

@classmethod
def setUpClass(cls):
"""
Create some event for testing.
"""
cls.integer = Integer("integer", set(range(10)))
cls.symbol = Symbolic("symbol", {"a", "b", "c"})
cls.real = Continuous("real")

def test_creation(self):
event = EncodedEvent()
event[self.integer] = 1
self.assertEqual(event[self.integer], (1,))
event[self.integer] = (1, 2)
self.assertEqual(event[self.integer], (1, 2))
event[self.symbol] = 0
self.assertEqual(event[self.symbol], (0, ))
event[self.symbol] = {1, 0}
self.assertEqual(event[self.symbol], (0, 1))

interval = portion.open(0, 1)
event[self.real] = interval
self.assertEqual(interval, event[self.real])

def test_raises(self):
event = EncodedEvent()
with self.assertRaises(ValueError):
event[self.symbol] = 3

with self.assertRaises(ValueError):
event[self.symbol] = portion.open(0, 1)

with self.assertRaises(ValueError):
event[self.symbol] = (1, 2, 3, 4)

def test_dict_like_creation(self):
event = EncodedEvent(zip([self.integer, self.symbol], [1, 0]))
self.assertEqual(event[self.integer], (1,))
self.assertEqual(event[self.symbol], (0,))

event = EncodedEvent(zip([self.integer, self.symbol], [[0, 1], 0]))
self.assertEqual(event[self.integer], (0, 1))
self.assertEqual(event[self.symbol], (0,))


if __name__ == '__main__':
unittest.main()

0 comments on commit 33ed6bd

Please sign in to comment.