diff --git a/.readthedocs.yml b/.readthedocs.yml index 2ebedc6ba..dc12989c7 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -15,6 +15,7 @@ build: pre_build: - set -e && cd examples/ase && bash train.sh - set -e && cd examples/programmatic/llpr && bash train.sh + - set -e && cd examples/zbl && bash train.sh # Build documentation in the docs/ directory with Sphinx sphinx: diff --git a/docs/generate_examples/conf.py b/docs/generate_examples/conf.py index 79715a5ea..40448f3e0 100644 --- a/docs/generate_examples/conf.py +++ b/docs/generate_examples/conf.py @@ -13,8 +13,16 @@ sphinx_gallery_conf = { "filename_pattern": "/*", "copyfile_regex": r".*\.(pt|sh|xyz|yaml)", - "examples_dirs": [os.path.join(ROOT, "examples", "ase"), os.path.join(ROOT, "examples", "programmatic", "llpr")], - "gallery_dirs": [os.path.join(ROOT, "docs", "src", "examples", "ase"), os.path.join(ROOT, "docs", "src", "examples", "programmatic", "llpr")], + "examples_dirs": [ + os.path.join(ROOT, "examples", "ase"), + os.path.join(ROOT, "examples", "programmatic", "llpr"), + os.path.join(ROOT, "examples", "zbl") + ], + "gallery_dirs": [ + os.path.join(ROOT, "docs", "src", "examples", "ase"), + os.path.join(ROOT, "docs", "src", "examples", "programmatic", "llpr"), + os.path.join(ROOT, "docs", "src", "examples", "zbl") + ], "min_reported_time": 5, "matplotlib_animations": True, } diff --git a/docs/src/dev-docs/utils/additive/composition.rst b/docs/src/dev-docs/utils/additive/composition.rst new file mode 100644 index 000000000..4499f0c97 --- /dev/null +++ b/docs/src/dev-docs/utils/additive/composition.rst @@ -0,0 +1,7 @@ +Composition model +################# + +.. automodule:: metatrain.utils.additive.composition + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/src/dev-docs/utils/additive/index.rst b/docs/src/dev-docs/utils/additive/index.rst new file mode 100644 index 000000000..f2aed2d26 --- /dev/null +++ b/docs/src/dev-docs/utils/additive/index.rst @@ -0,0 +1,12 @@ +Additive models +=============== + +API for handling additive models in ``metatrain``. These are models that +can be added to one or more architectures. + +.. toctree:: + :maxdepth: 1 + + remove_additive + composition + zbl diff --git a/docs/src/dev-docs/utils/additive/remove_additive.rst b/docs/src/dev-docs/utils/additive/remove_additive.rst new file mode 100644 index 000000000..6a115a471 --- /dev/null +++ b/docs/src/dev-docs/utils/additive/remove_additive.rst @@ -0,0 +1,7 @@ +Removing additive contributions +############################### + +.. automodule:: metatrain.utils.additive.remove + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/src/dev-docs/utils/additive/zbl.rst b/docs/src/dev-docs/utils/additive/zbl.rst new file mode 100644 index 000000000..ab0248bde --- /dev/null +++ b/docs/src/dev-docs/utils/additive/zbl.rst @@ -0,0 +1,7 @@ +ZBL short-range potential +######################### + +.. automodule:: metatrain.utils.additive.zbl + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/src/dev-docs/utils/composition.rst b/docs/src/dev-docs/utils/composition.rst deleted file mode 100644 index 0a6cb2a34..000000000 --- a/docs/src/dev-docs/utils/composition.rst +++ /dev/null @@ -1,7 +0,0 @@ -Composition -########### - -.. automodule:: metatrain.utils.composition - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/src/dev-docs/utils/index.rst b/docs/src/dev-docs/utils/index.rst index 3a2d12d07..312ee5342 100644 --- a/docs/src/dev-docs/utils/index.rst +++ b/docs/src/dev-docs/utils/index.rst @@ -6,9 +6,9 @@ This is the API for the ``utils`` module of ``metatrain``. .. toctree:: :maxdepth: 1 + additive/index data/index architectures - composition devices dtype errors diff --git a/docs/src/tutorials/index.rst b/docs/src/tutorials/index.rst index 3414fda6c..c8b0f5e4c 100644 --- a/docs/src/tutorials/index.rst +++ b/docs/src/tutorials/index.rst @@ -11,3 +11,4 @@ This sections includes some more advanced tutorials on the usage of the ../examples/ase/run_ase ../examples/programmatic/llpr/llpr + ../examples/zbl/dimers diff --git a/examples/ase/run_ase.py b/examples/ase/run_ase.py index 18a172d34..030f40dca 100644 --- a/examples/ase/run_ase.py +++ b/examples/ase/run_ase.py @@ -42,11 +42,6 @@ # %% # -# .. note:: -# We have to import ``rascaline.torch`` even though it is not used explicitly in this -# tutorial. The SOAP-BPNN model contains compiled extensions and therefore the import -# is required. -# # Setting up the simulation # ------------------------- # diff --git a/examples/programmatic/llpr/llpr.py b/examples/programmatic/llpr/llpr.py index 8db135c86..10857aaaf 100644 --- a/examples/programmatic/llpr/llpr.py +++ b/examples/programmatic/llpr/llpr.py @@ -48,7 +48,10 @@ # how to create a Dataset object from them. from metatrain.utils.data import Dataset, read_systems, read_targets # noqa: E402 -from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists # noqa: E402 +from metatrain.utils.neighbor_lists import ( # noqa: E402 + get_requested_neighbor_lists, + get_system_with_neighbor_lists, +) qm9_systems = read_systems("qm9_reduced_100.xyz") @@ -67,7 +70,7 @@ } targets, _ = read_targets(target_config) -requested_neighbor_lists = model.requested_neighbor_lists() +requested_neighbor_lists = get_requested_neighbor_lists(model) qm9_systems = [ get_system_with_neighbor_lists(system, requested_neighbor_lists) for system in qm9_systems diff --git a/examples/zbl/README.rst b/examples/zbl/README.rst new file mode 100644 index 000000000..50f7bdb41 --- /dev/null +++ b/examples/zbl/README.rst @@ -0,0 +1,2 @@ +Running molecular dynamics with ASE +=================================== diff --git a/examples/zbl/dimers.py b/examples/zbl/dimers.py new file mode 100644 index 000000000..04069a5a5 --- /dev/null +++ b/examples/zbl/dimers.py @@ -0,0 +1,148 @@ +""" +Training a model with ZBL corrections +===================================== + +This tutorial demonstrates how to train a model with ZBL corrections. + +The training set for this example consists of a +subset of the ethanol moleculs from the `rMD17 dataset +`_. + +The models are trained using the following training options, respectively: + +.. literalinclude:: options_no_zbl.yaml + :language: yaml + +.. literalinclude:: options_zbl.yaml + :language: yaml + +As you can see, they are identical, except for the ``zbl`` key in the +``model`` section. +You can train the same models yourself with + +.. literalinclude:: train.sh + :language: bash + +A detailed step-by-step introduction on how to train a model is provided in +the :ref:`label_basic_usage` tutorial. +""" + +# %% +# +# First, we start by importing the necessary libraries, including the integration of ASE +# calculators for metatensor atomistic models. + +import ase +import matplotlib.pyplot as plt +import numpy as np +import torch +from metatensor.torch.atomistic.ase_calculator import MetatensorCalculator + + +# %% +# +# Setting up the dimers +# --------------------- +# +# We set up a series of dimers with different atom pairs and distances. We will +# calculate the energies of these dimers using the models trained with and without ZBL +# corrections. + +distances = np.linspace(0.5, 6.0, 200) +pairs = {} +for pair in [("H", "H"), ("H", "C"), ("C", "C"), ("C", "O"), ("O", "O"), ("H", "O")]: + structures = [] + for distance in distances: + atoms = ase.Atoms( + symbols=[pair[0], pair[1]], + positions=[[0, 0, 0], [0, 0, distance]], + ) + structures.append(atoms) + pairs[pair] = structures + +# %% +# +# We now load the two exported models, one with and one without ZBL corrections + +calc_no_zbl = MetatensorCalculator( + "model_no_zbl.pt", extensions_directory="extensions/" +) +calc_zbl = MetatensorCalculator("model_zbl.pt", extensions_directory="extensions/") + + +# %% +# +# Calculate and plot energies without ZBL +# --------------------------------------- +# +# We calculate the energies of the dimer curves for each pair of atoms and +# plot the results, using the non-ZBL-corrected model. + +for pair, structures_for_pair in pairs.items(): + energies = [] + for atoms in structures_for_pair: + atoms.set_calculator(calc_no_zbl) + with torch.jit.optimized_execution(False): + energies.append(atoms.get_potential_energy()) + energies = np.array(energies) - energies[-1] + plt.plot(distances, energies, label=f"{pair[0]}-{pair[1]}") +plt.title("Dimer curves - no ZBL") +plt.xlabel("Distance (Å)") +plt.ylabel("Energy (eV)") +plt.legend() +plt.tight_layout() +plt.show() + +# %% +# +# Calculate and plot energies from the ZBL-corrected model +# -------------------------------------------------------- +# +# We repeat the same procedure as above, but this time with the ZBL-corrected model. + +for pair, structures_for_pair in pairs.items(): + energies = [] + for atoms in structures_for_pair: + atoms.set_calculator(calc_zbl) + with torch.jit.optimized_execution(False): + energies.append(atoms.get_potential_energy()) + energies = np.array(energies) - energies[-1] + plt.plot(distances, energies, label=f"{pair[0]}-{pair[1]}") +plt.title("Dimer curves - with ZBL") +plt.xlabel("Distance (Å)") +plt.ylabel("Energy (eV)") +plt.legend() +plt.tight_layout() +plt.show() + +# %% +# +# It can be seen that all the dimer curves include a strong repulsion +# at short distances, which is due to the ZBL contribution. Even the H-H dimer, +# whose ZBL correction is very weak due to the small covalent radii of hydrogen, +# would show a strong repulsion closer to the origin (here, we only plotted +# starting from a distance of 0.5 Å). Let's zoom in on the H-H dimer to see +# this effect more clearly. + +new_distances = np.linspace(0.1, 2.0, 200) + +structures = [] +for distance in new_distances: + atoms = ase.Atoms( + symbols=["H", "H"], + positions=[[0, 0, 0], [0, 0, distance]], + ) + structures.append(atoms) + +for atoms in structures: + atoms.set_calculator(calc_zbl) +with torch.jit.optimized_execution(False): + energies = [atoms.get_potential_energy() for atoms in structures] +energies = np.array(energies) - energies[-1] +plt.plot(new_distances, energies, label="H-H") +plt.title("Dimer curve - H-H with ZBL") +plt.xlabel("Distance (Å)") +plt.ylabel("Energy (eV)") +plt.legend() +plt.tight_layout() +plt.show() diff --git a/examples/zbl/ethanol_reduced_100.xyz b/examples/zbl/ethanol_reduced_100.xyz new file mode 120000 index 000000000..f01afa4c6 --- /dev/null +++ b/examples/zbl/ethanol_reduced_100.xyz @@ -0,0 +1 @@ +../ase/ethanol_reduced_100.xyz \ No newline at end of file diff --git a/examples/zbl/options_no_zbl.yaml b/examples/zbl/options_no_zbl.yaml new file mode 100644 index 000000000..e53e218ba --- /dev/null +++ b/examples/zbl/options_no_zbl.yaml @@ -0,0 +1,21 @@ +seed: 42 + +architecture: + name: experimental.soap_bpnn + model: + zbl: false + training: + num_epochs: 10 + +# training set section +training_set: + systems: + read_from: ethanol_reduced_100.xyz + length_unit: angstrom + targets: + energy: + key: "energy" + unit: "eV" # very important to run simulations + +validation_set: 0.1 +test_set: 0.0 diff --git a/examples/zbl/options_zbl.yaml b/examples/zbl/options_zbl.yaml new file mode 100644 index 000000000..56fe80642 --- /dev/null +++ b/examples/zbl/options_zbl.yaml @@ -0,0 +1,21 @@ +seed: 42 + +architecture: + name: experimental.soap_bpnn + model: + zbl: true + training: + num_epochs: 10 + +# training set section +training_set: + systems: + read_from: ethanol_reduced_100.xyz + length_unit: angstrom + targets: + energy: + key: "energy" + unit: "eV" # very important to run simulations + +validation_set: 0.1 +test_set: 0.0 diff --git a/examples/zbl/train.sh b/examples/zbl/train.sh new file mode 100755 index 000000000..03b6baab2 --- /dev/null +++ b/examples/zbl/train.sh @@ -0,0 +1,4 @@ +#!/bin/bash + +mtt train options_no_zbl.yaml -o model_no_zbl.pt +mtt train options_zbl.yaml -o model_zbl.pt diff --git a/src/metatrain/cli/eval.py b/src/metatrain/cli/eval.py index 0caf1adbc..c796b83c1 100644 --- a/src/metatrain/cli/eval.py +++ b/src/metatrain/cli/eval.py @@ -25,7 +25,10 @@ from ..utils.evaluate_model import evaluate_model from ..utils.logging import MetricLogger from ..utils.metrics import RMSEAccumulator -from ..utils.neighbor_lists import get_system_with_neighbor_lists +from ..utils.neighbor_lists import ( + get_requested_neighbor_lists, + get_system_with_neighbor_lists, +) from ..utils.omegaconf import expand_dataset_config from ..utils.per_atom import average_by_num_atoms from .formatter import CustomHelpFormatter @@ -177,7 +180,7 @@ def _eval_targets( # if already present (e.g. if this function is called after training) for sample in dataset: system = sample["system"] - get_system_with_neighbor_lists(system, model.requested_neighbor_lists()) + get_system_with_neighbor_lists(system, get_requested_neighbor_lists(model)) # Infer the device and dtype from the model model_tensor = next(itertools.chain(model.parameters(), model.buffers())) diff --git a/src/metatrain/experimental/alchemical_model/default-hypers.yaml b/src/metatrain/experimental/alchemical_model/default-hypers.yaml index d41f53afb..4d3f91eb2 100644 --- a/src/metatrain/experimental/alchemical_model/default-hypers.yaml +++ b/src/metatrain/experimental/alchemical_model/default-hypers.yaml @@ -13,6 +13,7 @@ model: bpnn: hidden_sizes: [32, 32] output_size: 1 + zbl: false training: batch_size: 8 diff --git a/src/metatrain/experimental/alchemical_model/model.py b/src/metatrain/experimental/alchemical_model/model.py index 52b9af5fb..8ffbd8dc7 100644 --- a/src/metatrain/experimental/alchemical_model/model.py +++ b/src/metatrain/experimental/alchemical_model/model.py @@ -1,6 +1,7 @@ from pathlib import Path from typing import Dict, List, Optional, Union +import metatensor.torch import torch from metatensor.torch import Labels, TensorBlock, TensorMap from metatensor.torch.atomistic import ( @@ -12,6 +13,7 @@ ) from torch_alchemical.models import AlchemicalModel as AlchemicalModelUpstream +from ...utils.additive import ZBL from ...utils.data.dataset import DatasetInfo from ...utils.dtype import dtype_to_str from ...utils.export import export @@ -54,6 +56,11 @@ def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo) -> None: **self.hypers["bpnn"], ) + additive_models = [] + if self.hypers["zbl"]: + additive_models.append(ZBL(model_hypers, dataset_info)) + self.additive_models = torch.nn.ModuleList(additive_models) + self.cutoff = self.hypers["soap"]["cutoff"] self.is_restarted = False @@ -123,6 +130,18 @@ def forward( keys=keys, blocks=[block], ) + + if not self.training: + # at evaluation, we also add the additive contributions + for additive_model in self.additive_models: + additive_contributions = additive_model( + systems, outputs, selected_atoms + ) + total_energies[output_name] = metatensor.torch.add( + total_energies[output_name], + additive_contributions[output_name], + ) + return total_energies @classmethod @@ -145,10 +164,21 @@ def export(self) -> MetatensorAtomisticModel: if dtype not in self.__supported_dtypes__: raise ValueError(f"unsupported dtype {dtype} for AlchemicalModel") + # Make sure the model is all in the same dtype + # For example, after training, the additive models could still be in + # float64 + self.to(dtype) + + interaction_ranges = [self.hypers["soap"]["cutoff"]] + for additive_model in self.additive_models: + if hasattr(additive_model, "cutoff_radius"): + interaction_ranges.append(additive_model.cutoff_radius) + interaction_range = max(interaction_ranges) + capabilities = ModelCapabilities( outputs=self.outputs, atomic_types=self.atomic_types, - interaction_range=self.hypers["soap"]["cutoff"], + interaction_range=interaction_range, length_unit=self.dataset_info.length_unit, supported_devices=self.__supported_devices__, dtype=dtype_to_str(dtype), diff --git a/src/metatrain/experimental/alchemical_model/schema-hypers.json b/src/metatrain/experimental/alchemical_model/schema-hypers.json index b5901c318..4e9e141ed 100644 --- a/src/metatrain/experimental/alchemical_model/schema-hypers.json +++ b/src/metatrain/experimental/alchemical_model/schema-hypers.json @@ -53,6 +53,9 @@ } }, "additionalProperties": false + }, + "zbl": { + "type": "boolean" } }, "additionalProperties": false diff --git a/src/metatrain/experimental/alchemical_model/tests/test_exported.py b/src/metatrain/experimental/alchemical_model/tests/test_exported.py index 3be002445..891983693 100644 --- a/src/metatrain/experimental/alchemical_model/tests/test_exported.py +++ b/src/metatrain/experimental/alchemical_model/tests/test_exported.py @@ -4,7 +4,10 @@ from metatrain.experimental.alchemical_model import AlchemicalModel from metatrain.utils.data import DatasetInfo, TargetInfo, TargetInfoDict -from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists +from metatrain.utils.neighbor_lists import ( + get_requested_neighbor_lists, + get_system_with_neighbor_lists, +) from . import MODEL_HYPERS @@ -31,7 +34,8 @@ def test_to(device, dtype): positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]]), cell=torch.zeros(3, 3), ) - system = get_system_with_neighbor_lists(system, exported.requested_neighbor_lists()) + requested_neighbor_lists = get_requested_neighbor_lists(exported) + system = get_system_with_neighbor_lists(system, requested_neighbor_lists) system = system.to(device=device, dtype=dtype) evaluation_options = ModelEvaluationOptions( diff --git a/src/metatrain/experimental/alchemical_model/tests/test_functionality.py b/src/metatrain/experimental/alchemical_model/tests/test_functionality.py index b3c42d81f..7ee3331af 100644 --- a/src/metatrain/experimental/alchemical_model/tests/test_functionality.py +++ b/src/metatrain/experimental/alchemical_model/tests/test_functionality.py @@ -3,7 +3,10 @@ from metatrain.experimental.alchemical_model import AlchemicalModel from metatrain.utils.data import DatasetInfo, TargetInfo, TargetInfoDict -from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists +from metatrain.utils.neighbor_lists import ( + get_requested_neighbor_lists, + get_system_with_neighbor_lists, +) from . import MODEL_HYPERS @@ -25,7 +28,8 @@ def test_prediction_subset_elements(): positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]]), cell=torch.zeros(3, 3), ) - system = get_system_with_neighbor_lists(system, model.requested_neighbor_lists()) + requested_neighbor_lists = get_requested_neighbor_lists(model) + system = get_system_with_neighbor_lists(system, requested_neighbor_lists) evaluation_options = ModelEvaluationOptions( length_unit=dataset_info.length_unit, diff --git a/src/metatrain/experimental/alchemical_model/tests/test_invariance.py b/src/metatrain/experimental/alchemical_model/tests/test_invariance.py index f64925848..9d9a84dd9 100644 --- a/src/metatrain/experimental/alchemical_model/tests/test_invariance.py +++ b/src/metatrain/experimental/alchemical_model/tests/test_invariance.py @@ -6,7 +6,10 @@ from metatrain.experimental.alchemical_model import AlchemicalModel from metatrain.utils.data import DatasetInfo, TargetInfo, TargetInfoDict -from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists +from metatrain.utils.neighbor_lists import ( + get_requested_neighbor_lists, + get_system_with_neighbor_lists, +) from . import DATASET_PATH, MODEL_HYPERS @@ -24,13 +27,15 @@ def test_rotational_invariance(): system = ase.io.read(DATASET_PATH) original_system = copy.deepcopy(system) original_system = systems_to_torch(original_system) + requested_neighbor_lists = get_requested_neighbor_lists(model) original_system = get_system_with_neighbor_lists( - original_system, model.requested_neighbor_lists() + original_system, requested_neighbor_lists ) system.rotate(48, "y") system = systems_to_torch(system) - system = get_system_with_neighbor_lists(system, model.requested_neighbor_lists()) + requested_neighbor_lists = get_requested_neighbor_lists(model) + system = get_system_with_neighbor_lists(system, requested_neighbor_lists) evaluation_options = ModelEvaluationOptions( length_unit=dataset_info.length_unit, diff --git a/src/metatrain/experimental/alchemical_model/tests/test_regression.py b/src/metatrain/experimental/alchemical_model/tests/test_regression.py index 4dbc6ed0b..648c91c18 100644 --- a/src/metatrain/experimental/alchemical_model/tests/test_regression.py +++ b/src/metatrain/experimental/alchemical_model/tests/test_regression.py @@ -14,7 +14,10 @@ read_targets, ) from metatrain.utils.data.dataset import TargetInfoDict -from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists +from metatrain.utils.neighbor_lists import ( + get_requested_neighbor_lists, + get_system_with_neighbor_lists, +) from . import DATASET_PATH, DEFAULT_HYPERS, MODEL_HYPERS @@ -38,8 +41,9 @@ def test_regression_init(): # Predict on the first five systems systems = read_systems(DATASET_PATH)[:5] + requested_neighbor_lists = get_requested_neighbor_lists(model) systems = [ - get_system_with_neighbor_lists(system, model.requested_neighbor_lists()) + get_system_with_neighbor_lists(system, requested_neighbor_lists) for system in systems ] @@ -101,8 +105,9 @@ def test_regression_train(): ) model = AlchemicalModel(MODEL_HYPERS, dataset_info) + requested_neighbor_lists = get_requested_neighbor_lists(model) systems = [ - get_system_with_neighbor_lists(system, model.requested_neighbor_lists()) + get_system_with_neighbor_lists(system, requested_neighbor_lists) for system in systems ] diff --git a/src/metatrain/experimental/alchemical_model/trainer.py b/src/metatrain/experimental/alchemical_model/trainer.py index dddbc9839..0dd13a2b5 100644 --- a/src/metatrain/experimental/alchemical_model/trainer.py +++ b/src/metatrain/experimental/alchemical_model/trainer.py @@ -5,6 +5,7 @@ import torch from metatensor.learn.data import DataLoader +from ...utils.additive import remove_additive from ...utils.data import ( CombinedDataLoader, Dataset, @@ -19,7 +20,10 @@ from ...utils.logging import MetricLogger from ...utils.loss import TensorMapDictLoss from ...utils.metrics import RMSEAccumulator -from ...utils.neighbor_lists import get_system_with_neighbor_lists +from ...utils.neighbor_lists import ( + get_requested_neighbor_lists, + get_system_with_neighbor_lists, +) from ...utils.per_atom import average_by_num_atoms from . import AlchemicalModel from .utils.composition import calculate_composition_weights @@ -67,7 +71,7 @@ def train( # Calculating the neighbor lists for the training and validation datasets: logger.info("Calculating neighbor lists for the datasets") - requested_neighbor_lists = model.requested_neighbor_lists() + requested_neighbor_lists = get_requested_neighbor_lists(model) for dataset in train_datasets + val_datasets: for i in range(len(dataset)): system = dataset[i]["system"] @@ -223,6 +227,10 @@ def train( key: value.to(dtype=dtype, device=device) for key, value in targets.items() } + for additive_model in model.additive_models: + targets = remove_additive( + systems, targets, additive_model, model.dataset_info.targets + ) predictions = evaluate_model( model, systems, @@ -259,6 +267,10 @@ def train( key: value.to(dtype=dtype, device=device) for key, value in targets.items() } + for additive_model in model.additive_models: + targets = remove_additive( + systems, targets, additive_model, model.dataset_info.targets + ) predictions = evaluate_model( model, systems, diff --git a/src/metatrain/experimental/gap/default-hypers.yaml b/src/metatrain/experimental/gap/default-hypers.yaml index d3b911e83..2c7f192fe 100644 --- a/src/metatrain/experimental/gap/default-hypers.yaml +++ b/src/metatrain/experimental/gap/default-hypers.yaml @@ -17,10 +17,10 @@ model: rate: 1.0 scale: 2.0 exponent: 7.0 - krr: degree: 2 num_sparse_points: 500 + zbl: false training: regularizer: 0.001 diff --git a/src/metatrain/experimental/gap/model.py b/src/metatrain/experimental/gap/model.py index 1ce0a9cb5..833325d11 100644 --- a/src/metatrain/experimental/gap/model.py +++ b/src/metatrain/experimental/gap/model.py @@ -21,7 +21,7 @@ from metatrain.utils.data.dataset import DatasetInfo -from ...utils.composition import CompositionModel +from ...utils.additive import ZBL, CompositionModel from ...utils.export import export @@ -95,7 +95,6 @@ def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo) -> None: self._soap_torch_calculator = rascaline.torch.SoapPowerSpectrum( **model_hypers["soap"] ) - self._soap_calculator = rascaline.SoapPowerSpectrum(**model_hypers["soap"]) kernel_kwargs = { "degree": model_hypers["krr"]["degree"], @@ -128,10 +127,16 @@ def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo) -> None: ) self._species_labels: TorchLabels = TorchLabels.empty("_") - self.composition_model = CompositionModel( + # additive models: these are handled by the trainer at training + # time, and they are added to the output at evaluation time + composition_model = CompositionModel( model_hypers={}, dataset_info=dataset_info, ) + additive_models = [composition_model] + if self.hypers["zbl"]: + additive_models.append(ZBL(model_hypers, dataset_info)) + self.additive_models = torch.nn.ModuleList(additive_models) def restart(self, dataset_info: DatasetInfo) -> "GAP": raise ValueError("GAP does not allow restarting training") @@ -209,24 +214,34 @@ def forward( energies = self._subset_of_regressors_torch(soap_features) return_dict = {output_key: energies} - # apply composition model - composition_energies = self.composition_model( - systems, {output_key: ModelOutput("energy", per_atom=True)}, selected_atoms - ) - composition_energies[output_key] = metatensor.torch.sum_over_samples( - composition_energies[output_key], "atom" - ) - return_dict[output_key] = metatensor.torch.add( - return_dict[output_key], composition_energies[output_key] - ) + if not self.training: + # at evaluation, we also add the additive contributions + for additive_model in self.additive_models: + additive_contributions = additive_model( + systems, outputs, selected_atoms + ) + for name in return_dict: + if name.startswith("mtt::aux::"): + continue # skip auxiliary outputs (not targets) + return_dict[name] = metatensor.torch.add( + return_dict[name], + additive_contributions[name], + ) return return_dict def export(self) -> MetatensorAtomisticModel: + + interaction_ranges = [self.hypers["soap"]["cutoff"]] + for additive_model in self.additive_models: + if hasattr(additive_model, "cutoff_radius"): + interaction_ranges.append(additive_model.cutoff_radius) + interaction_range = max(interaction_ranges) + capabilities = ModelCapabilities( outputs=self.outputs, atomic_types=sorted(self.dataset_info.atomic_types), - interaction_range=self.hypers["soap"]["cutoff"], + interaction_range=interaction_range, length_unit=self.dataset_info.length_unit, supported_devices=["cuda", "cpu"], dtype="float64", diff --git a/src/metatrain/experimental/gap/schema-hypers.json b/src/metatrain/experimental/gap/schema-hypers.json index a756bc590..e793f0315 100644 --- a/src/metatrain/experimental/gap/schema-hypers.json +++ b/src/metatrain/experimental/gap/schema-hypers.json @@ -87,6 +87,9 @@ } }, "additionalProperties": false + }, + "zbl": { + "type": "boolean" } }, "additionalProperties": false diff --git a/src/metatrain/experimental/gap/tests/test_errors.py b/src/metatrain/experimental/gap/tests/test_errors.py index 24e5c03b3..b0359bd80 100644 --- a/src/metatrain/experimental/gap/tests/test_errors.py +++ b/src/metatrain/experimental/gap/tests/test_errors.py @@ -62,7 +62,7 @@ def test_ethanol_regression_train_and_invariance(): hypers["model"]["krr"]["num_sparse_points"] = 30 target_info_dict = TargetInfoDict( - energy=TargetInfo(quantity="energy", unit="kcal/mol") + energy=TargetInfo(quantity="energy", unit="kcal/mol", gradients=["positions"]) ) dataset_info = DatasetInfo( diff --git a/src/metatrain/experimental/gap/tests/test_regression.py b/src/metatrain/experimental/gap/tests/test_regression.py index e2a2ee72c..81212353c 100644 --- a/src/metatrain/experimental/gap/tests/test_regression.py +++ b/src/metatrain/experimental/gap/tests/test_regression.py @@ -74,6 +74,7 @@ def test_regression_train_and_invariance(): val_datasets=[dataset], checkpoint_dir=".", ) + gap.eval() # Predict on the first five systems output = gap(systems[:5], {"mtt::U0": gap.outputs["mtt::U0"]}) @@ -138,7 +139,7 @@ def test_ethanol_regression_train_and_invariance(): hypers["model"]["krr"]["num_sparse_points"] = 900 target_info_dict = TargetInfoDict( - energy=TargetInfo(quantity="energy", unit="kcal/mol") + energy=TargetInfo(quantity="energy", unit="kcal/mol", gradients=["positions"]) ) dataset_info = DatasetInfo( @@ -155,6 +156,7 @@ def test_ethanol_regression_train_and_invariance(): val_datasets=[dataset], checkpoint_dir=".", ) + gap.eval() # Predict on the first five systems output = gap(systems[:5], {"energy": gap.outputs["energy"]}) diff --git a/src/metatrain/experimental/gap/trainer.py b/src/metatrain/experimental/gap/trainer.py index ffeb40951..5daef8369 100644 --- a/src/metatrain/experimental/gap/trainer.py +++ b/src/metatrain/experimental/gap/trainer.py @@ -9,8 +9,12 @@ from metatrain.utils.data import Dataset -from ...utils.composition import remove_composition +from ...utils.additive import remove_additive from ...utils.data import check_datasets +from ...utils.neighbor_lists import ( + get_requested_neighbor_lists, + get_system_with_neighbor_lists, +) from . import GAP from .model import torch_tensor_map_to_core @@ -52,7 +56,8 @@ def train( # Calculate and set the composition weights: logger.info("Calculating composition weights") - model.composition_model.train_model(train_datasets) + # model.additive_models[0] is the composition model + model.additive_models[0].train_model(train_datasets) logger.info("Setting up data loaders") if len(train_datasets[0][0][output_name].keys) > 1: @@ -69,11 +74,25 @@ def train( model._keys = train_y.keys train_structures = [sample["system"] for sample in train_dataset] - logger.info("Subtracting composition energies") - # this acts in-place on train_y - remove_composition( - train_structures, {target_name: train_y}, model.composition_model - ) + logger.info("Calculating neighbor lists for the datasets") + requested_neighbor_lists = get_requested_neighbor_lists(model) + for dataset in train_datasets + val_datasets: + for i in range(len(dataset)): + system = dataset[i]["system"] + # The following line attaches the neighbors lists to the system, + # and doesn't require to reassign the system to the dataset: + _ = get_system_with_neighbor_lists(system, requested_neighbor_lists) + + logger.info("Subtracting composition energies") # and potentially ZBL + train_targets = {target_name: train_y} + for additive_model in model.additive_models: + train_targets = remove_additive( + train_structures, + train_targets, + additive_model, + model.dataset_info.targets, + ) + train_y = train_targets[target_name] logger.info("Calculating SOAP features") if len(train_y[0].gradients_list()) > 0: diff --git a/src/metatrain/experimental/pet/default-hypers.yaml b/src/metatrain/experimental/pet/default-hypers.yaml index 98839a13a..ad6befbbb 100644 --- a/src/metatrain/experimental/pet/default-hypers.yaml +++ b/src/metatrain/experimental/pet/default-hypers.yaml @@ -32,6 +32,7 @@ model: N_TARGETS: 1 TARGET_INDEX_KEY: target_index RESIDUAL_FACTOR: 0.5 + USE_ZBL: False training: INITIAL_LR: 1e-4 diff --git a/src/metatrain/experimental/pet/model.py b/src/metatrain/experimental/pet/model.py index 0bb67b19c..bf567dee1 100644 --- a/src/metatrain/experimental/pet/model.py +++ b/src/metatrain/experimental/pet/model.py @@ -18,6 +18,7 @@ from metatrain.utils.data import DatasetInfo +from ...utils.additive import ZBL from ...utils.dtype import dtype_to_str from ...utils.export import export from .utils import systems_to_batch_dict @@ -48,6 +49,13 @@ def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo) -> None: self.pet = None self.checkpoint_path: Optional[str] = None + # additive models: these are handled by the trainer at training + # time, and they are added to the output at evaluation time + additive_models = [] + if self.hypers["USE_ZBL"]: + additive_models.append(ZBL(model_hypers, dataset_info)) + self.additive_models = torch.nn.ModuleList(additive_models) + def restart(self, dataset_info: DatasetInfo) -> "PET": if dataset_info != self.dataset_info: raise ValueError( @@ -110,6 +118,21 @@ def forward( if not outputs[output_name].per_atom: output_tmap = metatensor.torch.sum_over_samples(output_tmap, "atom") output_quantities[output_name] = output_tmap + + if not self.training: + # at evaluation, we also add the additive contributions + for additive_model in self.additive_models: + additive_contributions = additive_model( + systems, outputs, selected_atoms + ) + for output_name in output_quantities: + if output_name.startswith("mtt::aux::"): + continue # skip auxiliary outputs (not targets) + output_quantities[output_name] = metatensor.torch.add( + output_quantities[output_name], + additive_contributions[output_name], + ) + return output_quantities @classmethod @@ -148,6 +171,17 @@ def export(self) -> MetatensorAtomisticModel: if dtype not in self.__supported_dtypes__: raise ValueError(f"Unsupported dtype {self.dtype} for PET") + # Make sure the model is all in the same dtype + # For example, after training, the additive models could still be in + # float64 + self.to(dtype) + + interaction_ranges = [self.hypers["N_GNN_LAYERS"] * self.cutoff] + for additive_model in self.additive_models: + if hasattr(additive_model, "cutoff_radius"): + interaction_ranges.append(additive_model.cutoff_radius) + interaction_range = max(interaction_ranges) + capabilities = ModelCapabilities( outputs={ self.target_name: ModelOutput( @@ -157,7 +191,7 @@ def export(self) -> MetatensorAtomisticModel: ) }, atomic_types=self.atomic_types, - interaction_range=self.cutoff, + interaction_range=interaction_range, length_unit=self.dataset_info.length_unit, supported_devices=["cpu", "cuda"], # and not __supported_devices__ dtype=dtype_to_str(dtype), diff --git a/src/metatrain/experimental/pet/schema-hypers.json b/src/metatrain/experimental/pet/schema-hypers.json index e68c5bd95..9b00a1927 100644 --- a/src/metatrain/experimental/pet/schema-hypers.json +++ b/src/metatrain/experimental/pet/schema-hypers.json @@ -123,6 +123,9 @@ }, "RESIDUAL_FACTOR": { "type": "number" + }, + "USE_ZBL": { + "type": "boolean" } }, "additionalProperties": false diff --git a/src/metatrain/experimental/pet/tests/test_exported.py b/src/metatrain/experimental/pet/tests/test_exported.py index a72eb88dd..f67a15e4c 100644 --- a/src/metatrain/experimental/pet/tests/test_exported.py +++ b/src/metatrain/experimental/pet/tests/test_exported.py @@ -13,7 +13,10 @@ from metatrain.utils.architectures import get_default_hypers from metatrain.utils.data import DatasetInfo, TargetInfo, TargetInfoDict from metatrain.utils.export import export -from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists +from metatrain.utils.neighbor_lists import ( + get_requested_neighbor_lists, + get_system_with_neighbor_lists, +) DEFAULT_HYPERS = get_default_hypers("experimental.pet") @@ -59,7 +62,8 @@ def test_to(device): positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]]), cell=torch.zeros(3, 3), ) - system = get_system_with_neighbor_lists(system, exported.requested_neighbor_lists()) + requested_neighbor_lists = get_requested_neighbor_lists(exported) + system = get_system_with_neighbor_lists(system, requested_neighbor_lists) system = system.to(device=device, dtype=dtype) evaluation_options = ModelEvaluationOptions( diff --git a/src/metatrain/experimental/pet/tests/test_functionality.py b/src/metatrain/experimental/pet/tests/test_functionality.py index 74a47b075..ddf527603 100644 --- a/src/metatrain/experimental/pet/tests/test_functionality.py +++ b/src/metatrain/experimental/pet/tests/test_functionality.py @@ -20,7 +20,10 @@ from metatrain.utils.architectures import get_default_hypers from metatrain.utils.data import DatasetInfo, TargetInfo, TargetInfoDict from metatrain.utils.jsonschema import validate -from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists +from metatrain.utils.neighbor_lists import ( + get_requested_neighbor_lists, + get_system_with_neighbor_lists, +) DEFAULT_HYPERS = get_default_hypers("experimental.pet") @@ -74,7 +77,8 @@ def test_prediction(): positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]]), cell=torch.zeros(3, 3), ) - system = get_system_with_neighbor_lists(system, model.requested_neighbor_lists()) + requested_neighbor_lists = get_requested_neighbor_lists(model) + system = get_system_with_neighbor_lists(system, requested_neighbor_lists) evaluation_options = ModelEvaluationOptions( length_unit=dataset_info.length_unit, @@ -123,7 +127,8 @@ def test_per_atom_predictions_functionality(): positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]]), cell=torch.zeros(3, 3), ) - system = get_system_with_neighbor_lists(system, model.requested_neighbor_lists()) + requested_neighbor_lists = get_requested_neighbor_lists(model) + system = get_system_with_neighbor_lists(system, requested_neighbor_lists) evaluation_options = ModelEvaluationOptions( length_unit=dataset_info.length_unit, @@ -173,7 +178,8 @@ def test_selected_atoms_functionality(): positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]]), cell=torch.zeros(3, 3), ) - system = get_system_with_neighbor_lists(system, model.requested_neighbor_lists()) + requested_neighbor_lists = get_requested_neighbor_lists(model) + system = get_system_with_neighbor_lists(system, requested_neighbor_lists) evaluation_options = ModelEvaluationOptions( length_unit=dataset_info.length_unit, diff --git a/src/metatrain/experimental/pet/trainer.py b/src/metatrain/experimental/pet/trainer.py index e91ccaf24..63c06afd9 100644 --- a/src/metatrain/experimental/pet/trainer.py +++ b/src/metatrain/experimental/pet/trainer.py @@ -95,10 +95,10 @@ def train( ) ase_train_dataset = dataset_to_ase( - train_dataset, do_forces=do_forces, target_name=target_name + train_dataset, model, do_forces=do_forces, target_name=target_name ) ase_val_dataset = dataset_to_ase( - val_dataset, do_forces=do_forces, target_name=target_name + val_dataset, model, do_forces=do_forces, target_name=target_name ) self.hypers = update_hypers(self.hypers, model.hypers, do_forces) diff --git a/src/metatrain/experimental/pet/utils/dataset_to_ase.py b/src/metatrain/experimental/pet/utils/dataset_to_ase.py index 3a35a63bf..e7111a20f 100644 --- a/src/metatrain/experimental/pet/utils/dataset_to_ase.py +++ b/src/metatrain/experimental/pet/utils/dataset_to_ase.py @@ -1,11 +1,16 @@ from metatensor.learn.data import DataLoader +from ....utils.additive import remove_additive from ....utils.data import collate_fn from ....utils.data.system_to_ase import system_to_ase +from ....utils.neighbor_lists import ( + get_requested_neighbor_lists, + get_system_with_neighbor_lists, +) # dummy dataloaders due to https://github.com/lab-cosmo/metatensor/issues/521 -def dataset_to_ase(dataset, do_forces=True, target_name="energy"): +def dataset_to_ase(dataset, model, do_forces=True, target_name="energy"): dataloader = DataLoader( dataset, batch_size=1, @@ -14,6 +19,14 @@ def dataset_to_ase(dataset, do_forces=True, target_name="energy"): ) ase_dataset = [] for (system,), targets in dataloader: + # remove additive model (e.g. ZBL) contributions + requested_neighbor_lists = get_requested_neighbor_lists(model) + system = get_system_with_neighbor_lists(system, requested_neighbor_lists) + for additive_model in model.additive_models: + targets = remove_additive( + [system], targets, additive_model, model.dataset_info.targets + ) + # transform to ase atoms ase_atoms = system_to_ase(system) ase_atoms.info["energy"] = float( targets[target_name].block().values.squeeze(-1).detach().cpu().numpy() diff --git a/src/metatrain/experimental/soap_bpnn/default-hypers.yaml b/src/metatrain/experimental/soap_bpnn/default-hypers.yaml index a86bc9ded..1c7fe1c66 100644 --- a/src/metatrain/experimental/soap_bpnn/default-hypers.yaml +++ b/src/metatrain/experimental/soap_bpnn/default-hypers.yaml @@ -15,11 +15,11 @@ model: rate: 1.0 scale: 2.0 exponent: 7.0 - bpnn: layernorm: true num_hidden_layers: 2 num_neurons_per_layer: 32 + zbl: false training: distributed: False diff --git a/src/metatrain/experimental/soap_bpnn/model.py b/src/metatrain/experimental/soap_bpnn/model.py index 4b99f9229..556f3ef52 100644 --- a/src/metatrain/experimental/soap_bpnn/model.py +++ b/src/metatrain/experimental/soap_bpnn/model.py @@ -16,7 +16,7 @@ from metatrain.utils.data.dataset import DatasetInfo -from ...utils.composition import CompositionModel +from ...utils.additive import ZBL, CompositionModel from ...utils.dtype import dtype_to_str from ...utils.export import export @@ -187,10 +187,16 @@ def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo) -> None: } ) - self.composition_model = CompositionModel( + # additive models: these are handled by the trainer at training + # time, and they are added to the output at evaluation time + composition_model = CompositionModel( model_hypers={}, dataset_info=dataset_info, ) + additive_models = [composition_model] + if self.hypers["zbl"]: + additive_models.append(ZBL(model_hypers, dataset_info)) + self.additive_models = torch.nn.ModuleList(additive_models) def restart(self, dataset_info: DatasetInfo) -> "SoapBpnn": # merge old and new dataset info @@ -274,17 +280,18 @@ def forward( ) if not self.training: - # at evaluation, we also add the composition contributions - composition_contributions = self.composition_model( - systems, outputs, selected_atoms - ) - for name in return_dict: - if name.startswith("mtt::aux::"): - continue # skip auxiliary outputs (not targets) - return_dict[name] = metatensor.torch.add( - return_dict[name], - composition_contributions[name], + # at evaluation, we also add the additive contributions + for additive_model in self.additive_models: + additive_contributions = additive_model( + systems, outputs, selected_atoms ) + for name in return_dict: + if name.startswith("mtt::aux::"): + continue # skip auxiliary outputs (not targets) + return_dict[name] = metatensor.torch.add( + return_dict[name], + additive_contributions[name], + ) return return_dict @@ -309,14 +316,20 @@ def export(self) -> MetatensorAtomisticModel: raise ValueError(f"unsupported dtype {self.dtype} for SoapBpnn") # Make sure the model is all in the same dtype - # For example, at this point, the composition model within the SOAP-BPNN is - # still float64 + # For example, after training, the additive models could still be in + # float64 self.to(dtype) + interaction_ranges = [self.hypers["soap"]["cutoff"]] + for additive_model in self.additive_models: + if hasattr(additive_model, "cutoff_radius"): + interaction_ranges.append(additive_model.cutoff_radius) + interaction_range = max(interaction_ranges) + capabilities = ModelCapabilities( outputs=self.outputs, atomic_types=self.atomic_types, - interaction_range=self.hypers["soap"]["cutoff"], + interaction_range=interaction_range, length_unit=self.dataset_info.length_unit, supported_devices=self.__supported_devices__, dtype=dtype_to_str(dtype), diff --git a/src/metatrain/experimental/soap_bpnn/schema-hypers.json b/src/metatrain/experimental/soap_bpnn/schema-hypers.json index 570931d49..b2ca893b0 100644 --- a/src/metatrain/experimental/soap_bpnn/schema-hypers.json +++ b/src/metatrain/experimental/soap_bpnn/schema-hypers.json @@ -80,6 +80,9 @@ } }, "additionalProperties": false + }, + "zbl": { + "type": "boolean" } }, "additionalProperties": false diff --git a/src/metatrain/experimental/soap_bpnn/tests/test_exported.py b/src/metatrain/experimental/soap_bpnn/tests/test_exported.py index cc41a360c..63242161e 100644 --- a/src/metatrain/experimental/soap_bpnn/tests/test_exported.py +++ b/src/metatrain/experimental/soap_bpnn/tests/test_exported.py @@ -4,7 +4,10 @@ from metatrain.experimental.soap_bpnn import SoapBpnn from metatrain.utils.data import DatasetInfo, TargetInfo, TargetInfoDict -from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists +from metatrain.utils.neighbor_lists import ( + get_requested_neighbor_lists, + get_system_with_neighbor_lists, +) from . import MODEL_HYPERS @@ -31,7 +34,8 @@ def test_to(device, dtype): positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]]), cell=torch.zeros(3, 3), ) - system = get_system_with_neighbor_lists(system, exported.requested_neighbor_lists()) + requested_neighbor_lists = get_requested_neighbor_lists(exported) + system = get_system_with_neighbor_lists(system, requested_neighbor_lists) system = system.to(device=device, dtype=dtype) evaluation_options = ModelEvaluationOptions( diff --git a/src/metatrain/experimental/soap_bpnn/trainer.py b/src/metatrain/experimental/soap_bpnn/trainer.py index 47cc09174..b67d9dd84 100644 --- a/src/metatrain/experimental/soap_bpnn/trainer.py +++ b/src/metatrain/experimental/soap_bpnn/trainer.py @@ -7,7 +7,7 @@ import torch.distributed from torch.utils.data import DataLoader, DistributedSampler -from ...utils.composition import remove_composition +from ...utils.additive import remove_additive from ...utils.data import CombinedDataLoader, Dataset, TargetInfoDict, collate_fn from ...utils.data.extract_targets import get_targets_dict from ...utils.distributed.distributed_data_parallel import DistributedDataParallel @@ -18,6 +18,10 @@ from ...utils.logging import MetricLogger from ...utils.loss import TensorMapDictLoss from ...utils.metrics import RMSEAccumulator +from ...utils.neighbor_lists import ( + get_requested_neighbor_lists, + get_system_with_neighbor_lists, +) from ...utils.per_atom import average_by_num_atoms from .model import SoapBpnn @@ -85,15 +89,28 @@ def train( else: logger.info(f"Training on device {device} with dtype {dtype}") + # Calculate the neighbor lists in advance (in particular, this + # needs to happen before the additive models are trained, as they + # might need them): + logger.info("Calculating neighbor lists for the datasets") + requested_neighbor_lists = get_requested_neighbor_lists(model) + for dataset in train_datasets + val_datasets: + for i in range(len(dataset)): + system = dataset[i]["system"] + # The following line attaches the neighbors lists to the system, + # and doesn't require to reassign the system to the dataset: + _ = get_system_with_neighbor_lists(system, requested_neighbor_lists) + # Move the model to the device and dtype: model.to(device=device, dtype=dtype) - # The composition model of the SOAP-BPNN is always on CPU (to avoid OOM + # The additive models of the SOAP-BPNN are always on CPU (to avoid OOM # errors during the linear algebra training) and in float64 (to avoid # numerical errors in the composition weights, which can be very large). - model.composition_model.to(device=torch.device("cpu"), dtype=torch.float64) + for additive_model in model.additive_models: + additive_model.to(device=torch.device("cpu"), dtype=torch.float64) logger.info("Calculating composition weights") - model.composition_model.train_model( + model.additive_models[0].train_model( # this is the composition model train_datasets, self.hypers["fixed_composition_weights"] ) @@ -230,7 +247,10 @@ def train( optimizer.zero_grad() systems, targets = batch - remove_composition(systems, targets, model.composition_model) + for additive_model in model.additive_models: + targets = remove_additive( + systems, targets, additive_model, train_targets + ) systems = [system.to(dtype=dtype, device=device) for system in systems] targets = { key: value.to(dtype=dtype, device=device) @@ -270,7 +290,10 @@ def train( val_loss = 0.0 for batch in val_dataloader: systems, targets = batch - remove_composition(systems, targets, model.composition_model) + for additive_model in model.additive_models: + targets = remove_additive( + systems, targets, additive_model, train_targets + ) systems = [system.to(dtype=dtype, device=device) for system in systems] targets = { key: value.to(dtype=dtype, device=device) diff --git a/src/metatrain/utils/additive/__init__.py b/src/metatrain/utils/additive/__init__.py new file mode 100644 index 000000000..bab1a5de3 --- /dev/null +++ b/src/metatrain/utils/additive/__init__.py @@ -0,0 +1,3 @@ +from .composition import CompositionModel # noqa: F401 +from .zbl import ZBL # noqa: F401 +from .remove import remove_additive # noqa: F401 diff --git a/src/metatrain/utils/composition.py b/src/metatrain/utils/additive/composition.py similarity index 88% rename from src/metatrain/utils/composition.py rename to src/metatrain/utils/additive/composition.py index 1f96e0549..7a79c660e 100644 --- a/src/metatrain/utils/composition.py +++ b/src/metatrain/utils/additive/composition.py @@ -6,8 +6,8 @@ from metatensor.torch import Labels, TensorBlock, TensorMap from metatensor.torch.atomistic import ModelOutput, System -from .data import Dataset, DatasetInfo, get_all_targets, get_atomic_types -from .jsonschema import validate +from ..data import Dataset, DatasetInfo, get_all_targets, get_atomic_types +from ..jsonschema import validate class CompositionModel(torch.nn.Module): @@ -42,6 +42,15 @@ def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo): self.dataset_info = dataset_info self.atomic_types = sorted(dataset_info.atomic_types) + self.outputs = { + key: ModelOutput( + quantity=value.quantity, + unit=value.unit, + per_atom=True, + ) + for key, value in dataset_info.targets.items() + } + n_types = len(self.atomic_types) n_targets = len(dataset_info.targets) @@ -81,14 +90,14 @@ def train_model( raise ValueError( "Provided `datasets` contains unknown " f"atomic types {additional_types}. " - f"Known types from initilaization are {self.atomic_types}." + f"Known types from initialization are {self.atomic_types}." ) missing_types = sorted(set(self.atomic_types) - set(get_atomic_types(datasets))) if missing_types: warnings.warn( f"Provided `datasets` do not contain atomic types {missing_types}. " - f"Known types from initilaization are {self.atomic_types}.", + f"Known types from initialization are {self.atomic_types}.", stacklevel=2, ) @@ -192,15 +201,14 @@ def forward( ) -> Dict[str, TensorMap]: """Compute the targets for each system based on the composition weights. - :param systems: List of systems to calculate the energy per atom. + :param systems: List of systems to calculate the energy. :param outputs: Dictionary containing the model outputs. :param selected_atoms: Optional selection of atoms for which to compute the - targets. - :returns: A dictionary with the computed targets for each system. + predictions. + :returns: A dictionary with the computed predictions for each system. :raises ValueError: If no weights have been computed or if `outputs` keys contain unsupported keys. - :raises NotImplementedError: If `selected_atoms` is provided (not implemented). """ dtype = systems[0].positions.dtype device = systems[0].positions.device @@ -263,28 +271,3 @@ def forward( ) return targets_out - - -def remove_composition( - systems: List[System], - targets: Dict[str, TensorMap], - composition_model: torch.nn.Module, -): - """Remove the composition contribution from the training targets. - - The targets are changed in place. - - :param systems: List of systems. - :param targets: Dictionary containing the targets corresponding to the systems. - :param composition_model: The composition model used to calculate the composition - contribution. - """ - output_options = {} - for target_key in targets: - output_options[target_key] = ModelOutput(per_atom=False) - - composition_targets = composition_model(systems, output_options) - for target_key in targets: - targets[target_key].block().values[:] -= ( - composition_targets[target_key].block().values - ) diff --git a/src/metatrain/utils/additive/remove.py b/src/metatrain/utils/additive/remove.py new file mode 100644 index 000000000..4235af1fc --- /dev/null +++ b/src/metatrain/utils/additive/remove.py @@ -0,0 +1,75 @@ +import warnings +from typing import Dict, List, Union + +import metatensor.torch +import torch +from metatensor.torch import TensorMap +from metatensor.torch.atomistic import System + +from ..data import TargetInfo, TargetInfoDict +from ..evaluate_model import evaluate_model + + +def remove_additive( + systems: List[System], + targets: Dict[str, TensorMap], + additive_model: torch.nn.Module, + target_info_dict: Union[Dict[str, TargetInfo], TargetInfoDict], +): + """Remove an additive contribution from the training targets. + + :param systems: List of systems. + :param targets: Dictionary containing the targets corresponding to the systems. + :param additive_model: The model used to calculate the additive + contribution to be removed. + :param targets_dict: Dictionary containing information about the targets. + """ + warnings.filterwarnings( + "ignore", + category=RuntimeWarning, + message=( + "GRADIENT WARNING: element 0 of tensors does not " + "require grad and does not have a grad_fn" + ), + ) + additive_contribution = evaluate_model( + additive_model, + systems, + TargetInfoDict(**{key: target_info_dict[key] for key in targets.keys()}), + is_training=False, # we don't need any gradients w.r.t. any parameters + ) + + for target_key in targets: + # make the samples the same so we can use metatensor.torch.subtract + # we also need to detach the values to avoid backpropagating through the + # subtraction + block = metatensor.torch.TensorBlock( + values=additive_contribution[target_key].block().values.detach(), + samples=targets[target_key].block().samples, + components=additive_contribution[target_key].block().components, + properties=additive_contribution[target_key].block().properties, + ) + for gradient_name, gradient in ( + additive_contribution[target_key].block().gradients() + ): + block.add_gradient( + gradient_name, + metatensor.torch.TensorBlock( + values=gradient.values.detach(), + samples=targets[target_key].block().gradient(gradient_name).samples, + components=gradient.components, + properties=gradient.properties, + ), + ) + additive_contribution[target_key] = TensorMap( + keys=targets[target_key].keys, + blocks=[ + block, + ], + ) + # subtract the additive contribution from the target + targets[target_key] = metatensor.torch.subtract( + targets[target_key], additive_contribution[target_key] + ) + + return targets diff --git a/src/metatrain/utils/additive/zbl.py b/src/metatrain/utils/additive/zbl.py new file mode 100644 index 000000000..edf2ec7b2 --- /dev/null +++ b/src/metatrain/utils/additive/zbl.py @@ -0,0 +1,292 @@ +import warnings +from typing import Dict, List, Optional + +import metatensor.torch +import torch +from ase.data import covalent_radii +from metatensor.torch import Labels, TensorBlock, TensorMap +from metatensor.torch.atomistic import ModelOutput, NeighborListOptions, System + +from ..data import DatasetInfo + + +class ZBL(torch.nn.Module): + """ + A simple model for short-range repulsive interactions. + + The implementation here is equivalent to its + `LAMMPS counterpart `_, where we set the + inner cutoff to 0 and the outer cutoff to the sum of the covalent radii of the + two atoms as tabulated in ASE. Covalent radii that are not available in ASE are + set to 0.2 Å (and a warning is issued). + + :param model_hypers: A dictionary of model hyperparameters. This contains the + "inner_cutoff" and "outer_cutoff" keys, which are the inner and outer cutoffs + for the ZBL potential. + :param dataset_info: An object containing information about the dataset, including + target quantities and atomic types. + """ + + def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo): + super().__init__() + + # Check capabilities + if dataset_info.length_unit != "angstrom": + raise ValueError( + "ZBL only supports angstrom units, but a " + f"{dataset_info.length_unit} unit was provided." + ) + for target in dataset_info.targets.values(): + if target.quantity != "energy": + raise ValueError( + "ZBL only supports energy-like outputs, but a " + f"{target.quantity} output was provided." + ) + if target.unit != "eV": + raise ValueError( + "ZBL only supports eV units, but a " + f"{target.unit} output was provided." + ) + + self.dataset_info = dataset_info + self.atomic_types = sorted(dataset_info.atomic_types) + + self.outputs = { + key: ModelOutput( + quantity=value.quantity, + unit=value.unit, + per_atom=True, + ) + for key, value in dataset_info.targets.items() + } + + n_types = len(self.atomic_types) + + self.output_to_output_index = { + target: i for i, target in enumerate(sorted(dataset_info.targets.keys())) + } + + self.register_buffer( + "species_to_index", + torch.full((max(self.atomic_types) + 1,), -1, dtype=torch.int), + ) + for i, t in enumerate(self.atomic_types): + self.species_to_index[t] = i + + self.register_buffer( + "covalent_radii", torch.empty((n_types,), dtype=torch.float64) + ) + for i, t in enumerate(self.atomic_types): + ase_covalent_radius = covalent_radii[t] + if ase_covalent_radius == 0.2: + # 0.2 seems to be the default value when the covalent radius + # is not known/available + warnings.warn( + f"Covalent radius for element {t} is not available in ASE. " + "Using a default value of 0.2 Å.", + stacklevel=2, + ) + self.covalent_radii[i] = ase_covalent_radius + + largest_covalent_radius = float(torch.max(self.covalent_radii)) + self.cutoff_radius = 2.0 * largest_covalent_radius + + def restart(self, dataset_info: DatasetInfo) -> "ZBL": + """Restart the model with a new dataset info. + + :param dataset_info: New dataset information to be used. + """ + return self({}, self.dataset_info.union(dataset_info)) + + def forward( + self, + systems: List[System], + outputs: Dict[str, ModelOutput], + selected_atoms: Optional[Labels] = None, + ) -> Dict[str, TensorMap]: + """Compute the energies of a system solely based on a ZBL repulsive + potential. + + :param systems: List of systems to calculate the ZBL energy. + :param outputs: Dictionary containing the model outputs. + :param selected_atoms: Optional selection of atoms for which to compute the + predictions. + :returns: A dictionary with the computed predictions for each system. + + :raises ValueError: If the `outputs` contain unsupported keys. + """ + + # Assert only one neighbor list for all systems + neighbor_lists: List[TensorBlock] = [] + for system in systems: + nl_options = self.requested_neighbor_lists()[0] + nl = system.get_neighbor_list(nl_options) + neighbor_lists.append(nl) + + # Find the elements of all i and j atoms + zi = torch.concatenate( + [ + system.types[nl.samples.column("first_atom")] + for nl, system in zip(neighbor_lists, systems) + ] + ) + zj = torch.concatenate( + [ + system.types[nl.samples.column("second_atom")] + for nl, system in zip(neighbor_lists, systems) + ] + ) + + # Find the interatomic distances + rij = torch.concatenate( + [torch.sqrt(torch.sum(nl.values**2, dim=(1, 2))) for nl in neighbor_lists] + ) + + # Find the ZBL energies + e_zbl = self.get_pairwise_zbl(zi, zj, rij) + + # Sum over edges to get node energies + indices_for_sum_list = [] + sum = 0 + for system, nl in zip(systems, neighbor_lists): + indices_for_sum_list.append(nl.samples.column("first_atom") + sum) + sum += system.positions.shape[0] + + e_zbl_nodes = torch.zeros(sum, dtype=e_zbl.dtype, device=e_zbl.device) + e_zbl_nodes.index_add_(0, torch.cat(indices_for_sum_list), e_zbl) + + device = systems[0].positions.device + + # Set the outputs as the ZBL energies + targets_out: Dict[str, TensorMap] = {} + for target_key, target in outputs.items(): + if target_key.startswith("mtt::aux::"): + continue + sample_values: List[List[int]] = [] + + for i_system, system in enumerate(systems): + sample_values += [[i_system, i_atom] for i_atom in range(len(system))] + + block = TensorBlock( + values=e_zbl_nodes.reshape(-1, 1), + samples=Labels( + ["system", "atom"], torch.tensor(sample_values, device=device) + ), + components=[], + properties=Labels( + names=["energy"], values=torch.tensor([[0]], device=device) + ), + ) + + targets_out[target_key] = TensorMap( + keys=Labels(names=["_"], values=torch.tensor([[0]], device=device)), + blocks=[block], + ) + + # apply selected_atoms to the composition if needed + if selected_atoms is not None: + targets_out[target_key] = metatensor.torch.slice( + targets_out[target_key], "samples", selected_atoms + ) + + if not target.per_atom: + targets_out[target_key] = metatensor.torch.sum_over_samples( + targets_out[target_key], sample_names="atom" + ) + + return targets_out + + def get_pairwise_zbl(self, zi, zj, rij): + """ + Ziegler-Biersack-Littmark (ZBL) potential. + + Inputs are the atomic numbers (zi, zj) of the two atoms of interest + and their distance rij. + """ + # set cutoff from covalent radii of the elements + rc = ( + self.covalent_radii[self.species_to_index[zi]] + + self.covalent_radii[self.species_to_index[zj]] + ) + + r1 = 0.0 + p = 0.23 + # angstrom + a0 = 0.46850 + c = torch.tensor( + [0.02817, 0.28022, 0.50986, 0.18175], dtype=rij.dtype, device=rij.device + ) + d = torch.tensor( + [0.20162, 0.40290, 0.94229, 3.19980], dtype=rij.dtype, device=rij.device + ) + + a = a0 / (zi**p + zj**p) + + da = d.unsqueeze(-1) / a + + # e * e / (4 * pi * epsilon_0) / electron_volt / angstrom + factor = 14.399645478425668 * zi * zj + e = _e_zbl(factor, rij, c, da) # eV.angstrom + + # switching function + ec = _e_zbl(factor, rc, c, da) + dec = _dedr(factor, rc, c, da) + d2ec = _d2edr2(factor, rc, c, da) + + # coefficients are determined such that E(rc) = 0, E'(rc) = 0, and E''(rc) = 0 + A = (-3 * dec + (rc - r1) * d2ec) / ((rc - r1) ** 2) + B = (2 * dec - (rc - r1) * d2ec) / ((rc - r1) ** 3) + C = -ec + (rc - r1) * dec / 2 - (rc - r1) * (rc - r1) * d2ec / 12 + + e += A / 3 * ((rij - r1) ** 3) + B / 4 * ((rij - r1) ** 4) + C + e = e / 2.0 # divide by 2 to fix double counting of edges + + # set all contributions past the cutoff to zero + e[rij > rc] = 0.0 + + return e + + def requested_neighbor_lists(self) -> List[NeighborListOptions]: + return [ + NeighborListOptions( + cutoff=self.cutoff_radius, + full_list=True, + ) + ] + + +def _phi(r, c, da): + phi = torch.sum(c.unsqueeze(-1) * torch.exp(-r * da), dim=0) + return phi + + +def _dphi(r, c, da): + dphi = torch.sum(-c.unsqueeze(-1) * da * torch.exp(-r * da), dim=0) + return dphi + + +def _d2phi(r, c, da): + d2phi = torch.sum(c.unsqueeze(-1) * (da**2) * torch.exp(-r * da), dim=0) + return d2phi + + +def _e_zbl(factor, r, c, da): + phi = _phi(r, c, da) + ret = factor / r * phi + return ret + + +def _dedr(factor, r, c, da): + phi = _phi(r, c, da) + dphi = _dphi(r, c, da) + ret = factor / r * (-phi / r + dphi) + return ret + + +def _d2edr2(factor, r, c, da): + phi = _phi(r, c, da) + dphi = _dphi(r, c, da) + d2phi = _d2phi(r, c, da) + + ret = factor / r * (d2phi - 2 / r * dphi + 2 * phi / (r**2)) + return ret diff --git a/src/metatrain/utils/neighbor_lists.py b/src/metatrain/utils/neighbor_lists.py index b76d836f6..91f9d9641 100644 --- a/src/metatrain/utils/neighbor_lists.py +++ b/src/metatrain/utils/neighbor_lists.py @@ -15,6 +15,55 @@ from .data.system_to_ase import system_to_ase +def get_requested_neighbor_lists( + module: torch.nn.Module, +) -> List[NeighborListOptions]: + """Get the neighbor lists requested by a module and its children. + + :param module: The module for which to get the requested neighbor lists. + + :return: A list of `NeighborListOptions` objects requested by the module. + """ + requested: List[NeighborListOptions] = [] + _get_requested_neighbor_lists_in_place( + module=module, + module_name="", + requested=requested, + ) + return requested + + +def _get_requested_neighbor_lists_in_place( + module: torch.nn.Module, + module_name: str, + requested: List[NeighborListOptions], +): + # copied from + # metatensor/python/metatensor-torch/metatensor/torch/atomistic/model.py + # and just removed the length units + + if hasattr(module, "requested_neighbor_lists"): + for new_options in module.requested_neighbor_lists(): + new_options.add_requestor(module_name) + + already_requested = False + for existing in requested: + if existing == new_options: + already_requested = True + for requestor in new_options.requestors(): + existing.add_requestor(requestor) + + if not already_requested: + requested.append(new_options) + + for child_name, child in module.named_children(): + _get_requested_neighbor_lists_in_place( + module=child, + module_name=module_name + "." + child_name, + requested=requested, + ) + + def get_system_with_neighbor_lists( system: System, neighbor_lists: List[NeighborListOptions] ) -> System: diff --git a/src/metatrain/utils/output_gradient.py b/src/metatrain/utils/output_gradient.py index dda6888d2..d7cc3664e 100644 --- a/src/metatrain/utils/output_gradient.py +++ b/src/metatrain/utils/output_gradient.py @@ -1,3 +1,4 @@ +import warnings from typing import List, Optional import torch @@ -14,13 +15,29 @@ def compute_gradient( """ grad_outputs: Optional[List[Optional[torch.Tensor]]] = [torch.ones_like(target)] - gradient = torch.autograd.grad( - outputs=[target], - inputs=inputs, - grad_outputs=grad_outputs, - retain_graph=is_training, - create_graph=is_training, - ) + try: + gradient = torch.autograd.grad( + outputs=[target], + inputs=inputs, + grad_outputs=grad_outputs, + retain_graph=is_training, + create_graph=is_training, + ) + except RuntimeError as e: + # Torch raises an error if the target tensor does not require grad, + # but this could just mean that the target is a constant tensor, like in + # the case of composition models. In this case, we can safely ignore the error + # and we raise a warning instead. The warning can be caught and silenced in the + # appropriate places. + if ( + "element 0 of tensors does not require grad and does not have a grad_fn" + in str(e) + ): + warnings.warn(f"GRADIENT WARNING: {e}", RuntimeWarning, stacklevel=2) + gradient = [torch.zeros_like(i) for i in inputs] + else: + # Re-raise the error if it's not the one above + raise if gradient is None: raise ValueError( "Unexpected None value for computed gradient. " diff --git a/tests/utils/test_composition.py b/tests/utils/test_additive.py similarity index 82% rename from tests/utils/test_composition.py rename to tests/utils/test_additive.py index 780744664..fd2179e5e 100644 --- a/tests/utils/test_composition.py +++ b/tests/utils/test_additive.py @@ -7,9 +7,13 @@ from metatensor.torch.atomistic import ModelOutput, System from omegaconf import OmegaConf -from metatrain.utils.composition import CompositionModel, remove_composition +from metatrain.utils.additive import ZBL, CompositionModel, remove_additive from metatrain.utils.data import Dataset, DatasetInfo, TargetInfo, TargetInfoDict from metatrain.utils.data.readers import read_systems, read_targets +from metatrain.utils.neighbor_lists import ( + get_requested_neighbor_lists, + get_system_with_neighbor_lists, +) RESOURCES_PATH = Path(__file__).parents[1] / "resources" @@ -224,8 +228,8 @@ def test_composition_model_torchscript(tmpdir): ) -def test_remove_composition(): - """Tests the remove_composition function.""" +def test_remove_additive(): + """Tests the remove_additive function.""" dataset_path = RESOURCES_PATH / "qm9_reduced_100.xyz" systems = read_systems(dataset_path) @@ -260,7 +264,7 @@ def test_remove_composition(): targets["mtt::U0"] = metatensor.torch.join(targets["mtt::U0"], axis="samples") std_before = targets["mtt::U0"].block().values.std().item() - remove_composition(systems, targets, composition_model) + remove_additive(systems, targets, composition_model, target_info) std_after = targets["mtt::U0"].block().values.std().item() # In QM9 the composition contribution is very large: the standard deviation @@ -393,3 +397,81 @@ def test_composition_model_wrong_target(): ), ), ) + + +def test_zbl(): + """Test the ZBL model.""" + + dataset_path = RESOURCES_PATH / "qm9_reduced_100.xyz" + + systems = read_systems(dataset_path)[:5] + + conf = { + "mtt::U0": { + "quantity": "energy", + "read_from": dataset_path, + "file_format": ".xyz", + "reader": "ase", + "key": "U0", + "unit": "eV", + "forces": False, + "stress": False, + "virial": False, + } + } + _, target_info = read_targets(OmegaConf.create(conf)) + + zbl = ZBL( + model_hypers={}, + dataset_info=DatasetInfo( + length_unit="angstrom", + atomic_types=[1, 6, 7, 8], + targets=target_info, + ), + ) + + requested_neighbor_lists = get_requested_neighbor_lists(zbl) + for system in systems: + get_system_with_neighbor_lists(system, requested_neighbor_lists) + + # per_atom = True + output = zbl( + systems, + {"mtt::U0": ModelOutput(quantity="energy", unit="", per_atom=True)}, + ) + assert "mtt::U0" in output + assert output["mtt::U0"].block().samples.names == ["system", "atom"] + assert output["mtt::U0"].block().values.shape != (5, 1) + + # with selected_atoms + selected_atoms = metatensor.torch.Labels( + names=["system", "atom"], + values=torch.tensor([[0, 0]]), + ) + + output = zbl( + systems, + {"mtt::U0": ModelOutput(quantity="energy", unit="", per_atom=True)}, + selected_atoms=selected_atoms, + ) + assert "mtt::U0" in output + assert output["mtt::U0"].block().samples.names == ["system", "atom"] + assert output["mtt::U0"].block().values.shape == (1, 1) + + # per_atom = False + output = zbl( + systems, + {"mtt::U0": ModelOutput(quantity="energy", unit="", per_atom=False)}, + ) + assert "mtt::U0" in output + assert output["mtt::U0"].block().samples.names == ["system"] + assert output["mtt::U0"].block().values.shape == (5, 1) + + # check that the result is the same without batching + expected = output["mtt::U0"].block().values[3] + system = systems[3] + output = zbl( + [system], + {"mtt::U0": ModelOutput(quantity="energy", unit="", per_atom=False)}, + ) + assert torch.allclose(output["mtt::U0"].block().values[0], expected) diff --git a/tests/utils/test_evaluate_model.py b/tests/utils/test_evaluate_model.py index 72826cd7a..e2bd81eca 100644 --- a/tests/utils/test_evaluate_model.py +++ b/tests/utils/test_evaluate_model.py @@ -6,7 +6,10 @@ from metatrain.utils.data import DatasetInfo, TargetInfo, read_systems from metatrain.utils.evaluate_model import evaluate_model from metatrain.utils.export import export -from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists +from metatrain.utils.neighbor_lists import ( + get_requested_neighbor_lists, + get_system_with_neighbor_lists, +) from . import MODEL_HYPERS, RESOURCES_PATH @@ -45,8 +48,9 @@ def test_evaluate_model(training, exported): ) model = export(model, capabilities) + requested_neighbor_lists = get_requested_neighbor_lists(model) systems = [ - get_system_with_neighbor_lists(system, model.requested_neighbor_lists()) + get_system_with_neighbor_lists(system, requested_neighbor_lists) for system in systems ] diff --git a/tests/utils/test_llpr.py b/tests/utils/test_llpr.py index f1887ef5b..189e7c2ac 100644 --- a/tests/utils/test_llpr.py +++ b/tests/utils/test_llpr.py @@ -9,7 +9,10 @@ from metatrain.utils.data import Dataset, collate_fn, read_systems, read_targets from metatrain.utils.llpr import LLPRUncertaintyModel -from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists +from metatrain.utils.neighbor_lists import ( + get_requested_neighbor_lists, + get_system_with_neighbor_lists, +) from . import RESOURCES_PATH @@ -37,7 +40,7 @@ def test_llpr(tmpdir): }, } targets, _ = read_targets(target_config) - requested_neighbor_lists = model.requested_neighbor_lists() + requested_neighbor_lists = get_requested_neighbor_lists(model) qm9_systems = [ get_system_with_neighbor_lists(system, requested_neighbor_lists) for system in qm9_systems diff --git a/tox.ini b/tox.ini index a38fe4949..a8ef4bdd3 100644 --- a/tox.ini +++ b/tox.ini @@ -143,6 +143,7 @@ commands_pre = bash -c "set -e && cd {toxinidir}/examples/basic_usage && bash usage.sh" bash -c "set -e && cd {toxinidir}/examples/ase && bash train.sh" bash -c "set -e && cd {toxinidir}/examples/programmatic/llpr && bash train.sh" + bash -c "set -e && cd {toxinidir}/examples/zbl && bash train.sh" sphinx-build \ {posargs:-E} \ --builder html \