diff --git a/docs/src/dev-docs/utils/index.rst b/docs/src/dev-docs/utils/index.rst index 03d68cf13..83e486ad3 100644 --- a/docs/src/dev-docs/utils/index.rst +++ b/docs/src/dev-docs/utils/index.rst @@ -21,3 +21,4 @@ This is the API for the ``utils`` module of ``metatensor-models``. neighbor_lists omegaconf output_gradient + per_atom diff --git a/docs/src/dev-docs/utils/per_atom.rst b/docs/src/dev-docs/utils/per_atom.rst new file mode 100644 index 000000000..f9f0d8c49 --- /dev/null +++ b/docs/src/dev-docs/utils/per_atom.rst @@ -0,0 +1,7 @@ +Averaging predictions per atom +############################## + +.. automodule:: metatensor.models.utils.per_atom + :members: + :undoc-members: + :show-inheritance: diff --git a/src/metatensor/models/cli/eval.py b/src/metatensor/models/cli/eval.py index 691cf69f1..8aaafd45a 100644 --- a/src/metatensor/models/cli/eval.py +++ b/src/metatensor/models/cli/eval.py @@ -23,6 +23,7 @@ from ..utils.metrics import RMSEAccumulator from ..utils.neighbor_lists import get_system_with_neighbor_lists from ..utils.omegaconf import expand_dataset_config +from ..utils.per_atom import average_predictions_and_targets_by_num_atoms from .formatter import CustomHelpFormatter @@ -159,11 +160,19 @@ def _eval_targets( # Evaluate the model for batch in dataloader: - systems, targets = batch + systems, batch_targets = batch systems = [system.to(device=device) for system in systems] - targets = {key: value.to(device=device) for key, value in targets.items()} + batch_targets = { + key: value.to(device=device) for key, value in batch_targets.items() + } batch_predictions = evaluate_model(model, systems, options, is_training=False) - rmse_accumulator.update(batch_predictions, targets) + batch_predictions, batch_targets = average_predictions_and_targets_by_num_atoms( + predictions=batch_predictions, + targets=batch_targets, + systems=systems, + per_structure_targets=[], + ) + rmse_accumulator.update(batch_predictions, batch_targets) if return_predictions: all_predictions.append(batch_predictions) diff --git a/src/metatensor/models/experimental/alchemical_model/train.py b/src/metatensor/models/experimental/alchemical_model/train.py index 36d659c1a..dbc4b3e1a 100644 --- a/src/metatensor/models/experimental/alchemical_model/train.py +++ b/src/metatensor/models/experimental/alchemical_model/train.py @@ -23,7 +23,7 @@ from ...utils.loss import TensorMapDictLoss from ...utils.metrics import RMSEAccumulator from ...utils.neighbor_lists import get_system_with_neighbor_lists -from ...utils.per_atom import divide_by_num_atoms +from ...utils.per_atom import average_predictions_and_targets_by_num_atoms from . import DEFAULT_HYPERS from .model import Model from .utils.normalize import ( @@ -257,7 +257,7 @@ def train( ) # average by the number of atoms - predictions, targets = _average_by_num_atoms( + predictions, targets = average_predictions_and_targets_by_num_atoms( predictions, targets, systems, per_structure_targets ) @@ -284,7 +284,7 @@ def train( ) # average by the number of atoms - predictions, targets = _average_by_num_atoms( + predictions, targets = average_predictions_and_targets_by_num_atoms( predictions, targets, systems, per_structure_targets ) @@ -337,15 +337,3 @@ def train( break return model - - -def _average_by_num_atoms(predictions, targets, systems, per_structure_targets): - device = systems[0].device - num_atoms = torch.tensor([len(s) for s in systems], device=device) - for target in targets.keys(): - if target in per_structure_targets: - continue - predictions[target] = divide_by_num_atoms(predictions[target], num_atoms) - targets[target] = divide_by_num_atoms(targets[target], num_atoms) - - return predictions, targets diff --git a/src/metatensor/models/experimental/soap_bpnn/train.py b/src/metatensor/models/experimental/soap_bpnn/train.py index ae027b412..9e15258c5 100644 --- a/src/metatensor/models/experimental/soap_bpnn/train.py +++ b/src/metatensor/models/experimental/soap_bpnn/train.py @@ -24,7 +24,7 @@ from ...utils.loss import TensorMapDictLoss from ...utils.merge_capabilities import merge_capabilities from ...utils.metrics import RMSEAccumulator -from ...utils.per_atom import divide_by_num_atoms +from ...utils.per_atom import average_predictions_and_targets_by_num_atoms from . import DEFAULT_HYPERS from .model import Model @@ -262,7 +262,7 @@ def train( ) # average by the number of atoms - predictions, targets = _average_by_num_atoms( + predictions, targets = average_predictions_and_targets_by_num_atoms( predictions, targets, systems, per_structure_targets ) @@ -288,7 +288,7 @@ def train( ) # average by the number of atoms - predictions, targets = _average_by_num_atoms( + predictions, targets = average_predictions_and_targets_by_num_atoms( predictions, targets, systems, per_structure_targets ) @@ -341,15 +341,3 @@ def train( break return model - - -def _average_by_num_atoms(predictions, targets, systems, per_structure_targets): - device = systems[0].device - num_atoms = torch.tensor([len(s) for s in systems], device=device) - for target in targets.keys(): - if target in per_structure_targets: - continue - predictions[target] = divide_by_num_atoms(predictions[target], num_atoms) - targets[target] = divide_by_num_atoms(targets[target], num_atoms) - - return predictions, targets diff --git a/src/metatensor/models/utils/per_atom.py b/src/metatensor/models/utils/per_atom.py index 320e1bcc6..85af5ce54 100644 --- a/src/metatensor/models/utils/per_atom.py +++ b/src/metatensor/models/utils/per_atom.py @@ -1,5 +1,41 @@ +from typing import Dict, List + import torch from metatensor.torch import TensorBlock, TensorMap +from metatensor.torch.atomistic import System + + +def average_predictions_and_targets_by_num_atoms( + predictions: Dict[str, TensorMap], + targets: Dict[str, TensorMap], + systems: List[System], + per_structure_targets: List[str], +): + """Averages predictions and targets by the number of atoms in each system. + + This function averages predictions and targets by the number of atoms + in each system. Targets that are present in ``per_structure_targets`` will + not be averaged. + + :param predictions: A dictionary of predictions. + :param targets: A dictionary of targets. + :param systems: The systems used to compute the predictions. + :param per_structure_targets: A list of targets that should not be averaged. + """ + averaged_predictions = {} + averaged_targets = {} + device = systems[0].device + num_atoms = torch.tensor([len(s) for s in systems], device=device) + for target in targets.keys(): + if target in per_structure_targets: + averaged_predictions[target] = predictions[target] + averaged_targets[target] = targets[target] + averaged_predictions[target] = divide_by_num_atoms( + predictions[target], num_atoms + ) + averaged_targets[target] = divide_by_num_atoms(targets[target], num_atoms) + + return averaged_predictions, averaged_targets def divide_by_num_atoms(tensor_map: TensorMap, num_atoms: torch.Tensor) -> TensorMap: @@ -12,6 +48,11 @@ def divide_by_num_atoms(tensor_map: TensorMap, num_atoms: torch.Tensor) -> Tenso the majority of the cases, including energies, forces, and virials, where the energies and virials should be divided by the number of atoms, while the forces should not. + + :param tensor_map: The input tensor map. + :param num_atoms: The number of atoms in each system. + + :return: A new tensor map with the values divided by the number of atoms. """ blocks = [] diff --git a/tests/utils/test_per_atom.py b/tests/utils/test_per_atom.py index 916c6bf71..0db4ea774 100644 --- a/tests/utils/test_per_atom.py +++ b/tests/utils/test_per_atom.py @@ -1,7 +1,56 @@ import torch from metatensor.torch import Labels, TensorBlock, TensorMap +from metatensor.torch.atomistic import System -from metatensor.models.utils.per_atom import divide_by_num_atoms +from metatensor.models.utils.per_atom import ( + average_predictions_and_targets_by_num_atoms, + divide_by_num_atoms, +) + + +def test_average_predictions_and_targets_by_num_atoms(): + """Tests the average_predictions_and_targets_by_num_atoms function.""" + + systems = [ + System( + positions=torch.tensor([[0.0, 0.0, 0.0]]), + cell=torch.eye(3), + types=torch.tensor([0]), + ), + System( + positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]), + cell=torch.eye(3), + types=torch.tensor([0, 0]), + ), + System( + positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]), + cell=torch.eye(3), + types=torch.tensor([0, 0, 0]), + ), + ] + + block = TensorBlock( + values=torch.tensor([[1.0], [2.0], [3.0]]), + samples=Labels.range("samples", 3), + components=[], + properties=Labels("energy", torch.tensor([[0]])), + ) + + tensor_map = TensorMap(keys=Labels.single(), blocks=[block]) + tensor_map_dict = {"energy": tensor_map} + + averaged_predictions, averaged_targets = ( + average_predictions_and_targets_by_num_atoms( + tensor_map_dict, tensor_map_dict, systems, per_structure_targets=[] + ) + ) + + assert torch.allclose( + averaged_predictions["energy"].block().values, torch.tensor([1.0, 1.0, 1.0]) + ) + assert torch.allclose( + averaged_targets["energy"].block().values, torch.tensor([1.0, 1.0, 1.0]) + ) def test_divide_by_num_atoms():