Skip to content

Commit

Permalink
Migrate to pydantic v2 API (#13)
Browse files Browse the repository at this point in the history
* make tests compatible with pytest >= 8.0.0

* migrate to pydantic v2

* remove redundant import
  • Loading branch information
uermel authored May 16, 2024
1 parent 2e752a5 commit f597bbc
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 19 deletions.
5 changes: 2 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ dynamic = ["version"]
dependencies = [
"numpy",
"pandas",
"pydantic",
"pydantic>=2",
]

# extras
Expand Down Expand Up @@ -112,11 +112,10 @@ test = "pytest {args:tests}"
matrix-name-format = "{variable}_{value}"

[[tool.hatch.envs.all.matrix]]
pydantic_version = ["1","2"]
pydantic_version = ["2"]

[tool.hatch.envs.all.overrides]
matrix.pydantic_version.dependencies = [
{ value="pydantic<2", if = ["1"] },
{ value="pydantic>=2", if = ["2"] }
]

Expand Down
27 changes: 11 additions & 16 deletions src/imodmodel/models.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import os
import warnings
from typing import Tuple, List, Optional, Union
from typing import List, Optional, Tuple, Union

import numpy as np
from pydantic import BaseModel, validator
from pydantic.version import VERSION as PYDANTIC_VERSION
from pydantic import BaseModel, ConfigDict, field_validator


class ID(BaseModel):
Expand All @@ -20,10 +19,6 @@ class GeneralStorage(BaseModel):
index: Union[float, int, Tuple[int, int], Tuple[int, int, int, int]]
value: Union[float, int, Tuple[int, int], Tuple[int, int, int, int]]

if PYDANTIC_VERSION < '2.0':
class Config:
smart_union = True


class ModelHeader(BaseModel):
"""https://bio3d.colorado.edu/imod/doc/binspec.html"""
Expand Down Expand Up @@ -55,7 +50,8 @@ class ModelHeader(BaseModel):
beta: float
gamma: float

@validator('name', pre=True)
@field_validator('name', mode="before")
@classmethod
def decode_null_terminated_byte_string(cls, value: bytes):
end = value.find(b'\x00')
return value[:end].decode('utf-8')
Expand Down Expand Up @@ -84,7 +80,8 @@ class ObjectHeader(BaseModel):
meshsize: int
surfsize: int

@validator('name', pre=True)
@field_validator('name', mode="before")
@classmethod
def decode_null_terminated_byte_string(cls, value: bytes):
end = value.find(b'\x00')
return value[:end].decode('utf-8')
Expand All @@ -104,8 +101,7 @@ class Contour(BaseModel):
points: np.ndarray # pt
extra: List[GeneralStorage] = []

class Config:
arbitrary_types_allowed = True
model_config = ConfigDict(arbitrary_types_allowed=True)


class MeshHeader(BaseModel):
Expand All @@ -123,10 +119,9 @@ class Mesh(BaseModel):
raw_indices: np.ndarray
extra: List[GeneralStorage] = []

class Config:
arbitrary_types_allowed = True
model_config = ConfigDict(arbitrary_types_allowed=True)

@validator('raw_indices')
@field_validator('raw_indices')
@classmethod
def validate_indices(cls, indices: np.ndarray):
if indices.ndim > 1:
Expand All @@ -140,7 +135,7 @@ def validate_indices(cls, indices: np.ndarray):
warnings.warn(f'Unsupported mesh type: {i}')
return indices

@validator('raw_vertices')
@field_validator('raw_vertices')
@classmethod
def validate_vertices(cls, vertices: np.ndarray):
if vertices.ndim > 1:
Expand Down Expand Up @@ -242,7 +237,7 @@ class ImodModel(BaseModel):
id: ID
header: ModelHeader
objects: List[Object]
imat: Optional[IMAT]
imat: Optional[IMAT] = None
extra: List[GeneralStorage] = []

@classmethod
Expand Down

0 comments on commit f597bbc

Please sign in to comment.