Skip to content

Commit

Permalink
Add a general torch CompositionModel (#280)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: frostedoyster <[email protected]>
Co-authored-by: Guillaume Fraux <[email protected]>
  • Loading branch information
3 people authored Sep 7, 2024
1 parent bf2c742 commit e6d3927
Show file tree
Hide file tree
Showing 12 changed files with 774 additions and 230 deletions.
5 changes: 3 additions & 2 deletions docs/src/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@

import tomli # Replace by tomllib from std library once docs are build with Python 3.11

import metatrain


# When importing metatensor-torch, this will change the definition of the classes
# to include the documentation
os.environ["METATENSOR_IMPORT_FOR_SPHINX"] = "1"
os.environ["RASCALINE_IMPORT_FOR_SPHINX"] = "1"

import metatrain # noqa: E402


ROOT = os.path.abspath(os.path.join("..", ".."))

# We use a second (pseudo) sphinx project located in `docs/generate_examples` to run the
Expand Down
2 changes: 1 addition & 1 deletion src/metatrain/experimental/alchemical_model/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import torch
from metatensor.learn.data import DataLoader

from ...utils.composition import calculate_composition_weights
from ...utils.data import (
CombinedDataLoader,
Dataset,
Expand All @@ -23,6 +22,7 @@
from ...utils.neighbor_lists import get_system_with_neighbor_lists
from ...utils.per_atom import average_by_num_atoms
from . import AlchemicalModel
from .utils.composition import calculate_composition_weights
from .utils.normalize import (
get_average_number_of_atoms,
get_average_number_of_neighbors,
Expand Down
69 changes: 69 additions & 0 deletions src/metatrain/experimental/alchemical_model/utils/composition.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from typing import List, Tuple, Union

import torch

from ....utils.data.dataset import Dataset, get_atomic_types


def calculate_composition_weights(
datasets: Union[Dataset, List[Dataset]], property: str
) -> Tuple[torch.Tensor, List[int]]:
"""Calculate the composition weights for a dataset.
It assumes per-system properties.
:param dataset: Dataset to calculate the composition weights for.
:returns: Composition weights for the dataset, as well as the
list of species that the weights correspond to.
"""
if not isinstance(datasets, list):
datasets = [datasets]

# Note: `atomic_types` are sorted, and the composition weights are sorted as
# well, because the species are sorted in the composition features.
atomic_types = sorted(get_atomic_types(datasets))

targets = torch.stack(
[sample[property].block().values for dataset in datasets for sample in dataset]
)
targets = targets.squeeze(dim=(1, 2)) # remove component and property dimensions

total_num_structures = sum([len(dataset) for dataset in datasets])
dtype = datasets[0][0]["system"].positions.dtype
composition_features = torch.empty(
(total_num_structures, len(atomic_types)), dtype=dtype
)
structure_index = 0
for dataset in datasets:
for sample in dataset:
structure = sample["system"]
for j, s in enumerate(atomic_types):
composition_features[structure_index, j] = torch.sum(
structure.types == s
)
structure_index += 1

regularizer = 1e-20
while regularizer:
if regularizer > 1e5:
raise RuntimeError(
"Failed to solve the linear system to calculate the "
"composition weights. The dataset is probably too small "
"or ill-conditioned."
)
try:
solution = torch.linalg.solve(
composition_features.T @ composition_features
+ regularizer
* torch.eye(
composition_features.shape[1],
dtype=composition_features.dtype,
device=composition_features.device,
),
composition_features.T @ targets,
)
break
except torch._C._LinAlgError:
regularizer *= 10.0

return solution, atomic_types
22 changes: 20 additions & 2 deletions src/metatrain/experimental/gap/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from metatrain.utils.data.dataset import DatasetInfo

from ...utils.composition import CompositionModel
from ...utils.export import export


Expand Down Expand Up @@ -127,6 +128,11 @@ def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo) -> None:
)
self._species_labels: TorchLabels = TorchLabels.empty("_")

self.composition_model = CompositionModel(
model_hypers={},
dataset_info=dataset_info,
)

def restart(self, dataset_info: DatasetInfo) -> "GAP":
raise ValueError("GAP does not allow restarting training")

Expand Down Expand Up @@ -201,8 +207,20 @@ def forward(
soap_features = TorchTensorMap(self._keys, soap_features.blocks())
output_key = list(outputs.keys())[0]
energies = self._subset_of_regressors_torch(soap_features)
out_tensor = self.apply_composition_weights(systems, energies)
return {output_key: out_tensor}
return_dict = {output_key: energies}

# apply composition model
composition_energies = self.composition_model(
systems, {output_key: ModelOutput("energy", per_atom=True)}, selected_atoms
)
composition_energies[output_key] = metatensor.torch.sum_over_samples(
composition_energies[output_key], "atom"
)
return_dict[output_key] = metatensor.torch.add(
return_dict[output_key], composition_energies[output_key]
)

return return_dict

def export(self) -> MetatensorAtomisticModel:
capabilities = ModelCapabilities(
Expand Down
31 changes: 6 additions & 25 deletions src/metatrain/experimental/gap/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from metatrain.utils.data import Dataset

from ...utils.composition import calculate_composition_weights
from ...utils.composition import remove_composition
from ...utils.data import check_datasets
from . import GAP
from .model import torch_tensor_map_to_core
Expand Down Expand Up @@ -52,10 +52,7 @@ def train(

# Calculate and set the composition weights:
logger.info("Calculating composition weights")
composition_weights, species = calculate_composition_weights(
train_datasets, target_name
)
model.set_composition_weights(target_name, composition_weights, species)
model.composition_model.train_model(train_datasets)

logger.info("Setting up data loaders")
if len(train_datasets[0][0][output_name].keys) > 1:
Expand All @@ -72,26 +69,10 @@ def train(
model._keys = train_y.keys
train_structures = [sample["system"] for sample in train_dataset]

logger.info("Fitting composition energies")
composition_energies = torch.zeros(len(train_y.block().values), dtype=dtype)
for i, structure in enumerate(train_structures):
for j, s in enumerate(species):
composition_energies[i] += (
torch.sum(structure.types == s) * composition_weights[j]
)
train_y_values = train_y.block().values
train_y_values = train_y_values - composition_energies.reshape(-1, 1)
train_block = metatensor.torch.TensorBlock(
values=train_y_values,
samples=train_y.block().samples,
components=train_y.block().components,
properties=train_y.block().properties,
)
if len(train_y[0].gradients_list()) > 0:
train_block.add_gradient("positions", train_y[0].gradient("positions"))
train_y = metatensor.torch.TensorMap(
train_y.keys,
[train_block],
logger.info("Subtracting composition energies")
# this acts in-place on train_y
remove_composition(
train_structures, {target_name: train_y}, model.composition_model
)

logger.info("Calculating SOAP features")
Expand Down
55 changes: 25 additions & 30 deletions src/metatrain/experimental/soap_bpnn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from metatrain.utils.data.dataset import DatasetInfo

from ...utils.composition import apply_composition_contribution
from ...utils.composition import CompositionModel
from ...utils.dtype import dtype_to_str
from ...utils.export import export

Expand Down Expand Up @@ -123,14 +123,6 @@ def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo) -> None:
unit="unitless", per_atom=True
)

# creates a composition weight tensor that can be directly indexed by species,
# this can be left as a tensor of zero or set from the outside using
# set_composition_weights (recommended for better accuracy)
n_outputs = len(self.outputs)
self.register_buffer(
"composition_weights",
torch.zeros((n_outputs, max(self.atomic_types) + 1)),
)
# buffers cannot be indexed by strings (torchscript), so we create a single
# tensor for all output. Due to this, we need to slice the tensor when we use
# it and use the output name to select the correct slice via a dictionary
Expand Down Expand Up @@ -195,6 +187,11 @@ def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo) -> None:
}
)

self.composition_model = CompositionModel(
model_hypers={},
dataset_info=dataset_info,
)

def restart(self, dataset_info: DatasetInfo) -> "SoapBpnn":
# merge old and new dataset info
merged_info = self.dataset_info.union(dataset_info)
Expand Down Expand Up @@ -261,12 +258,7 @@ def forward(
atomic_energies: Dict[str, TensorMap] = {}
for output_name, output_layer in self.last_layers.items():
if output_name in outputs:
atomic_energies[output_name] = apply_composition_contribution(
output_layer(last_layer_features),
self.composition_weights[ # type: ignore
self.output_to_index[output_name]
],
)
atomic_energies[output_name] = output_layer(last_layer_features)

# Sum the atomic energies coming from the BPNN to get the total energy
for output_name, atomic_energy in atomic_energies.items():
Expand All @@ -281,6 +273,19 @@ def forward(
atomic_energy, ["atom", "center_type"]
)

if not self.training:
# at evaluation, we also add the composition contributions
composition_contributions = self.composition_model(
systems, outputs, selected_atoms
)
for name in return_dict:
if name.startswith("mtt::aux::"):
continue # skip auxiliary outputs (not targets)
return_dict[name] = metatensor.torch.add(
return_dict[name],
composition_contributions[name],
)

return return_dict

@classmethod
Expand All @@ -303,6 +308,11 @@ def export(self) -> MetatensorAtomisticModel:
if dtype not in self.__supported_dtypes__:
raise ValueError(f"unsupported dtype {self.dtype} for SoapBpnn")

# Make sure the model is all in the same dtype
# For example, at this point, the composition model within the SOAP-BPNN is
# still float64
self.to(dtype)

capabilities = ModelCapabilities(
outputs=self.outputs,
atomic_types=self.atomic_types,
Expand All @@ -314,21 +324,6 @@ def export(self) -> MetatensorAtomisticModel:

return export(model=self, model_capabilities=capabilities)

def set_composition_weights(
self,
output_name: str,
input_composition_weights: torch.Tensor,
atomic_types: List[int],
) -> None:
"""Set the composition weights for a given output."""
# all species that are not present retain their weight of zero
self.composition_weights[self.output_to_index[output_name]][ # type: ignore
atomic_types
] = input_composition_weights.to(
dtype=self.composition_weights.dtype, # type: ignore
device=self.composition_weights.device, # type: ignore
)

def add_output(self, output_name: str) -> None:
"""Add a new output to the self."""
# add a new row to the composition weights tensor
Expand Down
6 changes: 6 additions & 0 deletions src/metatrain/experimental/soap_bpnn/tests/test_continue.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ def test_continue(monkeypatch, tmp_path):
}
}
targets, _ = read_targets(OmegaConf.create(conf))

# systems in float64 are required for training
systems = [system.to(torch.float64) for system in systems]
dataset = Dataset.from_dict({"system": systems, "mtt::U0": targets["mtt::U0"]})

hypers = DEFAULT_HYPERS.copy()
Expand All @@ -63,6 +66,9 @@ def test_continue(monkeypatch, tmp_path):
checkpoint_dir=".",
)

# evaluation
systems = [system.to(torch.float32) for system in systems]

# Predict on the first five systems
output_before = model_before(
systems[:5], {"mtt::U0": model_before.outputs["mtt::U0"]}
Expand Down
26 changes: 16 additions & 10 deletions src/metatrain/experimental/soap_bpnn/tests/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,18 @@ def test_regression_init():
)

expected_output = torch.tensor(
[[-0.03860], [0.11137], [0.09112], [-0.05634], [-0.02549]]
[
[-0.038599025458],
[0.111374437809],
[0.091115802526],
[-0.056339077652],
[-0.025491207838],
]
)

# if you need to change the hardcoded values:
# torch.set_printoptions(precision=12)
# print(output["mtt::U0"].block().values)
torch.set_printoptions(precision=12)
print(output["mtt::U0"].block().values)

torch.testing.assert_close(
output["mtt::U0"].block().values, expected_output, rtol=1e-5, atol=1e-5
Expand Down Expand Up @@ -100,17 +106,17 @@ def test_regression_train():

expected_output = torch.tensor(
[
[-40.592571258545],
[-56.522350311279],
[-76.571365356445],
[-77.384849548340],
[-93.445365905762],
[-0.106249026954],
[0.039981484413],
[-0.142682999372],
[-0.031701669097],
[-0.016210660338],
]
)

# if you need to change the hardcoded values:
# torch.set_printoptions(precision=12)
# print(output["mtt::U0"].block().values)
torch.set_printoptions(precision=12)
print(output["mtt::U0"].block().values)

torch.testing.assert_close(
output["mtt::U0"].block().values, expected_output, rtol=1e-5, atol=1e-5
Expand Down
Loading

0 comments on commit e6d3927

Please sign in to comment.