diff --git a/src/random_events/product_algebra.py b/src/random_events/product_algebra.py index a6da2bc..22516c4 100644 --- a/src/random_events/product_algebra.py +++ b/src/random_events/product_algebra.py @@ -68,7 +68,6 @@ def __init__(self, *args, **kwargs): for key, value in self.items(): self[key] = value - def as_composite_set(self) -> Event: return Event(self) @@ -166,6 +165,18 @@ def __lt__(self, other: Self): else: return self[variable] < other[variable] + def marginal(self, variables: VariableSet) -> SimpleEvent: + """ + Create the marginal event, that only contains the variables given.. + + :param variables: The variables to contain in the marginal event + :return: The marginal event + """ + result = self.__class__() + for variable in variables: + result[variable] = self[variable] + return result + def non_empty_to_string(self) -> str: return "{" + ", ".join(f"{variable.name} = {assignment}" for variable, assignment in self.items()) + "}" @@ -185,6 +196,8 @@ def plot(self) -> Union[List[go.Scatter], List[go.Mesh3d]]: """ assert all(isinstance(variable, Continuous) for variable in self.keys()), \ "Plotting is only supported for events that consist of only continuous variables." + if len(self.keys()) == 1: + return self.plot_1d() if len(self.keys()) == 2: return self.plot_2d() elif len(self.keys()) == 3: @@ -192,6 +205,21 @@ def plot(self) -> Union[List[go.Scatter], List[go.Mesh3d]]: else: raise NotImplementedError("Plotting is only supported for two and three dimensional events") + def plot_1d(self) -> List[go.Scatter]: + """ + Plot the event in 1D. + """ + xs = [] + ys = [] + + interval: Interval = list(self.values())[0] + for simple_interval in interval.simple_sets: + simple_interval: SimpleInterval + xs.extend([simple_interval.lower, simple_interval.upper, None]) + ys.extend([0, 0, None]) + + return [go.Scatter(x=xs, y=ys, mode="lines", name="Event", fill="toself")] + def plot_2d(self) -> List[go.Scatter]: """ Plot the event in 2D. @@ -253,7 +281,9 @@ def plotly_layout(self) -> Dict: """ Create a layout for the plotly plot. """ - if len(self.variables) == 2: + if len(self.variables) == 1: + result = {"xaxis_title": self.variables[0].name} + elif len(self.variables) == 2: result = {"xaxis_title": self.variables[0].name, "yaxis_title": self.variables[1].name} elif len(self.variables) == 3: @@ -369,6 +399,18 @@ def new_empty_set(self) -> Self: def complement_if_empty(self) -> Self: raise NotImplementedError("Complement of an empty Event is not yet supported.") + def marginal(self, variables: VariableSet) -> Event: + """ + Create the marginal event, that only contains the variables given.. + + :param variables: The variables to contain in the marginal event + :return: The marginal event + """ + result = self.__class__() + for simple_set in self.simple_sets: + result.add_simple_set(simple_set.marginal(variables)) + return result.make_disjoint() + def plot(self, color="#636EFA") -> Union[List[go.Scatter], List[go.Mesh3d]]: """ Plot the complex event. diff --git a/src/random_events/variable.py b/src/random_events/variable.py index b10da3b..52b421f 100644 --- a/src/random_events/variable.py +++ b/src/random_events/variable.py @@ -81,6 +81,9 @@ def __init__(self, name: str, domain: Union[Type[SetElement], SetElement]): else: super().__init__(name, domain) + def domain_type(self) -> Type[SetElement]: + return self.domain.simple_sets[0].all_elements + class Integer(Variable): """ diff --git a/test/test_product_algebra.py b/test/test_product_algebra.py index d857f0f..4236ef7 100644 --- a/test/test_product_algebra.py +++ b/test/test_product_algebra.py @@ -99,6 +99,16 @@ def test_union(self): SimpleEvent({self.a: TestEnum.B, self.x: open(1, 4)})) self.assertEqual(union, result) + def test_marginal_event(self): + event_1 = SimpleEvent({self.x: closed(0, 1), self.y: SimpleInterval(0, 1)}) + event_2 = SimpleEvent({self.x: closed(1, 2), self.y: Interval(SimpleInterval(3, 4))}) + event_3 = SimpleEvent({self.x: closed(5, 6), self.y: Interval(SimpleInterval(5, 6))}) + event = Event(event_1, event_2, event_3) + marginal = event.marginal(SortedSet([self.x])) + self.assertEqual(marginal, SimpleEvent({self.x: closed(0, 2) | closed(5, 6)}).as_composite_set()) + fig = go.Figure(marginal.plot()) + # fig.show() + if __name__ == '__main__': unittest.main()