diff --git a/src/random_events/__init__.py b/src/random_events/__init__.py index e7c12d2..f593cd5 100644 --- a/src/random_events/__init__.py +++ b/src/random_events/__init__.py @@ -1 +1 @@ -__version__ = '2.0.3' +__version__ = '2.0.4' diff --git a/src/random_events/events.py b/src/random_events/events.py index e0e7bb3..ad62c94 100644 --- a/src/random_events/events.py +++ b/src/random_events/events.py @@ -421,6 +421,12 @@ def fill_missing_variables(self, variables: Iterable[Variable]): if variable not in self: self[variable] = variable.domain + def decode(self): + """ + Decode the event to a normal event. + :return: The decoded event + """ + return self.__copy__() class EncodedEvent(Event): """ @@ -458,18 +464,17 @@ def check_element(variable: Variable, element: Any) -> Union[tuple, portion.Inte return element - def decode(self) -> Event: - """ - Decode the event to a normal event. - :return: The decoded event - """ - return Event({variable: variable.decode_many(value) for variable, value in self.items()}) - def fill_missing_variables(self, variables: Iterable[Variable]): for variable in variables: if variable not in self: self[variable] = variable.encode_many(variable.domain) + def decode(self) -> Event: + return Event({variable: variable.decode_many(value) for variable, value in self.items()}) + + def encode(self) -> Self: + return self.__copy__() + class ComplexEvent(SupportsSetOperations): """ @@ -660,5 +665,18 @@ def plotly_layout(self) -> Dict: def is_empty(self) -> bool: return len(self.events) == 0 + def encode(self) -> 'ComplexEvent': + """ + Encode the event to an encoded event. + :return: The encoded event + """ + return ComplexEvent([event.encode() for event in self.events]) + + def decode(self) -> ComplexEvent: + """ + Decode the event to a normal event. + """ + return ComplexEvent([event.decode() for event in self.events]) + EventType = Union[Event, EncodedEvent, ComplexEvent] diff --git a/test/test_events.py b/test/test_events.py index 21adfce..bb2de7b 100644 --- a/test/test_events.py +++ b/test/test_events.py @@ -366,6 +366,18 @@ def test_union_with_different_variables(self): for event in union.events: self.assertEqual(len(event), 2) + def test_copy(self): + event = Event({self.x: portion.closed(0, 1), self.y: portion.closed(0, 1)}) + copied = event.copy() + self.assertEqual(event, copied) + self.assertIsNot(event, copied) + + def test_decode_encode(self): + event = Event({self.x: portion.closed(0, 1), self.y: portion.closed(0, 1)}) + encoded = event.encode() + decoded = encoded.decode() + self.assertEqual(event, decoded) + class PlottingTestCase(unittest.TestCase): x: Continuous = Continuous("x")