Skip to content

Commit

Permalink
Add pydantic base class.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 720897019
  • Loading branch information
sbodenstein authored and Torax team committed Jan 29, 2025
1 parent 75db63c commit a6b4b48
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 1 deletion.
41 changes: 40 additions & 1 deletion torax/config/pydantic_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@

"""Pydantic utilities and base classes."""

from typing import Annotated, TypeAlias
from collections.abc import Mapping
from typing import Annotated, Any, TypeAlias
import numpy as np
import pydantic
from typing_extensions import Self

DataTypes: TypeAlias = float | int | bool

Expand Down Expand Up @@ -65,3 +67,40 @@ def _numpy_array_is_rank_1(x: np.ndarray) -> np.ndarray:
NumpyArray1D = Annotated[
NumpyArray, pydantic.AfterValidator(_numpy_array_is_rank_1)
]


class Base(pydantic.BaseModel):
"""Base config class. Any custom config classes should inherit from this.
See https://docs.pydantic.dev/latest/ for documentation on pydantic.
"""

model_config = pydantic.ConfigDict(
frozen=False,
# Do not allow attributes not defined in pydantic model.
extra='forbid',
# Re-run validation if the model is updated.
validate_assignment=True,
arbitrary_types_allowed=True,
)

@classmethod
def from_dict(cls: type[Self], cfg: Mapping[str, Any]) -> Self:
return cls.model_validate(cfg)

def to_dict(self) -> dict[str, Any]:
return self.model_dump()


class BaseFrozen(Base):
"""Base config with frozen fields.
See https://docs.pydantic.dev/latest/ for documentation on pydantic.
"""

model_config = pydantic.ConfigDict(
frozen=True,
# Do not allow attributes not defined in pydantic model.
extra='forbid',
arbitrary_types_allowed=True,
)
37 changes: 37 additions & 0 deletions torax/config/tests/pydantic_base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

"""Unit tests for the `torax.config.pydantic_base` module."""

import functools
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
Expand Down Expand Up @@ -68,6 +69,42 @@ def test_1d_array(self):
with self.assertRaises(ValueError):
array.validate_python(np.array([[1.0, 2.0], [3.0, 4.0]]))

def test_pydantic_base_frozen(self):

class TestModel(pydantic_base.BaseFrozen):
x: float
y: float

m = TestModel(y=4.0, x=2.0)

with self.subTest('frozen_model_cannot_be_updated'):
with self.assertRaises(ValueError):
m.x = 2.0

def test_pydantic_base(self):

class Test(pydantic_base.Base, validate_assignment=True):
name: str

@functools.cached_property
def computed(self):
return self.name + '_test' # pytype: disable=attribute-error

@pydantic.model_validator(mode='after')
def validate(self):
if hasattr(self, 'computed'):
del self.computed
return self

m = Test(name='test_string')
self.assertEqual(m.computed, 'test_string_test')

with self.subTest('field_is_mutable'):
m.name = 'new_test_string'

with self.subTest('after_model_validator_is_called_on_update'):
self.assertEqual(m.computed, 'new_test_string_test')


if __name__ == '__main__':
absltest.main()

0 comments on commit a6b4b48

Please sign in to comment.