From 691fbf5e85ce9eb248ded97d07e5c2c6413eb21d Mon Sep 17 00:00:00 2001 From: Philip Loche Date: Fri, 16 Feb 2024 17:08:58 +0100 Subject: [PATCH] Add custom ArchitectureError --- src/metatensor/models/cli/eval_model.py | 26 +++++++++++++++------ src/metatensor/models/cli/train_model.py | 22 ++++++++++------- src/metatensor/models/utils/compute_loss.py | 14 +++++++---- src/metatensor/models/utils/errors.py | 20 ++++++++++++++++ tests/utils/test_errors.py | 11 +++++++++ 5 files changed, 73 insertions(+), 20 deletions(-) create mode 100644 src/metatensor/models/utils/errors.py create mode 100644 tests/utils/test_errors.py diff --git a/src/metatensor/models/cli/eval_model.py b/src/metatensor/models/cli/eval_model.py index 31ca2a481..bcbc3d8da 100644 --- a/src/metatensor/models/cli/eval_model.py +++ b/src/metatensor/models/cli/eval_model.py @@ -8,6 +8,7 @@ from ..utils.compute_loss import compute_model_loss from ..utils.data import collate_fn, read_structures, read_targets, write_predictions +from ..utils.errors import ArchitectureError from ..utils.extract_targets import get_outputs_dict from ..utils.info import finalize_aggregated_info, update_aggregated_info from ..utils.loss import TensorMapDictLoss @@ -90,7 +91,13 @@ def _eval_targets(model, dataset: Union[_BaseDataset, torch.utils.data.Subset]) finalized_info = finalize_aggregated_info(aggregated_info) energy_counter = 0 - for output in model.capabilities.outputs.values(): + + try: + outputs_capabilities = model.capabilities.outputs + except Exception as e: + ArchitectureError(e) + + for output in outputs_capabilities.values(): if output.quantity == "energy": energy_counter += 1 if energy_counter == 1: @@ -104,7 +111,7 @@ def _eval_targets(model, dataset: Union[_BaseDataset, torch.utils.data.Subset]) if key.endswith("_positions_gradients"): # check if this is a force target_name = key[: -len("_positions_gradients")] - if model.capabilities.outputs[target_name].quantity == "energy": + if outputs_capabilities[target_name].quantity == "energy": # if this is a force, replace the ugly name with "force" if only_one_energy: new_key = "force" @@ -113,9 +120,8 @@ def _eval_targets(model, dataset: Union[_BaseDataset, torch.utils.data.Subset]) elif key.endswith("_displacement_gradients"): # check if this is a virial/stress target_name = key[: -len("_displacement_gradients")] - if model.capabilities.outputs[target_name].quantity == "energy": - # if this is a virial/stress, - # replace the ugly name with "virial/stress" + if outputs_capabilities[target_name].quantity == "energy": + # if this is a virial/stress, replace the ugly name with "virial/stress" if only_one_energy: new_key = "virial/stress" else: @@ -139,7 +145,10 @@ def eval_model( """ logging.basicConfig(level=logging.INFO, format="%(message)s") logger.info("Setting up evaluation set.") - dtype = next(model.parameters()).dtype + try: + dtype = next(model.parameters()).dtype + except Exception as e: + ArchitectureError(e) options = expand_dataset_config(options) eval_structures = read_structures( @@ -154,5 +163,8 @@ def eval_model( _eval_targets(model, eval_dataset) # Predict structures - predictions = model(eval_structures, model.capabilities.outputs) + try: + predictions = model(eval_structures, model.capabilities.outputs) + except Exception as e: + ArchitectureError(e) write_predictions(output, predictions, eval_structures) diff --git a/src/metatensor/models/cli/train_model.py b/src/metatensor/models/cli/train_model.py index dc87a56f7..5b7a751f2 100644 --- a/src/metatensor/models/cli/train_model.py +++ b/src/metatensor/models/cli/train_model.py @@ -19,6 +19,7 @@ from .. import CONFIG_PATH from ..utils.data import get_all_species, read_structures, read_targets from ..utils.data.dataset import _train_test_random_split +from ..utils.errors import ArchitectureError from ..utils.model_io import save_model from ..utils.omegaconf import check_units, expand_dataset_config from .eval_model import _eval_targets @@ -279,15 +280,18 @@ def _train_model_hydra(options: DictConfig) -> None: ) logger.info("Calling architecture trainer") - model = architecture.train( - train_datasets=[train_dataset], - validation_datasets=[validation_dataset], - requested_capabilities=requested_capabilities, - hypers=OmegaConf.to_container(options["architecture"]), - continue_from=options["continue_from"], - output_dir=output_dir, - device_str=options["device"], - ) + try: + model = architecture.train( + train_datasets=[train_dataset], + validation_datasets=[validation_dataset], + requested_capabilities=requested_capabilities, + hypers=OmegaConf.to_container(options["architecture"]), + continue_from=options["continue_from"], + output_dir=output_dir, + device_str=options["device"], + ) + except Exception as e: + ArchitectureError(e) save_model(model, options["output_path"]) diff --git a/src/metatensor/models/utils/compute_loss.py b/src/metatensor/models/utils/compute_loss.py index 696f966a9..3fe5cf79a 100644 --- a/src/metatensor/models/utils/compute_loss.py +++ b/src/metatensor/models/utils/compute_loss.py @@ -4,6 +4,7 @@ from metatensor.torch import Labels, TensorBlock, TensorMap from metatensor.torch.atomistic import System +from .errors import ArchitectureError from .loss import TensorMapDictLoss from .output_gradient import compute_gradient @@ -24,12 +25,17 @@ def compute_model_loss( :returns: The loss as a scalar `torch.Tensor`. """ + try: + device = next(model.parameters()).device + outputs_capabilities = model.capabilities.outputs + except Exception as e: + ArchitectureError(e) + # Assert that all targets are within the model's capabilities: - if not set(targets.keys()).issubset(model.capabilities.outputs.keys()): + if not set(targets.keys()).issubset(outputs_capabilities.keys()): raise ValueError("Not all targets are within the model's capabilities.") - # Infer model device, move systems and targets to the same device: - device = next(model.parameters()).device + # Infer move systems and targets to the same device: systems = [system.to(device=device) for system in systems] targets = {key: target.to(device=device) for key, target in targets.items()} @@ -39,7 +45,7 @@ def compute_model_loss( energy_targets_that_require_strain_gradients = [] for target_name in targets.keys(): # Check if the target is an energy: - if model.capabilities.outputs[target_name].quantity == "energy": + if outputs_capabilities[target_name].quantity == "energy": energy_targets.append(target_name) # Check if the energy requires gradients: if targets[target_name].block().has_gradient("positions"): diff --git a/src/metatensor/models/utils/errors.py b/src/metatensor/models/utils/errors.py new file mode 100644 index 000000000..131cf7c7d --- /dev/null +++ b/src/metatensor/models/utils/errors.py @@ -0,0 +1,20 @@ +class ArchitectureError(Exception): + """ + Exception raised for errors originating from architectures + + This exception should be raised when an error occurs within an architecture's + operation, indicating that the problem is not directly related to the + metatensor-models infrastructure but rather to the specific architecture being used. + + :param exception: The original exception that was caught, which led to raising this + custom exception. + :type exception: The exception message includes the message of the original + exception, followed by a note emphasizing that the error likely originates from + an architecture. + """ + + def __init__(self, exception): + super().__init__( + f"{exception}\n\nThis error originates from an architecture, and is likely " + "not a problem with metatensor-models." + ) diff --git a/tests/utils/test_errors.py b/tests/utils/test_errors.py new file mode 100644 index 000000000..82d31ae3d --- /dev/null +++ b/tests/utils/test_errors.py @@ -0,0 +1,11 @@ +import pytest + +from metatensor.models.utils.errors import ArchitectureError + + +def test_architecture_erro(): + with pytest.raises(ArchitectureError, match="not a problem with metatensor-models"): + try: + raise ValueError("An example error from the architecture") + except Exception as e: + raise ArchitectureError(e)