diff --git a/pyproject.toml b/pyproject.toml index 1a98b0a..be96255 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,7 @@ dynamic = ["version"] dependencies = [ "numpy", "pandas", - "pydantic", + "pydantic>=2", ] # extras @@ -112,11 +112,10 @@ test = "pytest {args:tests}" matrix-name-format = "{variable}_{value}" [[tool.hatch.envs.all.matrix]] -pydantic_version = ["1","2"] +pydantic_version = ["2"] [tool.hatch.envs.all.overrides] matrix.pydantic_version.dependencies = [ - { value="pydantic<2", if = ["1"] }, { value="pydantic>=2", if = ["2"] } ] diff --git a/src/imodmodel/models.py b/src/imodmodel/models.py index 4f54c86..abda612 100644 --- a/src/imodmodel/models.py +++ b/src/imodmodel/models.py @@ -1,10 +1,9 @@ import os import warnings -from typing import Tuple, List, Optional, Union +from typing import List, Optional, Tuple, Union import numpy as np -from pydantic import BaseModel, validator -from pydantic.version import VERSION as PYDANTIC_VERSION +from pydantic import BaseModel, ConfigDict, field_validator class ID(BaseModel): @@ -20,10 +19,6 @@ class GeneralStorage(BaseModel): index: Union[float, int, Tuple[int, int], Tuple[int, int, int, int]] value: Union[float, int, Tuple[int, int], Tuple[int, int, int, int]] - if PYDANTIC_VERSION < '2.0': - class Config: - smart_union = True - class ModelHeader(BaseModel): """https://bio3d.colorado.edu/imod/doc/binspec.html""" @@ -55,7 +50,8 @@ class ModelHeader(BaseModel): beta: float gamma: float - @validator('name', pre=True) + @field_validator('name', mode="before") + @classmethod def decode_null_terminated_byte_string(cls, value: bytes): end = value.find(b'\x00') return value[:end].decode('utf-8') @@ -84,7 +80,8 @@ class ObjectHeader(BaseModel): meshsize: int surfsize: int - @validator('name', pre=True) + @field_validator('name', mode="before") + @classmethod def decode_null_terminated_byte_string(cls, value: bytes): end = value.find(b'\x00') return value[:end].decode('utf-8') @@ -104,8 +101,7 @@ class Contour(BaseModel): points: np.ndarray # pt extra: List[GeneralStorage] = [] - class Config: - arbitrary_types_allowed = True + model_config = ConfigDict(arbitrary_types_allowed=True) class MeshHeader(BaseModel): @@ -123,10 +119,9 @@ class Mesh(BaseModel): raw_indices: np.ndarray extra: List[GeneralStorage] = [] - class Config: - arbitrary_types_allowed = True + model_config = ConfigDict(arbitrary_types_allowed=True) - @validator('raw_indices') + @field_validator('raw_indices') @classmethod def validate_indices(cls, indices: np.ndarray): if indices.ndim > 1: @@ -140,7 +135,7 @@ def validate_indices(cls, indices: np.ndarray): warnings.warn(f'Unsupported mesh type: {i}') return indices - @validator('raw_vertices') + @field_validator('raw_vertices') @classmethod def validate_vertices(cls, vertices: np.ndarray): if vertices.ndim > 1: @@ -242,7 +237,7 @@ class ImodModel(BaseModel): id: ID header: ModelHeader objects: List[Object] - imat: Optional[IMAT] + imat: Optional[IMAT] = None extra: List[GeneralStorage] = [] @classmethod