diff --git a/src/random_events/__init__.py b/src/random_events/__init__.py index bc50bee..99d2a6f 100644 --- a/src/random_events/__init__.py +++ b/src/random_events/__init__.py @@ -1 +1 @@ -__version__ = '1.1.4' +__version__ = '1.1.5' diff --git a/src/random_events/utils.py b/src/random_events/utils.py new file mode 100644 index 0000000..0c08258 --- /dev/null +++ b/src/random_events/utils.py @@ -0,0 +1,16 @@ +def get_full_class_name(cls): + """ + Returns the full name of a class, including the module name. + + :param cls: The class. + :return: The full name of the class + """ + return cls.__module__ + "." + cls.__name__ + + +def recursive_subclasses(cls): + """ + :param cls: The class. + :return: A list of the classes subclasses. + """ + return cls.__subclasses__() + [g for s in cls.__subclasses__() for g in recursive_subclasses(s)] diff --git a/src/random_events/variables.py b/src/random_events/variables.py index 9cb474c..aad5d7b 100644 --- a/src/random_events/variables.py +++ b/src/random_events/variables.py @@ -1,9 +1,11 @@ import json -from typing import Any, Union, Iterable +from typing import Any, Union, Iterable, Dict import portion import pydantic +from . import utils + class Variable(pydantic.BaseModel): """ @@ -20,8 +22,14 @@ class Variable(pydantic.BaseModel): The set of possible events of the variable. """ + type: str = pydantic.Field(repr=False, init_var=False, default=None) + """ + The type of the variable. This is used for de-serialization and set automatically in the constructor. + """ + def __init__(self, name: str, domain: Any): super().__init__(name=name, domain=domain) + self.type = utils.get_full_class_name(self.__class__) def __lt__(self, other: "Variable") -> bool: """ @@ -74,6 +82,20 @@ def decode_many(self, elements: Iterable) -> Iterable[Any]: """ return elements + @staticmethod + def from_json(data: Dict[str, Any]) -> 'Variable': + """ + Create the correct instanceof the subclass from a json dict. + + :param data: The json dict + :return: The correct instance of the subclass + """ + for subclass in utils.recursive_subclasses(Variable): + if utils.get_full_class_name(subclass) == data["type"]: + return subclass(**{key: value for key, value in data.items() if key != "type"}) + + raise ValueError("Unknown type for variable. Type is {}".format(data["type"])) + class Continuous(Variable): """ diff --git a/test/test_variables.py b/test/test_variables.py index 994c0fb..9238c94 100644 --- a/test/test_variables.py +++ b/test/test_variables.py @@ -1,8 +1,9 @@ +import json import unittest import portion -from random_events.variables import Integer, Symbolic, Continuous +from random_events.variables import Integer, Symbolic, Continuous, Variable class VariablesTestCase(unittest.TestCase): @@ -66,19 +67,6 @@ def test_to_json(self): self.assertTrue(self.integer.model_dump_json()) self.assertTrue(self.real.model_dump_json()) - def test_from_json(self): - """ - Test that the variables can be loaded from json. - """ - real = Continuous.model_validate_json(self.real.model_dump_json()) - self.assertEqual(real, self.real) - - integer = Integer.model_validate_json(self.integer.model_dump_json()) - self.assertEqual(integer, self.integer) - - symbol = Symbolic.model_validate_json(self.symbol.model_dump_json()) - self.assertEqual(symbol, self.symbol) - def test_encode(self): """ Test that the variables can be encoded. @@ -95,6 +83,22 @@ def test_decode(self): self.assertEqual(self.symbol.decode(1), "b") self.assertEqual(self.real.decode(1.0), 1.0) + def test_type_setting(self): + self.assertEqual(self.real.type, "random_events.variables.Continuous") + self.assertEqual(self.integer.type, "random_events.variables.Integer") + self.assertEqual(self.symbol.type, "random_events.variables.Symbolic") + + def test_polymorphic_serialization(self): + real = Variable.from_json(json.loads(self.real.model_dump_json())) + self.assertEqual(real, self.real) + + integer = Variable.from_json(json.loads(self.integer.model_dump_json())) + print(integer) + self.assertEqual(integer, self.integer) + + symbol = Variable.from_json(json.loads(self.symbol.model_dump_json())) + self.assertEqual(symbol, self.symbol) + if __name__ == '__main__': unittest.main()