Skip to content

Commit

Permalink
Evaluate per-atom properties
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster committed May 18, 2024
1 parent 3fa7fbf commit 9b01b20
Show file tree
Hide file tree
Showing 7 changed files with 117 additions and 34 deletions.
1 change: 1 addition & 0 deletions docs/src/dev-docs/utils/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@ This is the API for the ``utils`` module of ``metatensor-models``.
neighbor_lists
omegaconf
output_gradient
per_atom
7 changes: 7 additions & 0 deletions docs/src/dev-docs/utils/per_atom.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Averaging predictions per atom
##############################

.. automodule:: metatensor.models.utils.per_atom
:members:
:undoc-members:
:show-inheritance:
15 changes: 12 additions & 3 deletions src/metatensor/models/cli/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from ..utils.metrics import RMSEAccumulator
from ..utils.neighbor_lists import get_system_with_neighbor_lists
from ..utils.omegaconf import expand_dataset_config
from ..utils.per_atom import average_predictions_and_targets_by_num_atoms
from .formatter import CustomHelpFormatter


Expand Down Expand Up @@ -159,11 +160,19 @@ def _eval_targets(

# Evaluate the model
for batch in dataloader:
systems, targets = batch
systems, batch_targets = batch
systems = [system.to(device=device) for system in systems]
targets = {key: value.to(device=device) for key, value in targets.items()}
batch_targets = {
key: value.to(device=device) for key, value in batch_targets.items()
}
batch_predictions = evaluate_model(model, systems, options, is_training=False)
rmse_accumulator.update(batch_predictions, targets)
batch_predictions, batch_targets = average_predictions_and_targets_by_num_atoms(
predictions=batch_predictions,
targets=batch_targets,
systems=systems,
per_structure_targets=[],
)
rmse_accumulator.update(batch_predictions, batch_targets)
if return_predictions:
all_predictions.append(batch_predictions)

Expand Down
18 changes: 3 additions & 15 deletions src/metatensor/models/experimental/alchemical_model/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from ...utils.loss import TensorMapDictLoss
from ...utils.metrics import RMSEAccumulator
from ...utils.neighbor_lists import get_system_with_neighbor_lists
from ...utils.per_atom import divide_by_num_atoms
from ...utils.per_atom import average_predictions_and_targets_by_num_atoms
from . import DEFAULT_HYPERS
from .model import Model
from .utils.normalize import (
Expand Down Expand Up @@ -257,7 +257,7 @@ def train(
)

# average by the number of atoms
predictions, targets = _average_by_num_atoms(
predictions, targets = average_predictions_and_targets_by_num_atoms(
predictions, targets, systems, per_structure_targets
)

Expand All @@ -284,7 +284,7 @@ def train(
)

# average by the number of atoms
predictions, targets = _average_by_num_atoms(
predictions, targets = average_predictions_and_targets_by_num_atoms(
predictions, targets, systems, per_structure_targets
)

Expand Down Expand Up @@ -337,15 +337,3 @@ def train(
break

return model


def _average_by_num_atoms(predictions, targets, systems, per_structure_targets):
device = systems[0].device
num_atoms = torch.tensor([len(s) for s in systems], device=device)
for target in targets.keys():
if target in per_structure_targets:
continue
predictions[target] = divide_by_num_atoms(predictions[target], num_atoms)
targets[target] = divide_by_num_atoms(targets[target], num_atoms)

return predictions, targets
18 changes: 3 additions & 15 deletions src/metatensor/models/experimental/soap_bpnn/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from ...utils.loss import TensorMapDictLoss
from ...utils.merge_capabilities import merge_capabilities
from ...utils.metrics import RMSEAccumulator
from ...utils.per_atom import divide_by_num_atoms
from ...utils.per_atom import average_predictions_and_targets_by_num_atoms
from . import DEFAULT_HYPERS
from .model import Model

Expand Down Expand Up @@ -262,7 +262,7 @@ def train(
)

# average by the number of atoms
predictions, targets = _average_by_num_atoms(
predictions, targets = average_predictions_and_targets_by_num_atoms(
predictions, targets, systems, per_structure_targets
)

Expand All @@ -288,7 +288,7 @@ def train(
)

# average by the number of atoms
predictions, targets = _average_by_num_atoms(
predictions, targets = average_predictions_and_targets_by_num_atoms(
predictions, targets, systems, per_structure_targets
)

Expand Down Expand Up @@ -341,15 +341,3 @@ def train(
break

return model


def _average_by_num_atoms(predictions, targets, systems, per_structure_targets):
device = systems[0].device
num_atoms = torch.tensor([len(s) for s in systems], device=device)
for target in targets.keys():
if target in per_structure_targets:
continue
predictions[target] = divide_by_num_atoms(predictions[target], num_atoms)
targets[target] = divide_by_num_atoms(targets[target], num_atoms)

return predictions, targets
41 changes: 41 additions & 0 deletions src/metatensor/models/utils/per_atom.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,41 @@
from typing import Dict, List

import torch
from metatensor.torch import TensorBlock, TensorMap
from metatensor.torch.atomistic import System


def average_predictions_and_targets_by_num_atoms(
predictions: Dict[str, TensorMap],
targets: Dict[str, TensorMap],
systems: List[System],
per_structure_targets: List[str],
):
"""Averages predictions and targets by the number of atoms in each system.
This function averages predictions and targets by the number of atoms
in each system. Targets that are present in ``per_structure_targets`` will
not be averaged.
:param predictions: A dictionary of predictions.
:param targets: A dictionary of targets.
:param systems: The systems used to compute the predictions.
:param per_structure_targets: A list of targets that should not be averaged.
"""
averaged_predictions = {}
averaged_targets = {}
device = systems[0].device
num_atoms = torch.tensor([len(s) for s in systems], device=device)
for target in targets.keys():
if target in per_structure_targets:
averaged_predictions[target] = predictions[target]
averaged_targets[target] = targets[target]
averaged_predictions[target] = divide_by_num_atoms(
predictions[target], num_atoms
)
averaged_targets[target] = divide_by_num_atoms(targets[target], num_atoms)

return averaged_predictions, averaged_targets


def divide_by_num_atoms(tensor_map: TensorMap, num_atoms: torch.Tensor) -> TensorMap:
Expand All @@ -12,6 +48,11 @@ def divide_by_num_atoms(tensor_map: TensorMap, num_atoms: torch.Tensor) -> Tenso
the majority of the cases, including energies, forces, and virials, where
the energies and virials should be divided by the number of atoms, while
the forces should not.
:param tensor_map: The input tensor map.
:param num_atoms: The number of atoms in each system.
:return: A new tensor map with the values divided by the number of atoms.
"""

blocks = []
Expand Down
51 changes: 50 additions & 1 deletion tests/utils/test_per_atom.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,56 @@
import torch
from metatensor.torch import Labels, TensorBlock, TensorMap
from metatensor.torch.atomistic import System

from metatensor.models.utils.per_atom import divide_by_num_atoms
from metatensor.models.utils.per_atom import (
average_predictions_and_targets_by_num_atoms,
divide_by_num_atoms,
)


def test_average_predictions_and_targets_by_num_atoms():
"""Tests the average_predictions_and_targets_by_num_atoms function."""

systems = [
System(
positions=torch.tensor([[0.0, 0.0, 0.0]]),
cell=torch.eye(3),
types=torch.tensor([0]),
),
System(
positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]),
cell=torch.eye(3),
types=torch.tensor([0, 0]),
),
System(
positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]),
cell=torch.eye(3),
types=torch.tensor([0, 0, 0]),
),
]

block = TensorBlock(
values=torch.tensor([[1.0], [2.0], [3.0]]),
samples=Labels.range("samples", 3),
components=[],
properties=Labels("energy", torch.tensor([[0]])),
)

tensor_map = TensorMap(keys=Labels.single(), blocks=[block])
tensor_map_dict = {"energy": tensor_map}

averaged_predictions, averaged_targets = (
average_predictions_and_targets_by_num_atoms(
tensor_map_dict, tensor_map_dict, systems, per_structure_targets=[]
)
)

assert torch.allclose(
averaged_predictions["energy"].block().values, torch.tensor([1.0, 1.0, 1.0])
)
assert torch.allclose(
averaged_targets["energy"].block().values, torch.tensor([1.0, 1.0, 1.0])
)


def test_divide_by_num_atoms():
Expand Down

0 comments on commit 9b01b20

Please sign in to comment.