Skip to content

Commit

Permalink
Removed pydantic dependency.
Browse files Browse the repository at this point in the history
  • Loading branch information
tomsch420 committed Nov 13, 2023
1 parent b2d6591 commit 5345b22
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 50 deletions.
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
portion>=2.4.1
pydantic>=2.4.2
2 changes: 1 addition & 1 deletion src/random_events/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.1.5'
__version__ = '1.2.5'
74 changes: 39 additions & 35 deletions src/random_events/variables.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import json
from typing import Any, Union, Iterable, Dict
from typing import Any, Iterable, Dict, Tuple

import portion
import pydantic

from . import utils


class Variable(pydantic.BaseModel):
class Variable:
"""
Abstract base class for all variables.
"""
Expand All @@ -17,19 +15,14 @@ class Variable(pydantic.BaseModel):
The name of the variable. The name is used for comparison and hashing.
"""

domain: Any = pydantic.Field(repr=False)
domain: Any
"""
The set of possible events of the variable.
"""

type: str = pydantic.Field(repr=False, init_var=False, default=None)
"""
The type of the variable. This is used for de-serialization and set automatically in the constructor.
"""

def __init__(self, name: str, domain: Any):
super().__init__(name=name, domain=domain)
self.type = utils.get_full_class_name(self.__class__)
self.name = name
self.domain = domain

def __lt__(self, other: "Variable") -> bool:
"""
Expand All @@ -46,6 +39,15 @@ def __gt__(self, other: "Variable") -> bool:
def __hash__(self) -> int:
return self.name.__hash__()

def __eq__(self, other):
return self.name == other.name and self.domain == other.domain

def __str__(self):
return f"{self.__class__.__name__}({self.name}, {self.domain})"

def __repr__(self):
return f"{self.__class__.__name__}({self.name})"

def encode(self, value: Any) -> Any:
"""
Encode an element of the domain to a representation that is usable for computations.
Expand Down Expand Up @@ -82,8 +84,22 @@ def decode_many(self, elements: Iterable) -> Iterable[Any]:
"""
return elements

@staticmethod
def from_json(data: Dict[str, Any]) -> 'Variable':
def to_json(self) -> Dict[str, Any]:
return {"name": self.name, "type": utils.get_full_class_name(self.__class__), "domain": self.domain}

@classmethod
def _from_json(cls, data: Dict[str, Any]) -> 'Variable':
"""
Create a variable from a json dict.
This method is called from the from_json method after the correct subclass is determined.
:param data: The json dict
:return: The variable
"""
return cls(name=data["name"], domain=data["domain"])

@classmethod
def from_json(cls, data: Dict[str, Any]) -> 'Variable':
"""
Create the correct instanceof the subclass from a json dict.
Expand All @@ -92,7 +108,7 @@ def from_json(data: Dict[str, Any]) -> 'Variable':
"""
for subclass in utils.recursive_subclasses(Variable):
if utils.get_full_class_name(subclass) == data["type"]:
return subclass(**{key: value for key, value in data.items() if key != "type"})
return subclass._from_json(data)

raise ValueError("Unknown type for variable. Type is {}".format(data["type"]))

Expand All @@ -102,37 +118,25 @@ class Continuous(Variable):
Class for real valued random variables.
"""

model_config = pydantic.ConfigDict(arbitrary_types_allowed=True)

domain: portion.Interval = pydantic.Field(portion.open(-portion.inf, portion.inf), repr=False)
domain: portion.Interval

def __init__(self, name: str, domain: portion.Interval = portion.open(-portion.inf, portion.inf)):
super().__init__(name=name, domain=domain)

@pydantic.field_serializer("domain")
def serialize_domain(self, interval: portion.Interval) -> str:
"""
Serialize the domain of this variable to a string.
:param interval: The domain
:return: A json string of it
"""
return json.dumps(portion.to_data(interval))
def to_json(self) -> Dict[str, Any]:
return {"name": self.name, "type": utils.get_full_class_name(self.__class__),
"domain": portion.to_data(self.domain)}

@pydantic.field_validator("domain", mode="before")
def validate_domain(cls, interval: Union[portion.Interval, str]) -> portion.Interval:
if isinstance(interval, str):
return portion.from_data(json.loads(interval))
elif isinstance(interval, portion.Interval):
return interval
else:
raise ValueError("Unknown type for domain. Type is {}".format(type(interval)))
@classmethod
def _from_json(cls, data: Dict[str, Any]) -> 'Variable':
return cls(name=data["name"], domain=portion.from_data(data["domain"]))


class Discrete(Variable):
"""
Class for discrete countable random variables.
"""
domain: tuple = pydantic.Field(repr=False)
domain: Tuple

def __init__(self, name: str, domain: Iterable):
super().__init__(name=name, domain=tuple(sorted(set(domain))))
Expand Down
19 changes: 6 additions & 13 deletions test/test_variables.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
import unittest

import portion
Expand Down Expand Up @@ -63,9 +62,9 @@ def test_to_json(self):
"""
Test that the variables can be dumped to json.
"""
self.assertTrue(self.symbol.model_dump_json())
self.assertTrue(self.integer.model_dump_json())
self.assertTrue(self.real.model_dump_json())
self.assertTrue(self.symbol.to_json())
self.assertTrue(self.integer.to_json())
self.assertTrue(self.real.to_json())

def test_encode(self):
"""
Expand All @@ -83,20 +82,14 @@ def test_decode(self):
self.assertEqual(self.symbol.decode(1), "b")
self.assertEqual(self.real.decode(1.0), 1.0)

def test_type_setting(self):
self.assertEqual(self.real.type, "random_events.variables.Continuous")
self.assertEqual(self.integer.type, "random_events.variables.Integer")
self.assertEqual(self.symbol.type, "random_events.variables.Symbolic")

def test_polymorphic_serialization(self):
real = Variable.from_json(json.loads(self.real.model_dump_json()))
real = Variable.from_json(self.real.to_json())
self.assertEqual(real, self.real)

integer = Variable.from_json(json.loads(self.integer.model_dump_json()))
print(integer)
integer = Variable.from_json(self.integer.to_json())
self.assertEqual(integer, self.integer)

symbol = Variable.from_json(json.loads(self.symbol.model_dump_json()))
symbol = Variable.from_json(self.symbol.to_json())
self.assertEqual(symbol, self.symbol)


Expand Down

0 comments on commit 5345b22

Please sign in to comment.