Skip to content

Commit

Permalink
Add custom ArchitectureError
Browse files Browse the repository at this point in the history
  • Loading branch information
PicoCentauri committed Feb 16, 2024
1 parent 56f2a64 commit 691fbf5
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 20 deletions.
26 changes: 19 additions & 7 deletions src/metatensor/models/cli/eval_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from ..utils.compute_loss import compute_model_loss
from ..utils.data import collate_fn, read_structures, read_targets, write_predictions
from ..utils.errors import ArchitectureError
from ..utils.extract_targets import get_outputs_dict
from ..utils.info import finalize_aggregated_info, update_aggregated_info
from ..utils.loss import TensorMapDictLoss
Expand Down Expand Up @@ -90,7 +91,13 @@ def _eval_targets(model, dataset: Union[_BaseDataset, torch.utils.data.Subset])
finalized_info = finalize_aggregated_info(aggregated_info)

energy_counter = 0
for output in model.capabilities.outputs.values():

try:
outputs_capabilities = model.capabilities.outputs
except Exception as e:
ArchitectureError(e)

for output in outputs_capabilities.values():
if output.quantity == "energy":
energy_counter += 1
if energy_counter == 1:
Expand All @@ -104,7 +111,7 @@ def _eval_targets(model, dataset: Union[_BaseDataset, torch.utils.data.Subset])
if key.endswith("_positions_gradients"):
# check if this is a force
target_name = key[: -len("_positions_gradients")]
if model.capabilities.outputs[target_name].quantity == "energy":
if outputs_capabilities[target_name].quantity == "energy":
# if this is a force, replace the ugly name with "force"
if only_one_energy:
new_key = "force"
Expand All @@ -113,9 +120,8 @@ def _eval_targets(model, dataset: Union[_BaseDataset, torch.utils.data.Subset])
elif key.endswith("_displacement_gradients"):
# check if this is a virial/stress
target_name = key[: -len("_displacement_gradients")]
if model.capabilities.outputs[target_name].quantity == "energy":
# if this is a virial/stress,
# replace the ugly name with "virial/stress"
if outputs_capabilities[target_name].quantity == "energy":
# if this is a virial/stress, replace the ugly name with "virial/stress"
if only_one_energy:
new_key = "virial/stress"
else:
Expand All @@ -139,7 +145,10 @@ def eval_model(
"""
logging.basicConfig(level=logging.INFO, format="%(message)s")
logger.info("Setting up evaluation set.")
dtype = next(model.parameters()).dtype
try:
dtype = next(model.parameters()).dtype
except Exception as e:
ArchitectureError(e)

options = expand_dataset_config(options)
eval_structures = read_structures(
Expand All @@ -154,5 +163,8 @@ def eval_model(
_eval_targets(model, eval_dataset)

# Predict structures
predictions = model(eval_structures, model.capabilities.outputs)
try:
predictions = model(eval_structures, model.capabilities.outputs)
except Exception as e:
ArchitectureError(e)
write_predictions(output, predictions, eval_structures)
22 changes: 13 additions & 9 deletions src/metatensor/models/cli/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .. import CONFIG_PATH
from ..utils.data import get_all_species, read_structures, read_targets
from ..utils.data.dataset import _train_test_random_split
from ..utils.errors import ArchitectureError
from ..utils.model_io import save_model
from ..utils.omegaconf import check_units, expand_dataset_config
from .eval_model import _eval_targets
Expand Down Expand Up @@ -279,15 +280,18 @@ def _train_model_hydra(options: DictConfig) -> None:
)

logger.info("Calling architecture trainer")
model = architecture.train(
train_datasets=[train_dataset],
validation_datasets=[validation_dataset],
requested_capabilities=requested_capabilities,
hypers=OmegaConf.to_container(options["architecture"]),
continue_from=options["continue_from"],
output_dir=output_dir,
device_str=options["device"],
)
try:
model = architecture.train(
train_datasets=[train_dataset],
validation_datasets=[validation_dataset],
requested_capabilities=requested_capabilities,
hypers=OmegaConf.to_container(options["architecture"]),
continue_from=options["continue_from"],
output_dir=output_dir,
device_str=options["device"],
)
except Exception as e:
ArchitectureError(e)

save_model(model, options["output_path"])

Expand Down
14 changes: 10 additions & 4 deletions src/metatensor/models/utils/compute_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from metatensor.torch import Labels, TensorBlock, TensorMap
from metatensor.torch.atomistic import System

from .errors import ArchitectureError
from .loss import TensorMapDictLoss
from .output_gradient import compute_gradient

Expand All @@ -24,12 +25,17 @@ def compute_model_loss(
:returns: The loss as a scalar `torch.Tensor`.
"""
try:
device = next(model.parameters()).device
outputs_capabilities = model.capabilities.outputs
except Exception as e:
ArchitectureError(e)

# Assert that all targets are within the model's capabilities:
if not set(targets.keys()).issubset(model.capabilities.outputs.keys()):
if not set(targets.keys()).issubset(outputs_capabilities.keys()):
raise ValueError("Not all targets are within the model's capabilities.")

# Infer model device, move systems and targets to the same device:
device = next(model.parameters()).device
# Infer move systems and targets to the same device:
systems = [system.to(device=device) for system in systems]
targets = {key: target.to(device=device) for key, target in targets.items()}

Expand All @@ -39,7 +45,7 @@ def compute_model_loss(
energy_targets_that_require_strain_gradients = []
for target_name in targets.keys():
# Check if the target is an energy:
if model.capabilities.outputs[target_name].quantity == "energy":
if outputs_capabilities[target_name].quantity == "energy":
energy_targets.append(target_name)
# Check if the energy requires gradients:
if targets[target_name].block().has_gradient("positions"):
Expand Down
20 changes: 20 additions & 0 deletions src/metatensor/models/utils/errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
class ArchitectureError(Exception):
"""
Exception raised for errors originating from architectures
This exception should be raised when an error occurs within an architecture's
operation, indicating that the problem is not directly related to the
metatensor-models infrastructure but rather to the specific architecture being used.
:param exception: The original exception that was caught, which led to raising this
custom exception.
:type exception: The exception message includes the message of the original
exception, followed by a note emphasizing that the error likely originates from
an architecture.
"""

def __init__(self, exception):
super().__init__(
f"{exception}\n\nThis error originates from an architecture, and is likely "
"not a problem with metatensor-models."
)
11 changes: 11 additions & 0 deletions tests/utils/test_errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import pytest

from metatensor.models.utils.errors import ArchitectureError


def test_architecture_erro():
with pytest.raises(ArchitectureError, match="not a problem with metatensor-models"):
try:
raise ValueError("An example error from the architecture")
except Exception as e:
raise ArchitectureError(e)

0 comments on commit 691fbf5

Please sign in to comment.