diff --git a/requirements.txt b/requirements.txt index ee22720..97fc0eb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1 @@ portion>=2.4.1 -pydantic>=2.4.2 diff --git a/src/random_events/__init__.py b/src/random_events/__init__.py index 99d2a6f..09964d6 100644 --- a/src/random_events/__init__.py +++ b/src/random_events/__init__.py @@ -1 +1 @@ -__version__ = '1.1.5' +__version__ = '1.2.5' diff --git a/src/random_events/variables.py b/src/random_events/variables.py index aad5d7b..4d07890 100644 --- a/src/random_events/variables.py +++ b/src/random_events/variables.py @@ -1,13 +1,11 @@ -import json -from typing import Any, Union, Iterable, Dict +from typing import Any, Iterable, Dict, Tuple import portion -import pydantic from . import utils -class Variable(pydantic.BaseModel): +class Variable: """ Abstract base class for all variables. """ @@ -17,19 +15,14 @@ class Variable(pydantic.BaseModel): The name of the variable. The name is used for comparison and hashing. """ - domain: Any = pydantic.Field(repr=False) + domain: Any """ The set of possible events of the variable. """ - type: str = pydantic.Field(repr=False, init_var=False, default=None) - """ - The type of the variable. This is used for de-serialization and set automatically in the constructor. - """ - def __init__(self, name: str, domain: Any): - super().__init__(name=name, domain=domain) - self.type = utils.get_full_class_name(self.__class__) + self.name = name + self.domain = domain def __lt__(self, other: "Variable") -> bool: """ @@ -46,6 +39,15 @@ def __gt__(self, other: "Variable") -> bool: def __hash__(self) -> int: return self.name.__hash__() + def __eq__(self, other): + return self.name == other.name and self.domain == other.domain + + def __str__(self): + return f"{self.__class__.__name__}({self.name}, {self.domain})" + + def __repr__(self): + return f"{self.__class__.__name__}({self.name})" + def encode(self, value: Any) -> Any: """ Encode an element of the domain to a representation that is usable for computations. @@ -82,8 +84,22 @@ def decode_many(self, elements: Iterable) -> Iterable[Any]: """ return elements - @staticmethod - def from_json(data: Dict[str, Any]) -> 'Variable': + def to_json(self) -> Dict[str, Any]: + return {"name": self.name, "type": utils.get_full_class_name(self.__class__), "domain": self.domain} + + @classmethod + def _from_json(cls, data: Dict[str, Any]) -> 'Variable': + """ + Create a variable from a json dict. + This method is called from the from_json method after the correct subclass is determined. + + :param data: The json dict + :return: The variable + """ + return cls(name=data["name"], domain=data["domain"]) + + @classmethod + def from_json(cls, data: Dict[str, Any]) -> 'Variable': """ Create the correct instanceof the subclass from a json dict. @@ -92,7 +108,7 @@ def from_json(data: Dict[str, Any]) -> 'Variable': """ for subclass in utils.recursive_subclasses(Variable): if utils.get_full_class_name(subclass) == data["type"]: - return subclass(**{key: value for key, value in data.items() if key != "type"}) + return subclass._from_json(data) raise ValueError("Unknown type for variable. Type is {}".format(data["type"])) @@ -102,37 +118,25 @@ class Continuous(Variable): Class for real valued random variables. """ - model_config = pydantic.ConfigDict(arbitrary_types_allowed=True) - - domain: portion.Interval = pydantic.Field(portion.open(-portion.inf, portion.inf), repr=False) + domain: portion.Interval def __init__(self, name: str, domain: portion.Interval = portion.open(-portion.inf, portion.inf)): super().__init__(name=name, domain=domain) - @pydantic.field_serializer("domain") - def serialize_domain(self, interval: portion.Interval) -> str: - """ - Serialize the domain of this variable to a string. - :param interval: The domain - :return: A json string of it - """ - return json.dumps(portion.to_data(interval)) + def to_json(self) -> Dict[str, Any]: + return {"name": self.name, "type": utils.get_full_class_name(self.__class__), + "domain": portion.to_data(self.domain)} - @pydantic.field_validator("domain", mode="before") - def validate_domain(cls, interval: Union[portion.Interval, str]) -> portion.Interval: - if isinstance(interval, str): - return portion.from_data(json.loads(interval)) - elif isinstance(interval, portion.Interval): - return interval - else: - raise ValueError("Unknown type for domain. Type is {}".format(type(interval))) + @classmethod + def _from_json(cls, data: Dict[str, Any]) -> 'Variable': + return cls(name=data["name"], domain=portion.from_data(data["domain"])) class Discrete(Variable): """ Class for discrete countable random variables. """ - domain: tuple = pydantic.Field(repr=False) + domain: Tuple def __init__(self, name: str, domain: Iterable): super().__init__(name=name, domain=tuple(sorted(set(domain)))) diff --git a/test/test_variables.py b/test/test_variables.py index 9238c94..1c99268 100644 --- a/test/test_variables.py +++ b/test/test_variables.py @@ -1,4 +1,3 @@ -import json import unittest import portion @@ -63,9 +62,9 @@ def test_to_json(self): """ Test that the variables can be dumped to json. """ - self.assertTrue(self.symbol.model_dump_json()) - self.assertTrue(self.integer.model_dump_json()) - self.assertTrue(self.real.model_dump_json()) + self.assertTrue(self.symbol.to_json()) + self.assertTrue(self.integer.to_json()) + self.assertTrue(self.real.to_json()) def test_encode(self): """ @@ -83,20 +82,14 @@ def test_decode(self): self.assertEqual(self.symbol.decode(1), "b") self.assertEqual(self.real.decode(1.0), 1.0) - def test_type_setting(self): - self.assertEqual(self.real.type, "random_events.variables.Continuous") - self.assertEqual(self.integer.type, "random_events.variables.Integer") - self.assertEqual(self.symbol.type, "random_events.variables.Symbolic") - def test_polymorphic_serialization(self): - real = Variable.from_json(json.loads(self.real.model_dump_json())) + real = Variable.from_json(self.real.to_json()) self.assertEqual(real, self.real) - integer = Variable.from_json(json.loads(self.integer.model_dump_json())) - print(integer) + integer = Variable.from_json(self.integer.to_json()) self.assertEqual(integer, self.integer) - symbol = Variable.from_json(json.loads(self.symbol.model_dump_json())) + symbol = Variable.from_json(self.symbol.to_json()) self.assertEqual(symbol, self.symbol)