From 62acad7877cde6266d9dd9dfda5f1acebf6342b6 Mon Sep 17 00:00:00 2001 From: Tom Schierenbeck Date: Thu, 30 May 2024 18:26:47 +0200 Subject: [PATCH] Implementation besides plotting is done. Proof for complement of product algebra is in the making --- examples/product_spaces.ipynb | 52 ++ requirements.txt | 2 - src/random_events/better_variables.py | 59 -- src/random_events/events.py | 786 -------------------------- src/random_events/interval.py | 20 +- src/random_events/product_algebra.py | 146 ++++- src/random_events/set.py | 8 + src/random_events/sigma_algebra.py | 25 +- src/random_events/utils.py | 3 + src/random_events/variable.py | 94 +++ src/random_events/variables.py | 291 ---------- test/test_events.py | 482 ---------------- test/test_interval.py | 14 + test/test_product_algebra.py | 74 +++ test/test_set.py | 22 + test/test_variable.py | 17 +- test/test_variables.py | 105 ---- 17 files changed, 448 insertions(+), 1752 deletions(-) delete mode 100644 src/random_events/better_variables.py delete mode 100644 src/random_events/events.py create mode 100644 src/random_events/variable.py delete mode 100644 src/random_events/variables.py delete mode 100644 test/test_events.py create mode 100644 test/test_product_algebra.py delete mode 100644 test/test_variables.py diff --git a/examples/product_spaces.ipynb b/examples/product_spaces.ipynb index 85e50d7..9ec175f 100644 --- a/examples/product_spaces.ipynb +++ b/examples/product_spaces.ipynb @@ -651,6 +651,58 @@ "execution_count": 62, "outputs": [] }, + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "## Complement of the Product Algebra\n", + "\n", + "[This](https://www.math.ucdavis.edu/~hunter/m206/ch4_measure_notes.pdf) mentions that the complement of an element of the product measure is constructed by\n", + "$$\n", + " (A \\times B)^c = (A^c \\times B) \\cup (A \\times B^c) \\cup (A^c \\times B^c).\n", + "$$\n", + "It is easy to see that this construction would produce exponential many elements with respect to the number of variables. This is unfortunate.\n", + "However, the correct complement can be formed with linear many terms, which is way more efficient. The following equations describe a proof by induction on how that can be done.\n", + "\n", + "Let\n", + "\\begin{align*}\n", + " \\mathbb{A} &= A \\cup A^c \\, , \\\\\n", + " \\mathbb{B} &= B \\cup B^c \\text{ and }\\\\\n", + " \\mathbb{C} &= C \\cup C^c.\n", + "\\end{align*}\n", + "\n", + "### Induction Assumption\n", + "\n", + "\\begin{align*}\n", + " (A \\times B)^c = (A^c \\times \\mathbb{B}) \\cup (A \\times B^C)\n", + "\\end{align*}\n", + "Proof:\n", + "\\begin{align*}\n", + " (A \\times B)^c &= (A^c \\times B) \\cup (A \\times B^c) \\cup (A^c \\times B^c) \\\\\n", + " &= (A^c \\times B) \\cup (A^c \\times B^c) \\cup (A \\times B^c) \\\\\n", + " &= ( A^c \\times (B \\cup B^c) ) \\cup (A \\times B^c) \\\\\n", + " &= (A^c \\times \\mathbb{B}) \\cup (A \\times B^C) \\square\n", + "\\end{align*}\n", + "\n", + "### Induction Step\n", + "\n", + "\\begin{align*}\n", + " (A \\times B \\times C)^c = (A^c \\times \\mathbb{B} \\times \\mathbb{C}) \\cup (A \\times B^C \\times \\mathbb{C} ) \\cup (A \\times B \\times C^c)\n", + "\\end{align*}\n", + "Proof:\n", + "\\begin{align*}\n", + " (A \\times B \\times C)^c &= (A^c \\times B \\times C) \\cup (A \\times B^c \\times C) \\cup (A \\times B \\times C^c) \\cup \n", + " (A^c \\times B^c \\times C) \\cup (A^c \\times B \\times C^c) \\cup (A \\times B^c \\times C^c) \\cup \n", + " (A^c \\times B^c \\times C^c) \\\\\n", + " &= (C \\times \\underbrace{(A^c \\times B) \\cup (A \\times B^c) \\cup (A^c \\times B^c))}_{\\text{Induction Assumption}} \\cup\n", + " (C^c \\times \\underbrace{(A^c \\times B) \\cup (A \\times B^c) \\cup (A^c \\times B^c))}_{\\text{Induction Assumption}} \\cup (A \\times B \\times C^c) \\\\\n", + " &= (C \\times (A^c \\times \\mathbb{B}) \\cup (A \\times B^C)) \\cup \n", + " (C^c \\times (A^c \\times \\mathbb{B}) \\cup (A \\times B^C)) \\cup (A \\times B \\times C^c)\\\\\n", + " &= \n", + "\\end{align*}\n" + ], + "id": "511cdcad45f76bab" + }, { "cell_type": "markdown", "source": [ diff --git a/requirements.txt b/requirements.txt index 063a8bf..fdb32df 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,4 @@ -portion~=2.4.2 numpy~=1.26.1 plotly~=5.20.0 typing_extensions - sortedcontainers~=2.4.0 \ No newline at end of file diff --git a/src/random_events/better_variables.py b/src/random_events/better_variables.py deleted file mode 100644 index 0090266..0000000 --- a/src/random_events/better_variables.py +++ /dev/null @@ -1,59 +0,0 @@ -from dataclasses import dataclass - -from typing_extensions import Self, Type - -from .interval import Interval, SimpleInterval -from .set import Set, SetElement -from .sigma_algebra import AbstractCompositeSet - - -@dataclass -class Variable: - name: str - domain: AbstractCompositeSet - - def __lt__(self, other: Self) -> bool: - """ - Returns True if self < other, False otherwise. - """ - return self.name < other.name - - def __gt__(self, other: Self) -> bool: - """ - Returns True if self > other, False otherwise. - """ - return self.name > other.name - - def __hash__(self) -> int: - return self.name.__hash__() - - def __eq__(self, other): - return self.name == other.name - - def __str__(self): - return f"{self.__class__.__name__}({self.name}, {self.domain})" - - def __repr__(self): - return f"{self.__class__.__name__}({self.name})" - - -@dataclass -class Continuous(Variable): - domain: Interval = Interval([SimpleInterval(-float("inf"), float("inf"))]) - - -@dataclass -class Symbolic(Variable): - """ - Class for unordered, finite, discrete random variables. - """ - domain: Set - - def __init__(self, name: str, domain: Type): - super().__init__(name, Set([value for value in domain if value != domain.EMPTY_SET])) - - -@dataclass -class Integer(Variable): - """Class for ordered, discrete random variables.""" - domain: Interval = Interval([SimpleInterval(-float("inf"), float("inf"))]) diff --git a/src/random_events/events.py b/src/random_events/events.py deleted file mode 100644 index 5dd86a6..0000000 --- a/src/random_events/events.py +++ /dev/null @@ -1,786 +0,0 @@ -from __future__ import annotations - -import itertools -from collections import UserDict - -import numpy as np -import portion -import plotly.graph_objects as go - -from typing_extensions import Set, Union, Any, TYPE_CHECKING, Iterable, List, Self, Dict, Tuple - -from .variables import Variable, Continuous, Discrete -from .utils import SubclassJSONSerializer - - -# Type hinting for Python 3.7 to 3.9 -if TYPE_CHECKING: - VariableMapType = UserDict[str, Variable] -else: - VariableMapType = UserDict - - -class VariableMap(VariableMapType): - """ - A map of variables to values. - - Accessing a variable by name is also supported. - """ - - def variable_of(self, name: str) -> Variable: - """ - Get the variable with the given name. - :param name: The variable's name - :return: The variable itself - """ - variable = [variable for variable in self.keys() if variable.name == name] - if len(variable) == 0: - raise KeyError(f"Variable {name} not found in event {self}") - return variable[0] - - def __getitem__(self, item: Union[str, Variable]): - if isinstance(item, str): - item = self.variable_of(item) - return super().__getitem__(item) - - def __setitem__(self, key: Union[str, Variable], value: Any): - if isinstance(key, str): - key = self.variable_of(key) - - if not isinstance(key, Variable): - raise TypeError(f"Key must be a Variable, not {type(key)}") - super().__setitem__(key, value) - - def __copy__(self): - return self.__class__({variable: value for variable, value in self.items()}) - - -# Type hinting for Python 3.7 to 3.9 -if TYPE_CHECKING: - EventMapType = VariableMap[str, Union[tuple, portion.Interval]] -else: - EventMapType = VariableMap - - -class SupportsSetOperations: - """ - A class that supports set operations. - """ - - def union(self, other: Self) -> Self: - """ - Form the union of this object with another object. - """ - raise NotImplementedError - - def __or__(self, other: Self): - return self.union(other) - - def intersection(self, other: Self) -> Self: - """ - Form the intersection of this object with another object. - """ - raise NotImplementedError - - def __and__(self, other): - return self.intersection(other) - - def difference(self, other: Self) -> Self: - """ - Form the difference of this object with another object. - """ - raise NotImplementedError - - def __sub__(self, other): - return self.difference(other) - - def complement(self) -> Self: - """ - Form the complement of this object. - """ - raise NotImplementedError - - def __invert__(self): - return self.complement() - - def is_empty(self) -> bool: - """ - Check if this object is empty. - """ - raise NotImplementedError - - -class Event(SupportsSetOperations, EventMapType, SubclassJSONSerializer): - """ - A map of variables to values of their respective domains. - """ - - def __str__(self): - return "{" + ", ".join(f"{variable.name}: {value}" for variable, value in self.items()) + "}" - - def check_same_type(self, other: Any): - """ - Check that both self and other are of the same type. - - :param other: The other object - """ - if type(self) is not type(other): - raise TypeError(f"Cannot use operation on {type(self)} with {type(other)}") - - def intersection(self, other: EventType) -> EventType: - - # if the other is a complex event - if isinstance(other, ComplexEvent): - - # flip the call - return other.intersection(ComplexEvent([self])) - - self.check_same_type(other) - result = self.__class__() - - variables = set(self.keys()) | set(other.keys()) - for variable in variables: - assignment1 = self.get(variable, variable.encoded_domain if isinstance(self, EncodedEvent) else variable.domain) - assignment2 = other.get(variable, variable.encoded_domain if isinstance(self, EncodedEvent) else variable.domain) - intersection = variable.intersection_of_assignments(assignment1, assignment2) - result[variable] = intersection - - return result - - def union(self, other: EventType) -> ComplexEvent: - # create complex event from self - complex_self = ComplexEvent([self]) - - # if the other is a complex event - if isinstance(other, ComplexEvent): - - # flip the call - return other.union(complex_self) - - self.check_same_type(other) - - # form the intersection - intersection = self.intersection(other) - - # if the intersection of the two events is empty - if intersection.is_empty(): - - # trivially mount it - complex_self.events.append(other) - return complex_self - - # form complement of intersection - complement_of_intersection = intersection.complement() - - # intersect the other event with the complement of the intersection - fragments_of_other = complement_of_intersection.intersection(ComplexEvent([other])) - - # add the fragments of the other event - complex_self.events.extend(fragments_of_other.events) - return ComplexEvent(complex_self.events) - - def difference(self, other: EventType) -> ComplexEvent: - # if the other is a complex event - if isinstance(other, ComplexEvent): - - # flip the call - return other.complement().intersection(ComplexEvent([self])) - - self.check_same_type(other) - - # form the intersection - intersection = self.intersection(other) - - # if the intersection of the two events is empty - if intersection.is_empty(): - return ComplexEvent([self]) - - # form complement of intersection - complement_of_intersection = intersection.complement() - - # construct intersection of complement - return ComplexEvent([event.intersection(self) for event in complement_of_intersection.events - if not event.is_empty()]) - - def complement(self) -> ComplexEvent: - # initialize events - events = [] - - # get variables as set - variables: Set[Variable] = set(self.keys()) - - # memorize processed variables - processed_variables = [] - - # for every assignment - for variable, value in self.items(): - - # create the current complementary event - complement_event = self.__class__() - - # invert this variables assignment - complement_event[variable] = variable.complement_of_assignment( - value, encoded=isinstance(self, EncodedEvent)) - - # for every other variable - for other_variable in variables.difference({variable}): - - # if the other variable is already processed - if other_variable in processed_variables: - # add the assignment to the current event - complement_event[other_variable] = self[other_variable] - else: - # add the entire domain to the current event - other_domain = other_variable.domain - - # encode if necessary - if isinstance(self, EncodedEvent): - other_domain = other_variable.encode_many(other_domain) - complement_event[other_variable] = other_domain - - # add to processed variables - processed_variables.append(variable) - - # add to complex event - if not complement_event.is_empty(): - events.append(complement_event) - return ComplexEvent(events) - - def __eq__(self, other: Self) -> bool: - """ - Check if two events are equal. - - If one variable is only in one of the events, it is assumed that the other event has the entire domain as - default value. - """ - - variables = set(self.keys()) | set(other.keys()) - - equal = True - - for variable in variables: - if variable in self and variable not in other: - value_equal = variable.domain == self[variable] - elif variable in other and variable not in self: - value_equal = variable.domain == other[variable] - else: - value_equal = self[variable] == other[variable] - equal &= value_equal - - return equal - - @staticmethod - def check_element(variable: Variable, element: Any) -> Union[tuple, portion.Interval]: - """ - Check that elements can be regarded as elements of the variable's domain. - - Wrap a single element into a set or interval, depending on the variable type. - For any Iterable type that is not a string, the element is converted to a tuple of that iterable. - - :param variable: The variable where the element should belong to - :param element: The element to wrap - :return: The wrapped element - """ - - if isinstance(element, Iterable) and not isinstance(element, (str, portion.Interval)): - element = tuple(element) - - # if the element is already wrapped - if isinstance(element, (tuple, portion.Interval)): - - # check that the element is in the variable's domain - if isinstance(variable, Discrete): - if not all(elem in variable.domain for elem in element): - # raise an error - raise ValueError(f"Element {element} not in domain {variable.domain}") - - element = tuple(sorted(element)) - - # return the element directly - return element - - # if the element is not in the variables' domain - if element not in variable.domain: - # raise an error - raise ValueError(f"Element {element} not in domain {variable.domain}") - - # if the variable is continuous - if isinstance(variable, Continuous): - # return the element as a singleton interval - return portion.singleton(element) - - # if the variable is discrete - elif isinstance(variable, Discrete): - # return the element as a set - return (element,) - else: - raise TypeError(f"Unknown variable type {type(variable)}") - - def __setitem__(self, key: Union[str, Variable], value: Any): - EventMapType.__setitem__(self, key, self.check_element(key, value)) - - def encode(self) -> 'EncodedEvent': - """ - Encode the event to an encoded event. - :return: The encoded event - """ - return EncodedEvent({variable: variable.encode_many(element) for variable, element in self.items()}) - - def is_empty(self) -> bool: - return any(len(value) == 0 for value in self.values()) or len(self.keys()) == 0 - - def plot(self) -> Union[List[go.Scatter], List[go.Mesh3d]]: - """ - Plot the event. - """ - assert all(isinstance(variable, Continuous) for variable in self.keys()), "Can only plot continuous events" - if len(self.keys()) == 2: - return self.plot_2d() - elif len(self.keys()) == 3: - return self.plot_3d() - else: - raise ValueError("Can only plot 2D and 3D events") - - def plot_2d(self) -> List[go.Scatter]: - """ - Plot the event in 2D. - """ - - # form cartesian product of all intervals - intervals = [value._intervals for value in self.values()] - simple_events = list(itertools.product(*intervals)) - - xs = [] - ys = [] - - # for every atomic interval - for simple_event in simple_events: - - # plot a rectangle - points = np.asarray(list(itertools.product(*[[axis.lower, axis.upper] for axis in simple_event]))) - y_points = points[:, 1] - y_points[len(y_points) // 2:] = y_points[len(y_points) // 2:][::-1] - xs.extend(points[:, 0].tolist() + [points[0, 0], None]) - ys.extend(y_points.tolist()+ [y_points[0], None]) - - return [go.Scatter(x=xs, y=ys, mode="lines", name="Event", fill="toself")] - - def plot_3d(self) -> List[go.Mesh3d]: - """ - Plot the event in 3D. - """ - - # form cartesian product of all intervals - intervals = [value._intervals for _, value in sorted(self.items())] - simple_events = list(itertools.product(*intervals)) - traces = [] - - # shortcut for the dimensions - x, y, z = 0, 1, 2 - - # for every atomic interval - for simple_event in simple_events: - - # Create a 3D mesh trace for the rectangle - traces.append(go.Mesh3d( - # 8 vertices of a cube - x=[simple_event[x].lower, simple_event[x].lower, simple_event[x].upper, simple_event[x].upper, - simple_event[x].lower, simple_event[x].lower, simple_event[x].upper, simple_event[x].upper], - y=[simple_event[y].lower, simple_event[y].upper, simple_event[y].upper, simple_event[y].lower, - simple_event[y].lower, simple_event[y].upper, simple_event[y].upper, simple_event[y].lower], - z=[simple_event[z].lower, simple_event[z].lower, simple_event[z].lower, simple_event[z].lower, - simple_event[z].upper, simple_event[z].upper, simple_event[z].upper, simple_event[z].upper], - # i, j and k give the vertices of triangles - i=[7, 0, 0, 0, 4, 4, 6, 6, 4, 0, 3, 2], - j=[3, 4, 1, 2, 5, 6, 5, 2, 0, 1, 6, 3], - k=[0, 7, 2, 3, 6, 7, 1, 1, 5, 5, 7, 6], - flatshading=True - )) - return traces - - def plotly_layout(self) -> Dict: - """ - Create a layout for the plotly plot. - """ - variables = list(sorted(self.keys())) - if len(variables) == 2: - result = {"xaxis_title": variables[0].name, - "yaxis_title": variables[1].name} - elif len(variables) == 3: - result = dict(scene=dict( - xaxis_title=variables[0].name, - yaxis_title=variables[1].name, - zaxis_title=variables[2].name) - ) - else: - raise NotImplementedError("Can only plot 2D and 3D events") - - return result - - def __hash__(self): - return hash(tuple(sorted(self.items()))) - - def fill_missing_variables(self, variables: Iterable[Variable]): - """ - Fill missing variables with their entire domain. - """ - for variable in variables: - if variable not in self: - self[variable] = variable.domain - - def decode(self): - """ - Decode the event to a normal event. - :return: The decoded event - """ - return self.__copy__() - - def marginal_event(self, variables: Iterable[Variable]) -> Self: - """ - Get the marginal event of this event with respect to a variable. - """ - return self.__class__({variable: self[variable] for variable in variables if variable in self}) - - def to_json(self) -> Dict[str, Any]: - result = super().to_json() - event = [(variable.to_json(), variable.assignment_to_json(assignment)) for variable, assignment in self.items()] - result["event"] = event - return result - - @classmethod - def _from_json(cls, data: Dict[str, Any]) -> Self: - result = cls() - for variable_json, assignment_json in data["event"]: - variable = Variable.from_json(variable_json) - assignment = variable.assignment_from_json(assignment_json) - result[variable] = assignment - return result - - def get_variables_where_assignment_is_different(self, other: Self) -> List[Variable]: - """ - Get all variables where the assignment is different from the other event's assignment - """ - return [variable for variable in self.keys() if self[variable] != other[variable]] - - def to_typst(self) -> str: - """ - Convert the event to a typst string. - """ - return " times ".join(f"{variable.name}_({variable.assignment_to_typst(value)})" for variable, value in - self.items()) - - -class EncodedEvent(Event): - """ - A map of variables to indices of their respective domains. - """ - - @staticmethod - def check_element(variable: Variable, element: Any) -> Union[tuple, portion.Interval]: - - # if the variable is continuous - if isinstance(variable, Continuous): - - # if it's not an interval - if not isinstance(element, portion.Interval): - - # try to convert it to one - element = portion.singleton(element) - - return element - - # if its any kind of iterable that's not an interval convert it to a tuple - if isinstance(element, Iterable) and not isinstance(element, portion.Interval): - element = tuple(sorted(element)) - - # if it is just an int, convert it to a tuple containing the int - elif isinstance(element, int): - element = (element, ) - - if not isinstance(element, tuple): - raise ValueError("Element for a discrete domain must be a tuple, not {}".format(type(element))) - - # if any element is not in the index set of the domain, raise an error - if not all(0 <= elem < len(variable.domain) for elem in element): - raise ValueError(f"Element {element} not in the index set of the domain {variable.domain}") - - return element - - def fill_missing_variables(self, variables: Iterable[Variable]): - for variable in variables: - if variable not in self: - self[variable] = variable.encode_many(variable.domain) - - def decode(self) -> Event: - return Event({variable: variable.decode_many(value) for variable, value in self.items()}) - - def encode(self) -> Self: - return self.__copy__() - - -class ComplexEvent(SupportsSetOperations, SubclassJSONSerializer): - """ - A complex event is a set of mutually exclusive events. - """ - events: List[Event] - - def __init__(self, events: Iterable[Event]): - self.events = list(event for event in events if not event.is_empty()) - variables = self.variables - for event in self.events: - event.fill_missing_variables(variables) - self.events = [event.__class__(sorted(event.items())) for event in self.events] - - @property - def variables(self) -> Tuple[Variable, ...]: - """ - Get the variables of the complex event. - """ - return tuple(sorted(set(variable for event in self.events for variable in event.keys()))) - - def union(self, other: EventType) -> Self: - if isinstance(other, Event): - return self.union(ComplexEvent([other])) - result = ComplexEvent(self.events + other.events) - return result.make_events_disjoint().simplify() - - def make_events_disjoint(self) -> Self: - """ - Make all events in this complex event disjoint. - - This is done by forming the intersection of all events recursively until no more intersections are found. - Then, the original events are decomposed into their disjoint components. - Finally, a complex event is formed from the disjoint components. Note that the result may not be the minimal - representation of the complex event, so it is advised to call ``simplify`` afterward. - """ - - # initialize previous intersections - previous_intersections = [] - - # for every pair of events - for index, event in enumerate(self.events): - for other_event in self.events[index + 1:]: - - # append intersection of pairwise events - intersection = event.intersection(other_event) - if not intersection.is_empty() and intersection not in previous_intersections: - previous_intersections.append(intersection) - - # if there are no intersections, skip the rest - if len(previous_intersections) == 0: - return self - - # while - while len(previous_intersections) > 0: - - # initialize new intersections - new_intersections = [] - - # form pairwise intersections of previous intersections - for index, intersection in enumerate(previous_intersections): - for other_intersection in previous_intersections[index + 1:]: - if not intersection.intersection(other_intersection).is_empty(): - new_intersections.append(intersection) - - if len(new_intersections) == 0: - break - - previous_intersections = new_intersections - - # sanity check - complex_event_of_intersections = ComplexEvent(previous_intersections) - assert complex_event_of_intersections.are_events_disjoint(), "Events are not disjoint" - - # initialize result - result = ComplexEvent(complex_event_of_intersections.events) - - # for every original event - for original_event in self.events: - - # initialize the difference of the original events with the intersection - decomposed_original_events = set() - - # for every atomic intersection in the disjoint intersections - for atomic_intersection in complex_event_of_intersections.events: - - # get the difference of the original event with the atomic intersection - original_event_disjoint_component = original_event.difference(atomic_intersection) - - # unify the differences - decomposed_original_events = decomposed_original_events.union( - set(original_event_disjoint_component.events)) - - # add the differences to the result - result.events.extend(decomposed_original_events) - return result - - def simplify(self) -> Self: - """ - Simplify the complex event such that sub-unions of events that can be expressed as a single events - are merged. - - This is done by seeking events that are equal in all but one dimensions to another event and merging them. - """ - - if len(self.variables) == 1: - return self.merge_if_one_dimensional() - - # for every unique pair of events - for index, event in enumerate(self.events): - for other_index, other_event in enumerate(self.events[index + 1:]): - other_index += index + 1 - - # get the different variables - different_variables = event.get_variables_where_assignment_is_different(other_event) - - # if they are the same event - if len(different_variables) == 0: - # recurse into the simpler complex event - result = ComplexEvent([]) - result.events.extend([event_ for index_, event_ in enumerate(self.events) - if index_ != other_index]) - return result.simplify() - - # if they differ in only one dimension - if len(different_variables) == 1: - mismatching_variable = different_variables[0] - - unified_event = event.__copy__() - unified_event[mismatching_variable] = (mismatching_variable. - union_of_assignments(event[mismatching_variable], - other_event[mismatching_variable])) - - # recurse into the simpler complex event - result = ComplexEvent([]) - result.events.append(unified_event) - result.events.extend([event_ for index_, event_ in enumerate(self.events) - if index_ not in [index, other_index]]) - return result.simplify() - - # if no simplification is possible, return the current complex event - return self.__copy__() - - def intersection(self, other: EventType) -> Self: - if isinstance(other, Event): - return self.intersection(ComplexEvent([other])) - intersections = [event.intersection(other_event) for other_event in other.events for event in self.events] - return ComplexEvent(intersections).simplify() - - def difference(self, other: EventType) -> Self: - if isinstance(other, Event): - return self.difference(ComplexEvent([other])) - return self.intersection(other.complement()) - - def complement(self) -> Self: - result = self.events[0].complement() - for event in self.events[1:]: - current_complement = event.complement() - result = result.intersection(current_complement) - return result.make_events_disjoint().simplify() - - def are_events_disjoint(self) -> bool: - """ - Check if all events inside this complex event are disjoint. - """ - for index, event in enumerate(self.events): - for event_ in self.events[index + 1:]: - if not event.intersection(event_).is_empty(): - return False - return True - - def __str__(self): - return " u ".join(str(event) for event in self.events) - - def __repr__(self): - return f"Union of {len(self.events)} events" - - def __eq__(self, other: ComplexEvent) -> bool: - """ - Check if two complex events are equal. - """ - return (all(event in other.events for event in self.events) - and all(event in self.events for event in other.events)) - - def __copy__(self): - return self.__class__([event.copy() for event in self.events]) - - def plot(self, color="#636EFA") -> Union[List[go.Scatter], List[go.Mesh3d]]: - """ - Plot the complex event. - - :param color: The color to use for this event - """ - traces = [] - show_legend = True - for index, event in enumerate(self.events): - event_traces = event.plot() - for event_trace in event_traces: - if len(event.keys()) == 2: - event_trace.update(name="Event", legendgroup=id(self), showlegend=show_legend, - line=dict(color=color)) - if len(event.keys()) == 3: - event_trace.update(name="Event", legendgroup=id(self), showlegend=show_legend, color=color) - show_legend = False - traces.append(event_trace) - return traces - - def plotly_layout(self) -> Dict: - """ - Create a layout for the plotly plot. - """ - return self.events[0].plotly_layout() - - def is_empty(self) -> bool: - return len(self.events) == 0 - - def encode(self) -> 'ComplexEvent': - """ - Encode the event to an encoded event. - :return: The encoded event - """ - return ComplexEvent([event.encode() for event in self.events]) - - def decode(self) -> ComplexEvent: - """ - Decode the event to a normal event. - """ - return ComplexEvent([event.decode() for event in self.events]) - - def marginal_event(self, variables: Iterable[Variable]) -> Self: - """ - Get the marginal event of this complex event with respect to a variable. - """ - return ComplexEvent([event.marginal_event(variables) for event in self.events]).simplify() - - def merge_if_one_dimensional(self) -> Self: - """ - Merge all events into a single event if they are all one-dimensional. - """ - if not len(self.variables) == 1: - return self - variable = self.variables[0] - value = self.events[0][variable] - - for event in self.events[1:]: - value = variable.union_of_assignments(value, event[variable]) - return ComplexEvent([Event({variable: value})]) - - def to_json(self) -> Dict[str, Any]: - result = super().to_json() - events = [event.to_json() for event in self.events] - result["events"] = events - return result - - @classmethod - def _from_json(cls, data: Dict[str, Any]) -> Self: - events = [Event.from_json(event) for event in data["events"]] - return cls(events) - - def to_typst(self) -> str: - """ - Convert the event to a typst string. - """ - return " union ".join(f"({event.to_typst()})" for event in self.events) - - -EventType = Union[Event, EncodedEvent, ComplexEvent] diff --git a/src/random_events/interval.py b/src/random_events/interval.py index 031c88b..63479bd 100644 --- a/src/random_events/interval.py +++ b/src/random_events/interval.py @@ -1,6 +1,6 @@ import enum from dataclasses import dataclass -from typing import Set +from typing import Dict, Any from sortedcontainers import SortedSet from typing_extensions import Self @@ -124,7 +124,7 @@ def complement(self) -> SortedSet[Self]: def contains(self, item: float) -> bool: return (self.lower < item < self.upper or (self.lower == item and self.left == Bound.CLOSED) or ( - self.upper == item and self.right == Bound.CLOSED)) + self.upper == item and self.right == Bound.CLOSED)) def __hash__(self): return hash((self.lower, self.upper, self.left, self.right)) @@ -140,9 +140,17 @@ def __repr__(self): def __str__(self): return sigma_algebra.AbstractSimpleSet.to_string(self) + def to_json(self) -> Dict[str, Any]: + return {**super().to_json(), 'lower': self.lower, 'upper': self.upper, 'left': self.left.name, + 'right': self.right.name} + + @classmethod + def _from_json(cls, data: Dict[str, Any]) -> Self: + return cls(data['lower'], data['upper'], Bound[data['left']], Bound[data['right']]) -class Interval(sigma_algebra.AbstractCompositeSet): + +class Interval(sigma_algebra.AbstractCompositeSet): simple_sets: SortedSet[SimpleInterval] def simplify(self) -> Self: @@ -161,9 +169,9 @@ def simplify(self) -> Self: last_simple_interval = result.simple_sets[-1] # if the borders are connected - if (last_simple_interval.upper > current_simple_interval.lower or - (last_simple_interval.upper == current_simple_interval.lower and not( - last_simple_interval.right == Bound.OPEN and current_simple_interval.left == Bound.OPEN))): + if (last_simple_interval.upper > current_simple_interval.lower or ( + last_simple_interval.upper == current_simple_interval.lower and not ( + last_simple_interval.right == Bound.OPEN and current_simple_interval.left == Bound.OPEN))): # extend the upper bound of the last element last_simple_interval.upper = current_simple_interval.upper diff --git a/src/random_events/product_algebra.py b/src/random_events/product_algebra.py index 6d60d7c..fd77c75 100644 --- a/src/random_events/product_algebra.py +++ b/src/random_events/product_algebra.py @@ -1,12 +1,9 @@ -from collections.abc import dict_keys, dict_values -from typing import Any - from sortedcontainers import SortedDict, SortedKeysView, SortedValuesView from typing_extensions import Union, Any -from .better_variables import * +from .variable import * from .sigma_algebra import * -from .variables import Variable +from .variable import Variable class VariableMap(SortedDict[Variable, Any]): @@ -17,7 +14,7 @@ class VariableMap(SortedDict[Variable, Any]): """ @property - def variables(self) -> dict_keys[Variable]: + def variables(self) -> SortedKeysView[Variable]: return self.keys() def variable_of(self, name: str) -> Variable: @@ -50,9 +47,14 @@ def __copy__(self): class SimpleEvent(AbstractSimpleSet, VariableMap[Variable, AbstractCompositeSet]): + """ + A simple event is a set of assignments of variables to values. + + A simple event is logically equivalent to a conjunction of assignments. + """ @property - def assignments(self) -> dict_values[AbstractCompositeSet]: + def assignments(self) -> SortedValuesView[AbstractCompositeSet]: return self.values() def intersection_with(self, other: Self) -> Self: @@ -69,10 +71,47 @@ def intersection_with(self, other: Self) -> Self: return result def complement(self) -> SortedSet[Self]: - pass - def is_empty(self) -> bool: + # initialize result + result = SortedSet() + + # initialize variables where the complement has already been computed + processed_variables = [] + + # for every key, value pair + for variable, assignment in self.items(): + + # initialize the current complement + current_complement = SimpleEvent() + + # set the current variable to its complement + current_complement[variable] = assignment.complement() + + # for every other variable + for other_variable in self.variables: + + # skip this iteration if the other variable is the same as the current one + if other_variable == variable: + continue + # if it has been processed, set copy its assignment from this + if other_variable in processed_variables: + current_complement[other_variable] = self[other_variable] + + # otherwise, set it to its domain (set of all values) + else: + current_complement[other_variable] = other_variable.domain + + # memorize the processed variables + processed_variables.append(variable) + + # if the current complement is not empty, add it to the result + if not current_complement.is_empty(): + result.add(current_complement) + + return result + + def is_empty(self) -> bool: if len(self) == 0: return True @@ -82,15 +121,94 @@ def is_empty(self) -> bool: return False - def contains(self, item) -> bool: - pass + def contains(self, item: Tuple) -> bool: + for assignment, value in zip(self.assignments, item): + if not assignment.contains(value): + return False + return True def __hash__(self): return hash(tuple(self.items())) - def __lt__(self, other): - pass + def __lt__(self, other: Self): + for variable, assignment in self.items(): + if assignment < other[variable]: + return True + return False + + def non_empty_to_string(self) -> str: + return "{" + ", ".join(f"{variable.name} = {assignment}" for variable, assignment in self.items()) + "}" class Event(AbstractCompositeSet): - ... + """ + An event is a disjoint set of simple events. + + Every simple event added to this event that is missing variables that any other event in this event has, will be + extended with the missing variable. The missing variables are mapped to their domain. + + """ + + simple_sets: SortedSet[SimpleEvent] + + def simplify(self) -> Self: + simplified, changed = self.simplify_once() + while changed: + simplified, changed = simplified.simplify_once() + return simplified + + def simplify_once(self) -> Tuple[Self, bool]: + """ + Simplify the event once. This simplification is not guaranteed to as simple as possible. + + :return: The simplified event and a boolean indicating whether the event has changed or not. + """ + + for event_a, event_b in itertools.combinations(self.simple_sets, 2): + different_variables = SortedSet() + + # get all events where these two events differ + for variable in event_a.variables: + if event_a[variable] != event_b[variable]: + different_variables.add(variable) + + # if the pair of simple events mismatches in more than one dimension it cannot be simplified + if len(different_variables) > 1: + break + + # if the pair of simple events mismatches in more than one dimension skip it + if len(different_variables) > 1: + continue + + # get the dimension where the two events differ + different_variable = different_variables[0] + + # initialize the simplified event + simplified_event = SimpleEvent() + + # for every variable + for variable in event_a.variables: + + # if the variable is the one where the two events differ + if variable == different_variable: + # set it to the union of the two events + simplified_event[variable] = event_a[variable].union_with(event_b[variable]) + + # if the variable has the same assignment + else: + # copy to the simplified event + simplified_event[variable] = event_a[variable] + + # create a new event with the simplified event and all other events + result = Event( + [simplified_event] + [event for event in self.simple_sets if event != event_a and event != event_b]) + return result, True + + # if nothing happened, return the original event and False + return self, False + + def new_empty_set(self) -> Self: + return Event() + + def complement_if_empty(self) -> Self: + raise NotImplementedError("Complement of an empty Event is not yet supported.") diff --git a/src/random_events/set.py b/src/random_events/set.py index e99d7ab..a3c8968 100644 --- a/src/random_events/set.py +++ b/src/random_events/set.py @@ -1,5 +1,6 @@ import enum from abc import abstractmethod +from typing import Dict, Any from sortedcontainers import SortedSet from typing_extensions import Self @@ -51,6 +52,13 @@ def __hash__(self): def __lt__(self, other): return self.value < other.value + def to_json(self) -> Dict[str, Any]: + return {**super().to_json(), "value": self.value} + + @classmethod + def _from_json(cls, data: Dict[str, Any]) -> Self: + return cls(data["value"]) + class Set(sigma_algebra.AbstractCompositeSet): diff --git a/src/random_events/sigma_algebra.py b/src/random_events/sigma_algebra.py index e073513..855098c 100644 --- a/src/random_events/sigma_algebra.py +++ b/src/random_events/sigma_algebra.py @@ -1,14 +1,16 @@ import itertools from abc import abstractmethod -from typing import Tuple +from typing import Tuple, Dict, Any from typing_extensions import Self, Set, Iterable, Optional from sortedcontainers import SortedSet +from .utils import SubclassJSONSerializer + EMPTY_SET_SYMBOL = "∅" -class AbstractSimpleSet: +class AbstractSimpleSet(SubclassJSONSerializer): """ Abstract class for simple sets. @@ -104,7 +106,7 @@ def __lt__(self, other): raise NotImplementedError -class AbstractCompositeSet: +class AbstractCompositeSet(SubclassJSONSerializer): """ Abstract class for composite sets. @@ -403,3 +405,20 @@ def add_simple_set(self, simple_set: AbstractSimpleSet): def __eq__(self, other: Self): return self.simple_sets._list == other.simple_sets._list + + def __hash__(self): + return hash(tuple(self.simple_sets)) + + def __lt__(self, other: Self): + if self.is_empty(): + return True + if other.is_empty(): + return False + return self.simple_sets[0] < other.simple_sets[0] + + def to_json(self) -> Dict[str, Any]: + return {**super().to_json(), "simple_sets": [simple_set.to_json() for simple_set in self.simple_sets]} + + @classmethod + def _from_json(cls, data: Dict[str, Any]) -> Self: + return cls([AbstractSimpleSet.from_json(simple_set) for simple_set in data["simple_sets"]]) diff --git a/src/random_events/utils.py b/src/random_events/utils.py index ea87afe..2e5d6c5 100644 --- a/src/random_events/utils.py +++ b/src/random_events/utils.py @@ -1,3 +1,5 @@ +from abc import abstractmethod + from typing_extensions import Dict, Any, Self @@ -30,6 +32,7 @@ def to_json(self) -> Dict[str, Any]: return {"type": get_full_class_name(self.__class__)} @classmethod + @abstractmethod def _from_json(cls, data: Dict[str, Any]) -> Self: """ Create a variable from a json dict. diff --git a/src/random_events/variable.py b/src/random_events/variable.py new file mode 100644 index 0000000..9062361 --- /dev/null +++ b/src/random_events/variable.py @@ -0,0 +1,94 @@ +from typing_extensions import Self, Type, Dict, Any, Union + +from .interval import Interval, SimpleInterval +from .set import Set, SetElement +from .sigma_algebra import AbstractCompositeSet +from .utils import SubclassJSONSerializer + + +class Variable(SubclassJSONSerializer): + name: str + domain: AbstractCompositeSet + + def __init__(self, name: str, domain: AbstractCompositeSet): + self.name = name + self.domain = domain + + def __lt__(self, other: Self) -> bool: + """ + Returns True if self < other, False otherwise. + """ + return self.name < other.name + + def __gt__(self, other: Self) -> bool: + """ + Returns True if self > other, False otherwise. + """ + return self.name > other.name + + def __hash__(self) -> int: + return self.name.__hash__() + + def __eq__(self, other): + return self.name == other.name + + def __str__(self): + return f"{self.__class__.__name__}({self.name}, {self.domain})" + + def __repr__(self): + return f"{self.__class__.__name__}({self.name})" + + def to_json(self) -> Dict[str, Any]: + return { + **super().to_json(), + "name": self.name, + "domain": self.domain.to_json() + } + + @classmethod + def _from_json(cls, data: Dict[str, Any]) -> Self: + return cls(data["name"], AbstractCompositeSet.from_json(data["domain"])) + + +class Continuous(Variable): + """ + Class for continuous random variables. + + The domain of a continuous variable is the real line. + """ + domain: Interval + + def __init__(self, name: str, domain=None): + super().__init__(name, Interval([SimpleInterval(-float("inf"), float("inf"))])) + + +class Symbolic(Variable): + """ + Class for unordered, finite, discrete random variables. + + The domain of a symbolic variable is a set of values from an enumeration. + """ + domain: Set + + def __init__(self, name: str, domain: Union[Type[SetElement], SetElement]): + """ + Construct a symbolic variable. + :param name: The name. + :param domain: The enum class that lists all elements of the domain. + """ + if isinstance(domain, type) and issubclass(domain, SetElement): + super().__init__(name, Set([value for value in domain if value != domain.EMPTY_SET])) + else: + super().__init__(name, domain) + + +class Integer(Variable): + """ + Class for ordered, discrete random variables. + + The domain of an integer variable is the number line. + """ + domain: Interval = Interval([SimpleInterval(-float("inf"), float("inf"))]) + + def __init__(self, name: str, domain=None): + super().__init__(name, Interval([SimpleInterval(-float("inf"), float("inf"))])) diff --git a/src/random_events/variables.py b/src/random_events/variables.py deleted file mode 100644 index 5b55550..0000000 --- a/src/random_events/variables.py +++ /dev/null @@ -1,291 +0,0 @@ -from typing import Any, Iterable, Dict, Tuple - -import portion -from typing_extensions import Union - -from . import utils - -AssignmentType = Union[portion.Interval, Tuple] - - -class Variable(utils.SubclassJSONSerializer): - """ - Abstract base class for all variables. - """ - - name: str - """ - The name of the variable. The name is used for comparison and hashing. - """ - - domain: AssignmentType - """ - The set of possible events of the variable. - """ - - def __init__(self, name: str, domain: Any): - self.name = name - self.domain = domain - - def __lt__(self, other: "Variable") -> bool: - """ - Returns True if self < other, False otherwise. - """ - return self.name < other.name - - def __gt__(self, other: "Variable") -> bool: - """ - Returns True if self > other, False otherwise. - """ - return self.name > other.name - - def __hash__(self) -> int: - return self.name.__hash__() - - def __eq__(self, other): - return self.name == other.name and self.domain == other.domain - - def __str__(self): - return f"{self.__class__.__name__}({self.name}, {self.domain})" - - def __repr__(self): - return f"{self.__class__.__name__}({self.name})" - - def encode(self, value: Any) -> Any: - """ - Encode an element of the domain to a representation that is usable for computations. - - :param value: The element to encode - :return: The encoded element - """ - return value - - def decode(self, value: Any) -> Any: - """ - Decode an element to the domain from a representation that is usable for computations. - - :param value: The element to decode - :return: The decoded element - """ - return value - - def encode_many(self, elements: Iterable) -> Iterable[Any]: - """ - Encode many elements of the domain to representations that are usable for computations. - - :param elements: The elements to encode - :return: The encoded elements - """ - return elements - - def decode_many(self, elements: Iterable) -> Iterable[Any]: - """ - Decode many elements from the representations that are usable for computations to their domains. - - :param elements: The encoded elements - :return: The decoded elements - """ - return elements - - def to_json(self) -> Dict[str, Any]: - return {"name": self.name, "type": utils.get_full_class_name(self.__class__), "domain": self.domain} - - @classmethod - def _from_json(cls, data: Dict[str, Any]) -> 'Variable': - """ - Create a variable from a json dict. - This method is called from the from_json method after the correct subclass is determined. - - :param data: The json dict - :return: The variable - """ - return cls(name=data["name"], domain=data["domain"]) - - def complement_of_assignment(self, assignment: AssignmentType, encoded: bool = False) -> AssignmentType: - """ - Returns the complement of the assignment for the variable. - - :param assignment: The assignment - :param encoded: If the assignment is encoded - :return: The complement of the assignment - """ - raise NotImplementedError - - @staticmethod - def intersection_of_assignments(assignment1: AssignmentType, - assignment2: AssignmentType, - encoded: bool = False) -> AssignmentType: - """ - Returns the intersection of two assignments - - :param assignment1: The first assignment - :param assignment2: The second assignment - :param encoded: If the assignment is encoded - :return: The intersection of the assignments - """ - raise NotImplementedError - - @staticmethod - def union_of_assignments(assignment1: AssignmentType, - assignment2: AssignmentType, - encoded: bool = False) -> AssignmentType: - """ - Returns the union of two assignments - - :param assignment1: The first assignment - :param assignment2: The second assignment - :param encoded: If the assignment is encoded - :return: The union of the assignments - """ - raise NotImplementedError - - def assignment_to_json(self, assignment: AssignmentType) -> Any: - """ - Convert an assignment to a json serializable object. - """ - raise NotImplementedError - - def assignment_from_json(self, data: Any) -> AssignmentType: - """ - Convert an assignment from a json serializable object. - """ - raise NotImplementedError - - @property - def encoded_domain(self): - return self.encode_many(self.domain) - - def assignment_to_typst(self, assignment: AssignmentType) -> str: - """ - Convert an assignment to typst string. - """ - raise NotImplementedError - - -class Continuous(Variable): - """ - Class for real valued random variables. - """ - - domain: portion.Interval - - def __init__(self, name: str, domain: portion.Interval = portion.open(-portion.inf, portion.inf)): - super().__init__(name=name, domain=domain) - - def to_json(self) -> Dict[str, Any]: - return {"name": self.name, "type": utils.get_full_class_name(self.__class__), - "domain": portion.to_data(self.domain)} - - @classmethod - def _from_json(cls, data: Dict[str, Any]) -> 'Variable': - return cls(name=data["name"], domain=portion.from_data(data["domain"])) - - def complement_of_assignment(self, assignment: portion.Interval, encoded: bool = False) -> portion.Interval: - return self.domain - assignment - - @staticmethod - def intersection_of_assignments(assignment1: portion.Interval, - assignment2: portion.Interval, - encoded: bool = False) -> portion.Interval: - return assignment1 & assignment2 - - @staticmethod - def union_of_assignments(assignment1: portion.Interval, - assignment2: portion.Interval, - encoded: bool = False) -> portion.Interval: - return assignment1 | assignment2 - - def assignment_to_json(self, assignment: portion.Interval) -> Any: - return portion.to_data(assignment) - - def assignment_from_json(self, data: Any) -> portion.Interval: - return portion.from_data(data) - - def assignment_to_typst(self, assignment: AssignmentType) -> str: - return " union ".join([interval.__str__() for interval in assignment]) - - -class Discrete(Variable): - """ - Class for discrete countable random variables. - """ - domain: Tuple - - def __init__(self, name: str, domain: Iterable): - super().__init__(name=name, domain=tuple(sorted(set(domain)))) - - def encode(self, element: Any) -> int: - """ - Encode an element of the domain to its index. - - :param element: The element to encode - :return: The index of the element - """ - return self.domain.index(element) - - def decode(self, index: int) -> Any: - """ - Decode an index to its element of the domain. - - :param index: The elements index - :return: The element itself - """ - return self.domain[index] - - def encode_many(self, elements: Iterable) -> Iterable[int]: - """ - Encode many elements of the domain to the indices of the elements. - - :param elements: The elements to encode - :return: The encoded elements - """ - return tuple(map(self.encode, elements)) - - def decode_many(self, elements: Iterable[int]) -> Iterable[Any]: - """ - Decode many elements from indices to their domains. - - :param elements: The encoded elements - :return: The decoded elements - """ - return tuple(map(self.decode, elements)) - - def complement_of_assignment(self, assignment: Tuple, encoded: bool = False) -> Tuple: - if not encoded: - return tuple(sorted(set(self.domain) - set(assignment))) - else: - return tuple(sorted(set(range(len(self.domain))) - set(assignment))) - - @staticmethod - def intersection_of_assignments(assignment1: Tuple, - assignment2: Tuple, - encoded: bool = False) -> Tuple: - - return tuple(sorted(set(assignment1) & set(assignment2))) - - @staticmethod - def union_of_assignments(assignment1: Tuple, - assignment2: Tuple, - encoded: bool = False) -> Tuple: - return tuple(sorted(set(assignment1) | set(assignment2))) - - def assignment_to_json(self, assignment: Tuple) -> Tuple: - return assignment - - def assignment_from_json(self, data: Any) -> AssignmentType: - return tuple(data) - - def assignment_to_typst(self, assignment: AssignmentType) -> str: - return "{" + ", ".join([str(element) for element in assignment]) + "}" - - -class Symbolic(Discrete): - """ - Class for unordered, finite, discrete random variables. - """ - ... - - -class Integer(Discrete): - """Class for ordered, discrete random variables.""" - ... diff --git a/test/test_events.py b/test/test_events.py deleted file mode 100644 index c921ae8..0000000 --- a/test/test_events.py +++ /dev/null @@ -1,482 +0,0 @@ -import unittest - -import portion - -from random_events.events import VariableMap, Event, EncodedEvent, ComplexEvent -from random_events.variables import Continuous, Integer, Symbolic - -import plotly.graph_objects as go - - -class VariableTestCase(unittest.TestCase): - - integer: Integer - symbol: Symbolic - real: Continuous - event: VariableMap - - @classmethod - def setUpClass(cls): - """ - Create some event for testing. - """ - cls.integer = Integer("integer", set(range(10))) - cls.symbol = Symbolic("symbol", {"a", "b", "c"}) - cls.real = Continuous("real") - cls.event = VariableMap({cls.integer: 1, cls.symbol: "a", cls.real: 1.0}) - - def test_creation(self): - """ - Test that the event is created correctly. - """ - self.assertEqual(self.event[self.integer], 1) - self.assertEqual(self.event[self.symbol], "a") - self.assertEqual(self.event[self.real], 1.0) - - def test_string_access(self): - """ - Test that the event is accessible by string. - """ - self.event["integer"] = self.event[self.integer] - self.assertEqual(self.event["integer"], self.event[self.integer]) - self.assertEqual(self.event["symbol"], self.event[self.symbol]) - self.assertEqual(self.event["real"], self.event[self.real]) - - def test_raising(self): - """ - Test that the event raises an error if the variable is not in the map. - """ - self.assertRaises(KeyError, lambda: self.event["not_in_map"]) - - -class EventTestCase(unittest.TestCase): - - integer: Integer - symbol: Symbolic - real: Continuous - event: Event - - @classmethod - def setUpClass(cls): - """ - Create some event for testing. - """ - cls.integer = Integer("integer", set(range(10))) - cls.symbol = Symbolic("symbol", {"a", "b", "c"}) - cls.real = Continuous("real") - cls.event = Event({cls.integer: 1, cls.symbol: "a", cls.real: 1.0}) - - def test_wrapping(self): - """ - Test that the event is wrapped correctly. - """ - self.assertEqual(self.event[self.integer], (1,)) - self.assertEqual(self.event[self.symbol], ("a",)) - self.assertEqual(self.event[self.real], portion.singleton(1.0)) - - def test_set_assignment(self): - """ - Test that the event is set correctly. - """ - event = self.event.copy() - event[self.integer] = (2, 3) - self.assertEqual(event[self.integer], (2, 3)) - event[self.symbol] = ("b", "c") - self.assertEqual(event[self.symbol], ("b", "c")) - event[self.real] = portion.closed(0.0, 1.0) - self.assertEqual(event[self.real], portion.closed(0.0, 1.0)) - - def test_raising(self): - """ - Test that errors are raised correctly. - """ - event = self.event.copy() - with self.assertRaises(ValueError): - event[self.integer] = 11 - - with self.assertRaises(ValueError): - event[self.integer] = (-1,) - - with self.assertRaises(ValueError): - event[self.symbol] = "d" - - with self.assertRaises(ValueError): - event[self.symbol] = ("d",) - - def test_encode(self): - """ - Test that events are correctly encoded. - """ - encoded = self.event.encode() - self.assertIsInstance(encoded, EncodedEvent) - decoded = encoded.decode() - self.assertEqual(self.event, decoded) - - def test_intersection(self): - """ - Test ordinary intersection of events - """ - event_1 = Event() - event_1[self.integer] = (1, 2) - event_1[self.symbol] = ("a", "b") - - result = event_1.intersection(self.event) - - self.assertEqual(result["integer"], (1, )) - self.assertEqual(result["symbol"], ("a", )) - self.assertEqual(result["real"], self.event["real"]) - - def test_empty_intersection(self): - """ - Test empty intersection of events - """ - event_1 = Event() - event_1[self.integer] = (1, 2) - event_1[self.symbol] = ("c", ) - result = event_1.intersection(self.event) - self.assertTrue(result.is_empty()) - - def test_intersection_alias(self): - event_1 = Event() - event_1[self.integer] = (1, 2) - event_1[self.symbol] = ("a", "b") - self.assertEqual(event_1 & self.event, event_1.intersection(self.event)) - self.assertEqual(event_1 & self.event, self.event & event_1) - - def test_union_without_intersection(self): - event1 = Event({self.integer: 1, self.symbol: "a", self.real: 2.0}) - union = event1.union(self.event) - self.assertIsInstance(union, ComplexEvent) - self.assertEqual(len(union.events), 2) - self.assertTrue(union.are_events_disjoint()) - - def test_difference(self): - event_1 = Event() - event_1[self.integer] = (1, 2, 5) - event_1[self.symbol] = ("a", "b") - result = event_1.difference(self.event) - self.assertEqual(len(result.events), 3) - self.assertTrue(result.are_events_disjoint()) - - def test_difference_alias(self): - event_1 = Event() - event_1[self.integer] = (1, 2, 5) - event_1[self.symbol] = ("a", "b") - self.assertEqual(event_1 - self.event, event_1.difference(self.event)) - # differences are not symmetric - self.assertNotEqual(event_1 - self.event, self.event - event_1) - - def test_equality(self): - self.assertEqual(self.event, self.event) - self.assertNotEqual(self.event, Event()) - - def test_raises_on_operation_with_different_types(self): - with self.assertRaises(TypeError): - self.event & self.event.encode() - - with self.assertRaises(TypeError): - self.event | self.event.encode() - - with self.assertRaises(TypeError): - self.event - self.event.encode() - - def test_serialization(self): - json = self.event.to_json() - event = Event.from_json(json) - self.assertEqual(event, self.event) - - def test_serialization_with_complex_interval(self): - event = Event({self.real: portion.closed(0, 1) | portion.closed(2, 3)}) - json = event.to_json() - event_ = Event.from_json(json) - self.assertEqual(event_, event) - - -class EncodedEventTestCase(unittest.TestCase): - - integer: Integer - symbol: Symbolic - real: Continuous - - @classmethod - def setUpClass(cls): - """ - Create some event for testing. - """ - cls.integer = Integer("integer", set(range(10))) - cls.symbol = Symbolic("symbol", {"a", "b", "c"}) - cls.real = Continuous("real") - - def test_creation(self): - event = EncodedEvent() - event[self.integer] = 1 - self.assertEqual(event[self.integer], (1,)) - event[self.integer] = (1, 2) - self.assertEqual(event[self.integer], (1, 2)) - event[self.symbol] = 0 - self.assertEqual(event[self.symbol], (0, )) - event[self.symbol] = {1, 0} - self.assertEqual(event[self.symbol], (0, 1)) - - interval = portion.open(0, 1) - event[self.real] = interval - self.assertEqual(interval, event[self.real]) - - def test_raises(self): - event = EncodedEvent() - with self.assertRaises(ValueError): - event[self.symbol] = 3 - - with self.assertRaises(ValueError): - event[self.symbol] = portion.open(0, 1) - - with self.assertRaises(ValueError): - event[self.symbol] = (1, 2, 3, 4) - - def test_dict_like_creation(self): - event = EncodedEvent(zip([self.integer, self.symbol], [1, 0])) - self.assertEqual(event[self.integer], (1,)) - self.assertEqual(event[self.symbol], (0,)) - - event = EncodedEvent(zip([self.integer, self.symbol], [[0, 1], 0])) - self.assertEqual(event[self.integer], (0, 1)) - self.assertEqual(event[self.symbol], (0,)) - - def test_set_operations_return_type(self): - event = EncodedEvent(zip([self.integer, self.symbol], [1, 0])) - self.assertEqual(type(event & event), EncodedEvent) - self.assertEqual(type(event | event), ComplexEvent) - self.assertEqual(type(event - event), ComplexEvent) - - def test_intersection_with_empty(self): - event = Event({self.integer: ()}) - complete_event = Event({self.integer: self.integer.domain}) - intersection = event.intersection(complete_event) - self.assertIn(self.integer, intersection.keys()) - self.assertTrue(intersection.is_empty()) - - def test_serialization(self): - event = EncodedEvent() - event[self.integer] = (1, 2) - event[self.symbol] = {1, 0} - event[self.real] = portion.open(0, 1) - - json = event.to_json() - event_ = EncodedEvent.from_json(json) - self.assertEqual(event, event_) - - -class ComplexEventTestCase(unittest.TestCase): - - x: Continuous = Continuous("x") - y: Continuous = Continuous("y") - z: Continuous = Continuous("z") - a: Symbolic = Symbolic("a", {"a1", "a2"}) - b: Symbolic = Symbolic("b", {"b1", "b2", "b3"}) - unit_interval = portion.closed(0, 1) - - def test_union(self): - event_1 = Event({self.x: self.unit_interval, self.y: self.unit_interval}) - event_2 = Event({self.x: portion.closed(0.5, 2), self.y: portion.closed(0.5, 2)}) - union = event_1.union(event_2) - self.assertIsInstance(union, ComplexEvent) - self.assertEqual(len(union.events), 3) - self.assertTrue(union.are_events_disjoint()) - # go.Figure(union.plot()).show() - - def test_union_of_complex_events(self): - event_1 = Event({self.x: self.unit_interval, self.y: self.unit_interval}) - event_2 = Event({self.x: portion.closed(0.5, 2), self.y: portion.closed(0.5, 2)}) - complex_event_1 = event_1.union(event_2) - event_3 = Event({self.x: portion.closed(0.5, 2), self.y: portion.closed(-0.5, 2)}) - complex_event_2 = event_1.union(event_3) - result = complex_event_1.union(complex_event_2) - self.assertIsInstance(result, ComplexEvent) - self.assertTrue(result.are_events_disjoint()) - - def test_make_events_disjoint_and_simplify(self): - event_1 = Event({self.x: self.unit_interval, self.y: self.unit_interval}) - event_2 = Event({self.x: portion.closed(0.5, 2), self.y: portion.closed(0.5, 2)}) - complex_event = ComplexEvent([event_1, event_2]) - self.assertFalse(complex_event.are_events_disjoint()) - complex_event = complex_event.make_events_disjoint() - self.assertEqual(len(complex_event.events), 5) - self.assertTrue(complex_event.are_events_disjoint()) - simplified_event = complex_event.simplify() - self.assertEqual(len(simplified_event.events), 3) - self.assertTrue(simplified_event.are_events_disjoint()) - - def test_are_events_disjoint(self): - event1 = Event({self.x: portion.closed(0, 1), self.y: portion.closed(0, 1)}) - event2 = Event({self.x: portion.closed(0.5, 2), self.y: portion.closed(0.5, 2)}) - event3 = Event({self.x: portion.closed(2, 3), self.y: portion.closed(2, 3)}) - - complex_event = ComplexEvent((event1, event2)) - self.assertFalse(complex_event.are_events_disjoint()) - - complex_event = ComplexEvent((event1, event3)) - self.assertTrue(complex_event.are_events_disjoint()) - - complex_event = ComplexEvent((event1, event3, event2,)) - self.assertFalse(complex_event.are_events_disjoint()) - - def test_from_continuous_complement(self): - event = Event({self.x: portion.closed(0, 1), self.y: portion.closed(0, 1)}) - complement = event.complement() - - self.assertEqual(len(complement.events), 2) - self.assertTrue(complement.are_events_disjoint()) - - c1 = complement.events[0] - self.assertEqual(c1[self.x], portion.open(-portion.inf, 0) | portion.open(1, portion.inf)) - self.assertEqual(c1[self.y], portion.open(-portion.inf, portion.inf)) - - c2 = complement.events[1] - self.assertEqual(c2[self.y], portion.open(-portion.inf, 0) | portion.open(1, portion.inf)) - self.assertEqual(c2[self.x], event[self.x]) - - for sub_event in complement.events: - self.assertTrue(sub_event.intersection(event).is_empty()) - - def test_complement_3d(self): - event = Event({self.x: portion.closed(0, 1), - self.y: portion.closed(0, 1), - self.z: portion.closed(0, 1)}) - complement = event.complement() - self.assertEqual(len(complement.events), 3) - self.assertTrue(complement.are_events_disjoint()) - - def test_from_discrete_complement(self): - event = Event({self.a: "a1", self.b: "b1"}) - complement = event.complement() - - self.assertEqual(len(complement.events), 2) - self.assertTrue(complement.are_events_disjoint()) - - c1 = complement.events[0] - self.assertEqual(c1[self.a], ("a2", )) - self.assertEqual(c1[self.b], self.b.domain) - - c2 = complement.events[1] - self.assertEqual(c2[self.a], ("a1", )) - self.assertEqual(c2[self.b], ("b2", "b3")) - - for sub_event in complement.events: - self.assertTrue(sub_event.intersection(event).is_empty()) - - def test_chained_complement(self): - event = Event({self.x: portion.closed(0, 1), self.y: portion.closed(0, 1)}) - complement = event.complement() - self.assertEqual(len(complement.events), 2) - copied_event = complement.complement() - self.assertEqual(len(copied_event.events), 1) - self.assertEqual(copied_event.events[0], event) - - def test_union_of_simple_with_complex(self): - event = Event({self.x: portion.closed(0, 1), self.y: portion.closed(0, 1)}) - complex_event = ComplexEvent([event]) - union1 = event.union(complex_event) - union2 = complex_event.union(event) - self.assertEqual(union1, union2) - - def test_union_with_different_variables(self): - event1 = Event({self.x: portion.closed(0, 1)}) - event2 = Event({self.y: portion.closed(0, 1)}) - union = event1.union(event2) - for event in union.events: - self.assertEqual(len(event), 2) - - def test_copy(self): - event = Event({self.x: portion.closed(0, 1), self.y: portion.closed(0, 1)}) - copied = event.copy() - self.assertEqual(event, copied) - self.assertIsNot(event, copied) - - def test_decode_encode(self): - event = Event({self.x: portion.closed(0, 1), self.y: portion.closed(0, 1)}) - encoded = event.encode() - decoded = encoded.decode() - self.assertEqual(event, decoded) - - def test_marginal_event(self): - event = Event({self.x: portion.closed(0, 1), self.y: portion.closed(0, 1)}) - complement = event.complement() - marginal_event = complement.marginal_event([self.x]) - self.assertEqual(len(marginal_event.events), 1) - self.assertEqual(marginal_event.events[0][self.x], portion.open(-portion.inf, portion.inf)) - - def test_merge_if_1d(self): - event1 = Event({self.x: portion.closed(0, 1)}) - event2 = Event({self.x: portion.closed(3, 4)}) - complex_event = ComplexEvent([event1, event2]) - merged = complex_event.merge_if_one_dimensional() - self.assertEqual(len(merged.events), 1) - self.assertEqual(merged.events[0][self.x], portion.closed(0, 1) | portion.closed(3, 4)) - - def test_serialization(self): - event = Event({self.x: portion.closed(0, 1), self.y: portion.closed(0, 1)}) - complement = event.complement() - json = complement.to_json() - complement_ = ComplexEvent.from_json(json) - self.assertEqual(complement, complement_) - - def test_intersection_symbol_and_real(self): - event = ComplexEvent([EncodedEvent({self.x: portion.closed(0, 1)})]) - event2 = EncodedEvent({self.a: (0, )}) - result = event & event2 - self.assertEqual(len(result.events), 1) - event_ = result.events[0] - self.assertEqual(event_[self.x], portion.closed(0, 1)) - self.assertEqual(event_[self.a], (0, )) - - def test_no_simplify_high_dimensions(self): - x = Continuous("x") - y = Continuous("y") - z = Continuous("z") - - event_1 = Event({x: portion.closed(0, 1), y: portion.closed(0, 1), z: portion.closed(0, 1)}) - event_2 = Event({x: portion.closed(0, 1), y: portion.closed(0.5, 1.5), z: portion.closed(0.5, 1.5)}) - complex_event = ComplexEvent([event_1, event_2]) - simplified = complex_event.simplify() - self.assertEqual(len(simplified.events), 2) - - - -class PlottingTestCase(unittest.TestCase): - x: Continuous = Continuous("x") - y: Continuous = Continuous("y") - z: Continuous = Continuous("z") - - def test_plot_2d(self): - event = Event({self.x: portion.closed(0, 1), self.y: portion.closed(0, 1)}) - fig = go.Figure(event.plot()) - # fig.show() - - def test_plot_3d(self): - event = Event({self.x: portion.closed(0, 1), self.y: portion.closed(0, 1), self.z: portion.closed(0, 1)}) - fig = go.Figure(event.plot()) - # fig.show() - - def test_plot_complex_event_2d(self): - event = Event({self.x: portion.closed(0, 1), self.y: portion.closed(0, 1)}) - complement = event.complement() - limiting_event = Event({self.x: portion.closed(-1, 2), self.y: portion.closed(-1, 2)}) - result = complement.intersection(limiting_event) - fig = go.Figure(result.plot(), result.plotly_layout()) - # fig.show() - - def test_plot_complex_event_3d(self): - event = Event({self.x: portion.closed(0, 1), - self.y: portion.closed(0, 1), - self.z: portion.closed(0, 1)}) - complement = event.complement() - limiting_event = Event({self.x: portion.closed(-1, 2), - self.y: portion.closed(-1, 2), - self.z: portion.closed(-1, 2)}) - result = complement.intersection(ComplexEvent([limiting_event])) - fig = go.Figure(result.plot(), result.plotly_layout()) - # fig.show() - - -if __name__ == '__main__': - unittest.main() diff --git a/test/test_interval.py b/test/test_interval.py index c70611f..237c737 100644 --- a/test/test_interval.py +++ b/test/test_interval.py @@ -1,6 +1,7 @@ import unittest from random_events.interval import * +from random_events.sigma_algebra import AbstractSimpleSet class SimpleIntervalTestCase(unittest.TestCase): @@ -42,6 +43,12 @@ def test_contains(self): self.assertFalse(a.contains(-1)) self.assertFalse(a.contains(1.1)) + def test_to_json(self): + a = SimpleInterval(0, 1) + b = AbstractSimpleSet.from_json(a.to_json()) + self.assertIsInstance(b, SimpleInterval) + self.assertEqual(a, b) + class IntervalTestCase(unittest.TestCase): @@ -68,6 +75,13 @@ def test_union(self): self.assertEqual(union_a_d_b_c, union_a_d_b_c_) self.assertTrue(union_a_d_b_c.is_disjoint()) + def test_to_json(self): + a = SimpleInterval(0, 1) + b = Interval([a]) + c = AbstractSimpleSet.from_json(b.to_json()) + self.assertIsInstance(c, Interval) + self.assertEqual(b, c) + if __name__ == '__main__': unittest.main() diff --git a/test/test_product_algebra.py b/test/test_product_algebra.py new file mode 100644 index 0000000..8df55f6 --- /dev/null +++ b/test/test_product_algebra.py @@ -0,0 +1,74 @@ +import unittest + +from sortedcontainers import SortedSet + +from random_events.variable import Continuous, Symbolic +from random_events.interval import Interval, SimpleInterval +from random_events.product_algebra import SimpleEvent, Event +from random_events.set import SetElement, Set + + +class TestEnum(SetElement): + EMPTY_SET = 0 + A = 1 + B = 2 + C = 4 + + +class SimpleEventTestCase(unittest.TestCase): + x = Continuous("x") + y = Continuous("y") + a = Symbolic("a", TestEnum) + b = Symbolic("b", TestEnum) + + def test_constructor(self): + event = SimpleEvent({self.a: Set([TestEnum.A]), self.x: Interval([SimpleInterval(0, 1)]), + self.y: Interval([SimpleInterval(0, 1)])}) + + self.assertEqual(event[self.x], Interval([SimpleInterval(0, 1)])) + self.assertEqual(event[self.y], Interval([SimpleInterval(0, 1)])) + self.assertEqual(event[self.a], Set([TestEnum.A])) + + self.assertFalse(event.is_empty()) + self.assertTrue(event.contains((TestEnum.A, 0.5, 0.1,))) + self.assertFalse(event.contains((TestEnum.B, 0.5, 0.1,))) + + def test_intersection_with(self): + event_1 = SimpleEvent({self.a: Set([TestEnum.A, TestEnum.B]), self.x: Interval([SimpleInterval(0, 1)]), + self.y: Interval([SimpleInterval(0, 1)])}) + event_2 = SimpleEvent({self.a: Set([TestEnum.A]), self.x: Interval([SimpleInterval(0.5, 1)])}) + intersection = event_1.intersection_with(event_2) + intersection_ = SimpleEvent({self.a: Set([TestEnum.A]), self.x: Interval([SimpleInterval(0.5, 1)]), + self.y: Interval([SimpleInterval(0, 1)])}) + self.assertEqual(intersection, intersection_) + self.assertNotEqual(intersection, event_1) + + event_3 = SimpleEvent({self.a: Set([TestEnum.C])}) + intersection = event_1.intersection_with(event_3) + self.assertTrue(intersection.is_empty()) + + def test_complement(self): + event = SimpleEvent( + {self.a: Set([TestEnum.A, TestEnum.B]), self.x: Interval([SimpleInterval(0, 1)]), self.y: self.y.domain}) + complement = event.complement() + self.assertEqual(len(complement), 2) + complement_1 = SimpleEvent({self.a: Set([TestEnum.C]), self.x: self.x.domain, self.y: self.y.domain}) + complement_2 = SimpleEvent({self.a: event[self.a], self.x: event[self.x].complement(), self.y: self.y.domain}) + self.assertEqual(complement, SortedSet([complement_1, complement_2])) + + def test_simplify(self): + event_1 = SimpleEvent({self.a: Set([TestEnum.A, TestEnum.B]), self.x: Interval([SimpleInterval(0, 1)]), + self.y: Interval([SimpleInterval(0, 1)])}) + event_2 = SimpleEvent({self.a: Set([TestEnum.C]), self.x: Interval([SimpleInterval(0, 1)]), + self.y: Interval([SimpleInterval(0, 1)])}) + event = Event([event_1, event_2]) + simplified = event.simplify() + self.assertEqual(len(simplified.simple_sets), 1) + + result = Event([SimpleEvent({self.a: self.a.domain, self.x: Interval([SimpleInterval(0, 1)]), + self.y: Interval([SimpleInterval(0, 1)])})]) + self.assertEqual(simplified, result) + + +if __name__ == '__main__': + unittest.main() diff --git a/test/test_set.py b/test/test_set.py index 7183d1b..bdb0679 100644 --- a/test/test_set.py +++ b/test/test_set.py @@ -2,6 +2,8 @@ import enum import unittest +from random_events.sigma_algebra import AbstractSimpleSet + class TestEnum(SetElement): EMPTY_SET = 0 @@ -38,6 +40,11 @@ def test_contains(self): self.assertFalse(a.contains(TestEnum.B)) self.assertFalse(a.contains(TestEnum.C)) + def test_to_json(self): + a = TestEnum.A + b = AbstractSimpleSet.from_json(a.to_json()) + self.assertEqual(a, b) + class SetTestCase(unittest.TestCase): @@ -58,5 +65,20 @@ def test_complement(self): s = Set([TestEnum.A, TestEnum.B]) self.assertEqual(s.complement(), Set([TestEnum.C])) + def test_to_json(self): + s = Set([TestEnum.A, TestEnum.B]) + s_ = AbstractSimpleSet.from_json(s.to_json()) + self.assertEqual(s, s_) + + def test_to_json_with_dynamic_enum(self): + enum_ = SetElement("Foo", "A B C") + s = Set([enum_.A, enum_.B]) + s_ = s.to_json() + del enum_ + s_ = AbstractSimpleSet.from_json(s_) + self.assertEqual(s, s_) + + + if __name__ == '__main__': unittest.main() diff --git a/test/test_variable.py b/test/test_variable.py index f9afca4..600b145 100644 --- a/test/test_variable.py +++ b/test/test_variable.py @@ -1,6 +1,6 @@ import unittest -from random_events.better_variables import * +from random_events.variable import * from random_events.interval import * @@ -12,11 +12,15 @@ class TestEnum(SetElement): class ContinuousTestCase(unittest.TestCase): + x = Continuous("x") def test_creation(self): - x = Continuous("x") - self.assertEqual(x.name, "x") - self.assertEqual(x.domain, Interval([SimpleInterval(-float("inf"), float("inf"))])) + self.assertEqual(self.x.name, "x") + self.assertEqual(self.x.domain, Interval([SimpleInterval(-float("inf"), float("inf"))])) + + def test_to_json(self): + x_ = Variable.from_json(self.x.to_json()) + self.assertEqual(self.x, x_) class IntegerTestCase(unittest.TestCase): @@ -34,6 +38,11 @@ def test_creation(self): self.assertEqual(x.name, "x") self.assertEqual(x.domain, Set([TestEnum.A, TestEnum.B, TestEnum.C])) + def test_to_json(self): + x = Symbolic("x", TestEnum) + x_ = Variable.from_json(x.to_json()) + self.assertEqual(x, x_) + if __name__ == '__main__': unittest.main() diff --git a/test/test_variables.py b/test/test_variables.py deleted file mode 100644 index e06a50f..0000000 --- a/test/test_variables.py +++ /dev/null @@ -1,105 +0,0 @@ -import unittest - -import portion - -from random_events.variables import Integer, Symbolic, Continuous, Variable - - -class VariablesTestCase(unittest.TestCase): - """Tests for `variables.py`.""" - - integer: Integer - symbol: Symbolic - real: Continuous - - @classmethod - def setUpClass(cls): - """ - Create some variables for testing. - """ - cls.integer = Integer("integer", set(range(10))) - cls.symbol = Symbolic("symbol", {"a", "b", "c"}) - cls.real = Continuous("real") - - def test_creation(self): - """ - Test that the variables are created correctly. - """ - self.assertEqual(self.integer.name, "integer") - self.assertEqual(self.integer.domain, tuple(range(10))) - self.assertEqual(self.symbol.name, "symbol") - self.assertEqual(self.symbol.domain, ("a", "b", "c")) - self.assertEqual(self.real.name, "real") - self.assertEqual(self.real.domain, portion.open(-portion.inf, portion.inf)) - - def test_hash(self): - """ - Test that the variables are hashable. - """ - self.assertTrue(hash(self.integer)) - self.assertTrue(hash(self.symbol)) - self.assertTrue(hash(self.real)) - - def test_ordering(self): - """ - Test that the variables are ordered correctly. - """ - self.assertLess(self.integer, self.symbol) - self.assertLess(self.real, self.symbol) - self.assertLess(self.integer, self.real) - self.assertEqual([self.integer, self.real, self.symbol], sorted([self.symbol, self.integer, self.real])) - - def test_equality(self): - """ - Test that the variables are equal to themselves and not equal to others. - """ - self.assertEqual(self.integer, Integer("integer", tuple(range(10)))) - self.assertNotEqual(self.symbol, Symbolic("symbol", ("d", "b", "c"))) - self.assertEqual(self.symbol, Symbolic("symbol", ("a", "b", "c"))) - self.assertEqual(self.real, Continuous("real")) - - def test_to_json(self): - """ - Test that the variables can be dumped to json. - """ - self.assertTrue(self.symbol.to_json()) - self.assertTrue(self.integer.to_json()) - self.assertTrue(self.real.to_json()) - - def test_encode(self): - """ - Test that the variables can be encoded. - """ - self.assertEqual(self.integer.encode(1), 1) - self.assertEqual(self.symbol.encode("b"), 1) - self.assertEqual(self.real.encode(1.0), 1.0) - - def test_decode(self): - """ - Test that the variables can be decoded. - """ - self.assertEqual(self.integer.decode(1), 1) - self.assertEqual(self.symbol.decode(1), "b") - self.assertEqual(self.real.decode(1.0), 1.0) - - def test_polymorphic_serialization(self): - real = Variable.from_json(self.real.to_json()) - self.assertEqual(real, self.real) - - integer = Variable.from_json(self.integer.to_json()) - self.assertEqual(integer, self.integer) - - symbol = Variable.from_json(self.symbol.to_json()) - self.assertEqual(symbol, self.symbol) - - def test_complement_of_assignment(self): - """ - Test that the complement of an assignment is correct. - """ - self.assertEqual(self.real.complement_of_assignment(portion.closed(0, 1)), - portion.open(-portion.inf, 0) | portion.open(1, portion.inf)) - self.assertEqual(self.symbol.complement_of_assignment(("a",)), ("b", "c", )) - - -if __name__ == '__main__': - unittest.main()