diff --git a/src/random_events/better_variables.py b/src/random_events/better_variables.py new file mode 100644 index 0000000..0090266 --- /dev/null +++ b/src/random_events/better_variables.py @@ -0,0 +1,59 @@ +from dataclasses import dataclass + +from typing_extensions import Self, Type + +from .interval import Interval, SimpleInterval +from .set import Set, SetElement +from .sigma_algebra import AbstractCompositeSet + + +@dataclass +class Variable: + name: str + domain: AbstractCompositeSet + + def __lt__(self, other: Self) -> bool: + """ + Returns True if self < other, False otherwise. + """ + return self.name < other.name + + def __gt__(self, other: Self) -> bool: + """ + Returns True if self > other, False otherwise. + """ + return self.name > other.name + + def __hash__(self) -> int: + return self.name.__hash__() + + def __eq__(self, other): + return self.name == other.name + + def __str__(self): + return f"{self.__class__.__name__}({self.name}, {self.domain})" + + def __repr__(self): + return f"{self.__class__.__name__}({self.name})" + + +@dataclass +class Continuous(Variable): + domain: Interval = Interval([SimpleInterval(-float("inf"), float("inf"))]) + + +@dataclass +class Symbolic(Variable): + """ + Class for unordered, finite, discrete random variables. + """ + domain: Set + + def __init__(self, name: str, domain: Type): + super().__init__(name, Set([value for value in domain if value != domain.EMPTY_SET])) + + +@dataclass +class Integer(Variable): + """Class for ordered, discrete random variables.""" + domain: Interval = Interval([SimpleInterval(-float("inf"), float("inf"))]) diff --git a/src/random_events/interval.py b/src/random_events/interval.py index b4e7046..031c88b 100644 --- a/src/random_events/interval.py +++ b/src/random_events/interval.py @@ -177,3 +177,6 @@ def simplify(self) -> Self: def new_empty_set(self) -> Self: return Interval() + + def complement_if_empty(self) -> Self: + return Interval([SimpleInterval(float('-inf'), float('inf'), Bound.OPEN, Bound.OPEN)]) diff --git a/src/random_events/product_algebra.py b/src/random_events/product_algebra.py new file mode 100644 index 0000000..6d60d7c --- /dev/null +++ b/src/random_events/product_algebra.py @@ -0,0 +1,96 @@ +from collections.abc import dict_keys, dict_values +from typing import Any + +from sortedcontainers import SortedDict, SortedKeysView, SortedValuesView +from typing_extensions import Union, Any + +from .better_variables import * +from .sigma_algebra import * +from .variables import Variable + + +class VariableMap(SortedDict[Variable, Any]): + """ + A map of variables to values. + + Accessing a variable by name is also supported. + """ + + @property + def variables(self) -> dict_keys[Variable]: + return self.keys() + + def variable_of(self, name: str) -> Variable: + """ + Get the variable with the given name. + :param name: The variable's name + :return: The variable itself + """ + variable = [variable for variable in self.variables if variable.name == name] + if len(variable) == 0: + raise KeyError(f"Variable {name} not found in event {self}") + return variable[0] + + def __getitem__(self, item: Union[str, Variable]): + if isinstance(item, str): + item = self.variable_of(item) + return super().__getitem__(item) + + def __setitem__(self, key: Union[str, Variable], value: Any): + if isinstance(key, str): + key = self.variable_of(key) + + if not isinstance(key, Variable): + raise TypeError(f"Key must be a Variable if not already present, got {type(key)} instead.") + + super().__setitem__(key, value) + + def __copy__(self): + return self.__class__({variable: value for variable, value in self.items()}) + + +class SimpleEvent(AbstractSimpleSet, VariableMap[Variable, AbstractCompositeSet]): + + @property + def assignments(self) -> dict_values[AbstractCompositeSet]: + return self.values() + + def intersection_with(self, other: Self) -> Self: + variables = self.keys() | other.keys() + result = SimpleEvent() + for variable in variables: + if variable in self and variable in other: + result[variable] = self[variable].intersection_with(other[variable]) + elif variable in self: + result[variable] = self[variable] + else: + result[variable] = other[variable] + + return result + + def complement(self) -> SortedSet[Self]: + pass + + def is_empty(self) -> bool: + + if len(self) == 0: + return True + + for assignment in self.values(): + if assignment.is_empty(): + return True + + return False + + def contains(self, item) -> bool: + pass + + def __hash__(self): + return hash(tuple(self.items())) + + def __lt__(self, other): + pass + + +class Event(AbstractCompositeSet): + ... diff --git a/src/random_events/set.py b/src/random_events/set.py new file mode 100644 index 0000000..e99d7ab --- /dev/null +++ b/src/random_events/set.py @@ -0,0 +1,66 @@ +import enum +from abc import abstractmethod + +from sortedcontainers import SortedSet +from typing_extensions import Self + +from . import sigma_algebra + + +class SetElement(sigma_algebra.AbstractSimpleSet, enum.Enum): + """ + Base class for enums that are used as elements in a set. + + Classes that inherit from this class have to define an attribute called EMPTY_SET. + """ + + @property + @abstractmethod + def EMPTY_SET(self): + raise NotImplementedError + + @property + def all_elements(self): + return self.__class__ + + def intersection_with(self, other: Self) -> Self: + if self == other: + return self + else: + return self.all_elements.EMPTY_SET + + def complement(self) -> SortedSet[Self]: + result = SortedSet() + for element in self.all_elements: + if element != self and element != self.all_elements.EMPTY_SET: + result.add(element) + return result + + def is_empty(self) -> bool: + return self is self.all_elements.EMPTY_SET + + def contains(self, item: Self) -> bool: + return self == item + + def non_empty_to_string(self) -> str: + return self.name + + def __hash__(self): + return enum.Enum.__hash__(self) + + def __lt__(self, other): + return self.value < other.value + + +class Set(sigma_algebra.AbstractCompositeSet): + + simple_sets: SortedSet[SetElement] + + def complement_if_empty(self) -> Self: + raise NotImplementedError("I don't know how to do this yet.") + + def simplify(self) -> Self: + return self + + def new_empty_set(self) -> Self: + return Set([]) diff --git a/src/random_events/sigma_algebra.py b/src/random_events/sigma_algebra.py index b597e0a..e073513 100644 --- a/src/random_events/sigma_algebra.py +++ b/src/random_events/sigma_algebra.py @@ -163,16 +163,25 @@ def intersection_with_simple_set(self, other: AbstractSimpleSet) -> Self: [result.add_simple_set(simple_set.intersection_with(other)) for simple_set in self.simple_sets] return result + def intersection_with_simple_sets(self, other: SortedSet[AbstractSimpleSet]) -> Self: + """ + Form the intersection of this object with a set of simple sets. + + :param other: The set of simple sets + :return: The intersection of this set with the set of simple sets + """ + result = self.new_empty_set() + [result.simple_sets.update(self.intersection_with_simple_set(other_simple_set).simple_sets) + for other_simple_set in other] + return result + def intersection_with(self, other: Self) -> Self: """ Form the intersection of this object with another object. :param other: The other set :return: The intersection of this set with the other set """ - result = self.new_empty_set() - [result.simple_sets.update(self.intersection_with_simple_set(other_simple_set).simple_sets) - for other_simple_set in other.simple_sets] - return result + return self.intersection_with_simple_sets(other.simple_sets) def __and__(self, other): return self.intersection_with(other) @@ -187,12 +196,7 @@ def difference_with_simple_set(self, other: AbstractSimpleSet) -> Self: [result.simple_sets.update(simple_set.difference_with(other)) for simple_set in self.simple_sets] return result.make_disjoint() - def difference_with(self, other: Self) -> Self: - """ - Form the difference with another composite set. - :param other: The other set - :return: The difference of this set with the other set - """ + def difference_with_simple_sets(self, other: SortedSet[AbstractSimpleSet]) -> Self: # initialize the result result = self.new_empty_set() @@ -205,7 +209,7 @@ def difference_with(self, other: Self) -> Self: first_iteration = True # for every simple set in the other set - for other_simple_set in other.simple_sets: + for other_simple_set in other: # form the element wise difference difference_with_other_simple_set = own_simple_set.difference_with(other_simple_set) @@ -227,6 +231,14 @@ def difference_with(self, other: Self) -> Self: return result.make_disjoint() + def difference_with(self, other: Self) -> Self: + """ + Form the difference with another composite set. + :param other: The other set + :return: The difference of this set with the other set + """ + return self.difference_with_simple_sets(other.simple_sets) + def __sub__(self, other): return self.difference_with(other) @@ -234,16 +246,23 @@ def complement(self) -> Self: """ :return: The complement of this set """ + + if self.is_empty(): + return self.complement_if_empty() + result = self.new_empty_set() - first_iteration = True - for simple_set in self.simple_sets: - if first_iteration: - result = simple_set.complement() - first_iteration = False - else: - result = result.intersection_with(simple_set.complement()) + result.simple_sets = self.simple_sets[0].complement() + for simple_set in self.simple_sets[1:]: + result = result.intersection_with_simple_sets(simple_set.complement()) return result.make_disjoint() + @abstractmethod + def complement_if_empty(self) -> Self: + """ + :return: The complement of this if it is empty. + """ + raise NotImplementedError + def __invert__(self): return self.complement() diff --git a/test/test_set.py b/test/test_set.py new file mode 100644 index 0000000..7183d1b --- /dev/null +++ b/test/test_set.py @@ -0,0 +1,62 @@ +from random_events.set import SetElement, Set +import enum +import unittest + + +class TestEnum(SetElement): + EMPTY_SET = 0 + A = 1 + B = 2 + C = 4 + + +class SetElementTestCase(unittest.TestCase): + + def test_intersection_with(self): + a = TestEnum.A + b = TestEnum.B + + intersection_a_b = a.intersection_with(b) + self.assertEqual(intersection_a_b, TestEnum.EMPTY_SET) + self.assertEqual(a.intersection_with(TestEnum.A), a) + self.assertEqual(TestEnum.EMPTY_SET.intersection_with(TestEnum.A), TestEnum.EMPTY_SET) + + def test_complement(self): + a = TestEnum.A + complement = a.complement() + self.assertEqual(complement, {TestEnum.B, TestEnum.C}) + + def test_is_empty(self): + a = TestEnum.EMPTY_SET + b = TestEnum.B + self.assertTrue(a.is_empty()) + self.assertFalse(b.is_empty()) + + def test_contains(self): + a = TestEnum.A + self.assertTrue(a.contains(TestEnum.A)) + self.assertFalse(a.contains(TestEnum.B)) + self.assertFalse(a.contains(TestEnum.C)) + + +class SetTestCase(unittest.TestCase): + + def test_simplify(self): + a = TestEnum.A + b = TestEnum.B + c = TestEnum.C + s = Set([a, b, c, c]) + self.assertEqual(len(s.simple_sets), 3) + self.assertEqual(s.simplify(), s) + + def test_difference(self): + s = Set([TestEnum.A, TestEnum.B]) + s_ = Set([TestEnum.A]) + self.assertEqual(s.difference_with(s_), Set([TestEnum.B])) + + def test_complement(self): + s = Set([TestEnum.A, TestEnum.B]) + self.assertEqual(s.complement(), Set([TestEnum.C])) + +if __name__ == '__main__': + unittest.main() diff --git a/test/test_variable.py b/test/test_variable.py new file mode 100644 index 0000000..f9afca4 --- /dev/null +++ b/test/test_variable.py @@ -0,0 +1,39 @@ +import unittest + +from random_events.better_variables import * +from random_events.interval import * + + +class TestEnum(SetElement): + EMPTY_SET = 0 + A = 1 + B = 2 + C = 4 + + +class ContinuousTestCase(unittest.TestCase): + + def test_creation(self): + x = Continuous("x") + self.assertEqual(x.name, "x") + self.assertEqual(x.domain, Interval([SimpleInterval(-float("inf"), float("inf"))])) + + +class IntegerTestCase(unittest.TestCase): + + def test_creation(self): + x = Integer("x") + self.assertEqual(x.name, "x") + self.assertEqual(x.domain, Interval([SimpleInterval(-float("inf"), float("inf"))])) + + +class SymbolicTestCase(unittest.TestCase): + + def test_creation(self): + x = Symbolic("x", TestEnum) + self.assertEqual(x.name, "x") + self.assertEqual(x.domain, Set([TestEnum.A, TestEnum.B, TestEnum.C])) + + +if __name__ == '__main__': + unittest.main()