Skip to content

Commit

Permalink
Fixed missing variables bug.
Browse files Browse the repository at this point in the history
  • Loading branch information
tomsch420 committed Mar 15, 2024
1 parent ffec3a4 commit d00e93d
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 7 deletions.
2 changes: 1 addition & 1 deletion src/random_events/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '2.0.2'
__version__ = '2.0.3'
41 changes: 35 additions & 6 deletions src/random_events/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import portion
import plotly.graph_objects as go

from typing_extensions import Set, Union, Any, TYPE_CHECKING, Iterable, List, Self, Dict
from typing_extensions import Set, Union, Any, TYPE_CHECKING, Iterable, List, Self, Dict, Tuple

from .variables import Variable, Continuous, Discrete

Expand Down Expand Up @@ -173,7 +173,7 @@ def union(self, other: EventType) -> ComplexEvent:

# add the fragments of the other event
complex_self.events.extend(fragments_of_other.events)
return complex_self
return ComplexEvent(complex_self.events)

def difference(self, other: EventType) -> ComplexEvent:
# if the other is a complex event
Expand Down Expand Up @@ -413,6 +413,14 @@ def plotly_layout(self) -> Dict:
def __hash__(self):
return hash(tuple(sorted(self.items())))

def fill_missing_variables(self, variables: Iterable[Variable]):
"""
Fill missing variables with their entire domain.
"""
for variable in variables:
if variable not in self:
self[variable] = variable.domain


class EncodedEvent(Event):
"""
Expand Down Expand Up @@ -457,6 +465,11 @@ def decode(self) -> Event:
"""
return Event({variable: variable.decode_many(value) for variable, value in self.items()})

def fill_missing_variables(self, variables: Iterable[Variable]):
for variable in variables:
if variable not in self:
self[variable] = variable.encode_many(variable.domain)


class ComplexEvent(SupportsSetOperations):
"""
Expand All @@ -466,8 +479,20 @@ class ComplexEvent(SupportsSetOperations):

def __init__(self, events: Iterable[Event]):
self.events = list(event for event in events if not event.is_empty())
variables = self.variables
for event in self.events:
event.fill_missing_variables(variables)

def union(self, other: Self) -> Self:
@property
def variables(self) -> Tuple[Variable, ...]:
"""
Get the variables of the complex event.
"""
return tuple(sorted(set(variable for event in self.events for variable in event.keys())))

def union(self, other: EventType) -> Self:
if isinstance(other, Event):
return self.union(ComplexEvent([other]))
result = ComplexEvent(self.events + other.events)
return result.make_events_disjoint().simplify()

Expand Down Expand Up @@ -573,11 +598,15 @@ def simplify(self) -> Self:
# if no simplification is possible, return the current complex event
return self.__copy__()

def intersection(self, other: Self) -> Self:
def intersection(self, other: EventType) -> Self:
if isinstance(other, Event):
return self.intersection(ComplexEvent([other]))
intersections = [event.intersection(other_event) for other_event in other.events for event in self.events]
return ComplexEvent(intersections)

def difference(self, other: Self) -> Self:
def difference(self, other: EventType) -> Self:
if isinstance(other, Event):
return self.difference(ComplexEvent([other]))
return self.intersection(other.complement())

def complement(self) -> Self:
Expand Down Expand Up @@ -632,4 +661,4 @@ def is_empty(self) -> bool:
return len(self.events) == 0


EventType = Union[Event, EncodedEvent, ComplexEvent]
EventType = Union[Event, EncodedEvent, ComplexEvent]
14 changes: 14 additions & 0 deletions test/test_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,20 @@ def test_chained_complement(self):
self.assertEqual(len(copied_event.events), 1)
self.assertEqual(copied_event.events[0], event)

def test_union_of_simple_with_complex(self):
event = Event({self.x: portion.closed(0, 1), self.y: portion.closed(0, 1)})
complex_event = ComplexEvent([event])
union1 = event.union(complex_event)
union2 = complex_event.union(event)
self.assertEqual(union1, union2)

def test_union_with_different_variables(self):
event1 = Event({self.x: portion.closed(0, 1)})
event2 = Event({self.y: portion.closed(0, 1)})
union = event1.union(event2)
for event in union.events:
self.assertEqual(len(event), 2)


class PlottingTestCase(unittest.TestCase):
x: Continuous = Continuous("x")
Expand Down

0 comments on commit d00e93d

Please sign in to comment.