diff --git a/src/random_events/events.py b/src/random_events/events.py index 76699c8..f83fa14 100644 --- a/src/random_events/events.py +++ b/src/random_events/events.py @@ -301,8 +301,7 @@ def encode(self) -> 'EncodedEvent': Encode the event to an encoded event. :return: The encoded event """ - return EncodedEvent({variable: variable.encode_many(element) if isinstance(variable, Discrete) else element for - variable, element in self.items()}) + return EncodedEvent({variable: variable.encode_many(element) for variable, element in self.items()}) def is_empty(self) -> bool: """ @@ -346,5 +345,5 @@ def decode(self) -> Event: :return: The decoded event """ return Event( - {variable: variable.decode_many(index) if isinstance(variable, Discrete) else index for variable, index in + {variable: variable.decode_many(value) for variable, value in self.items()}) diff --git a/src/random_events/variables.py b/src/random_events/variables.py index 25f0a8e..9cb474c 100644 --- a/src/random_events/variables.py +++ b/src/random_events/variables.py @@ -63,16 +63,16 @@ def encode_many(self, elements: Iterable) -> Iterable[Any]: :param elements: The elements to encode :return: The encoded elements """ - return tuple(map(self.encode, elements)) + return elements - def decode_many(self, indices: Iterable[int]) -> Iterable[Any]: + def decode_many(self, elements: Iterable) -> Iterable[Any]: """ Decode many elements from the representations that are usable for computations to their domains. - :param indices: The encoded elements + :param elements: The encoded elements :return: The decoded elements """ - return tuple(map(self.decode, indices)) + return elements class Continuous(Variable): @@ -133,6 +133,24 @@ def decode(self, index: int) -> Any: """ return self.domain[index] + def encode_many(self, elements: Iterable) -> Iterable[int]: + """ + Encode many elements of the domain to the indices of the elements. + + :param elements: The elements to encode + :return: The encoded elements + """ + return tuple(map(self.encode, elements)) + + def decode_many(self, elements: Iterable[int]) -> Iterable[Any]: + """ + Decode many elements from indices to their domains. + + :param elements: The encoded elements + :return: The decoded elements + """ + return tuple(map(self.decode, elements)) + class Symbolic(Discrete): """ diff --git a/test/test_events.py b/test/test_events.py index c8b7d10..2e194f1 100644 --- a/test/test_events.py +++ b/test/test_events.py @@ -105,9 +105,12 @@ def test_encode(self): """ Test that events are correctly encoded. """ + print(self.event) encoded = self.event.encode() + print(encoded) self.assertIsInstance(encoded, EncodedEvent) decoded = encoded.decode() + print(decoded) self.assertEqual(self.event, decoded) def test_intersection(self):