Skip to content

Commit

Permalink
feat(config): config validation via pydantic
Browse files Browse the repository at this point in the history
Fail early, fail gracefully.

Implements config validation via pydantic for CosipyConfig
(main_config), SlurmConfig (slurm_config) and Constants
(constants_config, still needs more detailed model description).

NOTE: Other configs are still missing.

This provides config validation similar to dataclasses but more
elaborate. The validation at runtime allows us to check early on
if any config is missing or invalid and stops the model with informative
error messages.

Furthermore pydantic provides us type information, which makes it easier
to extend the code as the project gets larger.

While implementing this, I already found that the previous `Constants`
class was missing some entries that are in the `constants.toml` file.

Flattening the config structure (as done before) is now done before
model validation by use of pydantics `@model_validator` decorator.

NOTE: This is supposed to be a first draft and further details should be
implemented.

IMPORTANT: Look for NOTE, FIXME, and HACK comments in the code.
  • Loading branch information
benatouba committed Nov 26, 2024
1 parent 67c0550 commit 1c4e4dc
Show file tree
Hide file tree
Showing 30 changed files with 769 additions and 608 deletions.
1 change: 1 addition & 0 deletions conda_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ nco
cdo
cartopy
vtk
pydantic
richdem
coveralls
codecov
Expand Down
195 changes: 130 additions & 65 deletions cosipy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@
import sys
from importlib.metadata import entry_points
from pathlib import Path
from typing import Annotated, Literal, Optional

from pydantic import BaseModel, ConfigDict, Field, model_validator
from pydantic.types import StringConstraints
from rich import print # noqa: A004
from typing_extensions import Self

if sys.version_info >= (3, 11):
import tomllib
Expand Down Expand Up @@ -95,6 +101,74 @@ def get_user_arguments() -> argparse.Namespace:
return arguments


DatetimeStr = Annotated[
str,
StringConstraints(
strip_whitespace=True, pattern=r"\d{4}-[01]\d-[0-3]\dT[0-2]\d:[0-5]\d"
),
]


class CosipyConfigModel(BaseModel):
"""COSIPY configuration model."""

model_config = ConfigDict(from_attributes=True)
time_start: DatetimeStr = Field(
description="Start time of the simulation in ISO format"
)
time_end: DatetimeStr = Field(description="End time of the simulation in ISO format")
data_path: Path = Field(description="Path to the data directory")
input_netcdf: Path = Field(description="Input NetCDF file path")
output_prefix: str = Field(description="Prefix for output files")
restart: bool = Field(description="Restart flag")
stake_evaluation: bool = Field(description="Flag for stake data evaluation")
stakes_loc_file: Path = Field(description="Path to stake location file")
stakes_data_file: Path = Field(description="Path to stake data file")
eval_method: Literal["rmse"] = Field(
"rmse", description="Evaluation method for simulations"
)
obs_type: Literal["mb", "snowheight"] = Field(description="Type of stake data used")
WRF: bool = Field(description="Flag for WRF input")
WRF_X_CSPY: bool = Field(description="Interactive simulation with WRF flag")
northing: str = Field(description="Name of northing dimension")
easting: str = Field(description="Name of easting dimension")
compression_level: int = Field(
ge=0, le=9, description="Output NetCDF compression level"
)
slurm_use: bool = Field(description="Use SLURM flag")
workers: Optional[int] = Field(
default=None,
ge=0,
description="""
Setting is only used is slurm_use is False.
Number of workers (cores), with 0 all available cores are used.
""",
)
local_port: int = Field(default=8786, gt=0, description="Port for local cluster")
full_field: bool = Field(description="Flag for writing full fields to file")
force_use_TP: bool = Field(..., description="Total precipitation flag")
force_use_N: bool = Field(..., description="Cloud cover fraction flag")
tile: bool = Field(description="Flag for tiling")
xstart: int = Field(ge=0, description="Start x index")
xend: int = Field(ge=0, description="End x index")
ystart: int = Field(ge=0, description="Start y index")
yend: int = Field(ge=0, description="End y index")
output_atm: str = Field(description="Atmospheric output variables")
output_internal: str = Field(description="Internal output variables")
output_full: str = Field(description="Full output variables")

@model_validator(mode="after")
def validate_output_variables(self) -> Self:
if self.WRF:
self.northing = "south_north"
self.easting = "west_east"
if self.WRF_X_CSPY:
self.full_field = True
if self.workers == 0:
self.workers = None
return self


def get_help():
"""Print help for commands."""
parser = set_parser()
Expand Down Expand Up @@ -135,7 +209,7 @@ class TomlLoader(object):
"""Load and parse configuration files."""

@staticmethod
def get_raw_toml(file_path: str = "./config.toml") -> dict:
def get_raw_toml(file_path: Path = default_path) -> dict:
"""Open and load .toml configuration file.
Args:
Expand All @@ -144,21 +218,20 @@ def get_raw_toml(file_path: str = "./config.toml") -> dict:
Returns:
Loaded .toml data.
"""
with open(file_path, "rb") as f:
raw_config = tomllib.load(f)

return raw_config
with file_path.open("rb") as f:
return tomllib.load(f)

@classmethod
def set_config_values(cls, config_table: dict):
@staticmethod
def flatten(config_table: dict[str, dict]) -> dict:
"""Overwrite attributes with configuration data.
Args:
config_table: Loaded .toml data.
"""
for _, table in config_table.items():
for k, v in table.items():
setattr(cls, k, v)
flat_dict = {}
for table in config_table.values():
flat_dict = {**flat_dict, **table}
return flat_dict


class Config(TomlLoader):
Expand All @@ -168,37 +241,45 @@ class Config(TomlLoader):
.toml file.
"""

def __init__(self):
self.args = get_user_arguments()
self.load(self.args.config_path)

@classmethod
def load(cls, path: str = "./config.toml"):
raw_toml = cls.get_raw_toml(path)
parsed_toml = cls.set_correct_config(raw_toml)
cls.set_config_values(parsed_toml)

@classmethod
def set_correct_config(cls, config_table: dict) -> dict:
"""Adjust invalid or mutually exclusive configuration values.
def __init__(self, path: Path = default_path) -> None:
raw_toml = self.get_raw_toml(path)
self.raw_toml = self.flatten(raw_toml)

Args:
config_table: Loaded .toml data.
def validate(self) -> CosipyConfigModel:
"""Validate configuration using Pydantic class.
Returns:
Adjusted .toml data.
CosipyConfigModel: Validated configuration.
"""
# WRF Compatibility
if config_table["DIMENSIONS"]["WRF"]:
config_table["DIMENSIONS"]["northing"] = "south_north"
config_table["DIMENSIONS"]["easting"] = "west_east"
if config_table["DIMENSIONS"]["WRF_X_CSPY"]:
config_table["FULL_FIELDS"]["full_field"] = True
# TOML doesn't support null values
if config_table["PARALLELIZATION"]["workers"] == 0:
config_table["PARALLELIZATION"]["workers"] = None
return CosipyConfigModel(**self.raw_toml)


ShebangStr = Annotated[str, StringConstraints(strip_whitespace=True, pattern=r"^#!")]


class SlurmConfigModel(BaseModel):
"""Slurm configuration model."""

account: str = Field(description="Slurm account/group")
name: str = Field(description="Equivalent to Slurm parameter `--job-name`")
queue: str = Field(description="Queue name")
slurm_parameters: list[str] = Field(description="Additional Slurm parameters")
shebang: ShebangStr = Field(description="Shebang string")
local_directory: Path = Field(description="Local directory")
port: int = Field(description="Network port number")
cores: int = Field(description="One grid point per core")
nodes: int = Field(description="Grid points submitted in one sbatch script")
processes: int = Field(description="Number of processes")
memory: str = Field(description="Memory per process")
memory_per_process: Optional[int] = Field(gt=0, description="Memory per process")

return config_table
@model_validator(mode="after")
def validate_output_variables(self):
if self.memory_per_process:
memory = self.memory_per_process * self.cores
self.memory = f"{memory}GB"

return self


class SlurmConfig(TomlLoader):
Expand All @@ -222,43 +303,27 @@ class SlurmConfig(TomlLoader):
slurm_parameters (List[str]): Additional Slurm parameters.
"""

def __init__(self):
self.args = get_user_arguments()
self.load(self.args.slurm_path)

@classmethod
def load(cls, path: str = "./slurm_config.toml"):
raw_toml = cls.get_raw_toml(path)
parsed_toml = cls.set_correct_config(raw_toml)
cls.set_config_values(parsed_toml)
def __init__(self, path: Path = default_slurm_path) -> None:
raw_toml = self.get_raw_toml(path)
self.raw_toml = self.flatten(raw_toml)

@classmethod
def set_correct_config(cls, config_table: dict) -> dict:
"""Adjust invalid or mutually exclusive configuration values.
Args:
config_table: Loaded .toml data.
def validate(self) -> SlurmConfigModel:
"""Validate configuration using Pydantic class.
Returns:
Adjusted .toml data.
CosipyConfigModel: Validated configuration.
"""
if config_table["OVERRIDES"]["memory_per_process"]:
memory = (
config_table["OVERRIDES"]["memory_per_process"]
* config_table["MEMORY"]["cores"]
)
config_table["MEMORY"]["memory"] = f"{memory}GB"

return config_table
return SlurmConfigModel(**self.raw_toml)


def main():
cfg = Config()
if cfg.slurm_use:
SlurmConfig()
def main() -> tuple[CosipyConfigModel, Optional[SlurmConfigModel]]:
args = get_user_arguments()
cfg = Config(args.config_path).validate()
slurm_cfg = SlurmConfig(args.slurm_path).validate() if cfg.slurm_use else None
return cfg, slurm_cfg


if __name__ == "__main__":
main()
else:
main()
main_config, slurm_config = main()
Loading

0 comments on commit 1c4e4dc

Please sign in to comment.