diff --git a/src/random_events/__init__.py b/src/random_events/__init__.py index 668c344..e7c12d2 100644 --- a/src/random_events/__init__.py +++ b/src/random_events/__init__.py @@ -1 +1 @@ -__version__ = '2.0.2' +__version__ = '2.0.3' diff --git a/src/random_events/events.py b/src/random_events/events.py index 3914eba..e0e7bb3 100644 --- a/src/random_events/events.py +++ b/src/random_events/events.py @@ -7,7 +7,7 @@ import portion import plotly.graph_objects as go -from typing_extensions import Set, Union, Any, TYPE_CHECKING, Iterable, List, Self, Dict +from typing_extensions import Set, Union, Any, TYPE_CHECKING, Iterable, List, Self, Dict, Tuple from .variables import Variable, Continuous, Discrete @@ -173,7 +173,7 @@ def union(self, other: EventType) -> ComplexEvent: # add the fragments of the other event complex_self.events.extend(fragments_of_other.events) - return complex_self + return ComplexEvent(complex_self.events) def difference(self, other: EventType) -> ComplexEvent: # if the other is a complex event @@ -413,6 +413,14 @@ def plotly_layout(self) -> Dict: def __hash__(self): return hash(tuple(sorted(self.items()))) + def fill_missing_variables(self, variables: Iterable[Variable]): + """ + Fill missing variables with their entire domain. + """ + for variable in variables: + if variable not in self: + self[variable] = variable.domain + class EncodedEvent(Event): """ @@ -457,6 +465,11 @@ def decode(self) -> Event: """ return Event({variable: variable.decode_many(value) for variable, value in self.items()}) + def fill_missing_variables(self, variables: Iterable[Variable]): + for variable in variables: + if variable not in self: + self[variable] = variable.encode_many(variable.domain) + class ComplexEvent(SupportsSetOperations): """ @@ -466,8 +479,20 @@ class ComplexEvent(SupportsSetOperations): def __init__(self, events: Iterable[Event]): self.events = list(event for event in events if not event.is_empty()) + variables = self.variables + for event in self.events: + event.fill_missing_variables(variables) - def union(self, other: Self) -> Self: + @property + def variables(self) -> Tuple[Variable, ...]: + """ + Get the variables of the complex event. + """ + return tuple(sorted(set(variable for event in self.events for variable in event.keys()))) + + def union(self, other: EventType) -> Self: + if isinstance(other, Event): + return self.union(ComplexEvent([other])) result = ComplexEvent(self.events + other.events) return result.make_events_disjoint().simplify() @@ -573,11 +598,15 @@ def simplify(self) -> Self: # if no simplification is possible, return the current complex event return self.__copy__() - def intersection(self, other: Self) -> Self: + def intersection(self, other: EventType) -> Self: + if isinstance(other, Event): + return self.intersection(ComplexEvent([other])) intersections = [event.intersection(other_event) for other_event in other.events for event in self.events] return ComplexEvent(intersections) - def difference(self, other: Self) -> Self: + def difference(self, other: EventType) -> Self: + if isinstance(other, Event): + return self.difference(ComplexEvent([other])) return self.intersection(other.complement()) def complement(self) -> Self: @@ -632,4 +661,4 @@ def is_empty(self) -> bool: return len(self.events) == 0 -EventType = Union[Event, EncodedEvent, ComplexEvent] \ No newline at end of file +EventType = Union[Event, EncodedEvent, ComplexEvent] diff --git a/test/test_events.py b/test/test_events.py index 848d6c6..21adfce 100644 --- a/test/test_events.py +++ b/test/test_events.py @@ -352,6 +352,20 @@ def test_chained_complement(self): self.assertEqual(len(copied_event.events), 1) self.assertEqual(copied_event.events[0], event) + def test_union_of_simple_with_complex(self): + event = Event({self.x: portion.closed(0, 1), self.y: portion.closed(0, 1)}) + complex_event = ComplexEvent([event]) + union1 = event.union(complex_event) + union2 = complex_event.union(event) + self.assertEqual(union1, union2) + + def test_union_with_different_variables(self): + event1 = Event({self.x: portion.closed(0, 1)}) + event2 = Event({self.y: portion.closed(0, 1)}) + union = event1.union(event2) + for event in union.events: + self.assertEqual(len(event), 2) + class PlottingTestCase(unittest.TestCase): x: Continuous = Continuous("x")