diff --git a/examples/programmatic/llpr_forces/ethanol_reduced_100.xyz b/examples/programmatic/llpr_forces/ethanol_reduced_100.xyz new file mode 120000 index 000000000..d9345ce62 --- /dev/null +++ b/examples/programmatic/llpr_forces/ethanol_reduced_100.xyz @@ -0,0 +1 @@ +../../ase/ethanol_reduced_100.xyz \ No newline at end of file diff --git a/examples/programmatic/llpr_forces/force_llpr.py b/examples/programmatic/llpr_forces/force_llpr.py new file mode 100644 index 000000000..155e33527 --- /dev/null +++ b/examples/programmatic/llpr_forces/force_llpr.py @@ -0,0 +1,229 @@ +import matplotlib.pyplot as plt +import numpy as np +import torch +from metatensor.torch.atomistic import ( + MetatensorAtomisticModel, + ModelEvaluationOptions, + ModelMetadata, + ModelOutput, + load_atomistic_model, +) + +from metatrain.utils.data import Dataset, collate_fn, read_systems, read_targets +from metatrain.utils.llpr import LLPRUncertaintyModel +from metatrain.utils.loss import TensorMapDictLoss +from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists + + +model = load_atomistic_model("model.pt", extensions_directory="extensions/") +model = model.to("cuda") + +train_systems = read_systems("train.xyz") +train_target_config = { + "energy": { + "quantity": "energy", + "read_from": "train.xyz", + "file_format": ".xyz", + "reader": "ase", + "key": "energy", + "unit": "kcal/mol", + "forces": { + "read_from": "train.xyz", + "file_format": ".xyz", + "reader": "ase", + "key": "forces", + }, + "stress": { + "read_from": "train.xyz", + "file_format": ".xyz", + "reader": "ase", + "key": "stress", + }, + "virial": False, + }, +} +train_targets, _ = read_targets(train_target_config) + +valid_systems = read_systems("valid.xyz") +valid_target_config = { + "energy": { + "quantity": "energy", + "read_from": "valid.xyz", + "file_format": ".xyz", + "reader": "ase", + "key": "energy", + "unit": "kcal/mol", + "forces": { + "read_from": "valid.xyz", + "file_format": ".xyz", + "reader": "ase", + "key": "forces", + }, + "stress": { + "read_from": "valid.xyz", + "file_format": ".xyz", + "reader": "ase", + "key": "stress", + }, + "virial": False, + }, +} +valid_targets, _ = read_targets(valid_target_config) + +test_systems = read_systems("test.xyz") +test_target_config = { + "energy": { + "quantity": "energy", + "read_from": "test.xyz", + "file_format": ".xyz", + "reader": "ase", + "key": "energy", + "unit": "kcal/mol", + "forces": { + "read_from": "test.xyz", + "file_format": ".xyz", + "reader": "ase", + "key": "forces", + }, + "stress": { + "read_from": "test.xyz", + "file_format": ".xyz", + "reader": "ase", + "key": "stress", + }, + "virial": False, + }, +} +test_targets, target_info = read_targets(test_target_config) + +requested_neighbor_lists = model.requested_neighbor_lists() +train_systems = [ + get_system_with_neighbor_lists(system, requested_neighbor_lists) + for system in train_systems +] +train_dataset = Dataset({"system": train_systems, **train_targets}) +valid_systems = [ + get_system_with_neighbor_lists(system, requested_neighbor_lists) + for system in valid_systems +] +valid_dataset = Dataset({"system": valid_systems, **valid_targets}) +test_systems = [ + get_system_with_neighbor_lists(system, requested_neighbor_lists) + for system in test_systems +] +test_dataset = Dataset({"system": test_systems, **test_targets}) + +train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_size=4, + shuffle=False, + collate_fn=collate_fn, +) +valid_dataloader = torch.utils.data.DataLoader( + valid_dataset, + batch_size=4, + shuffle=False, + collate_fn=collate_fn, +) +test_dataloader = torch.utils.data.DataLoader( + test_dataset, + batch_size=4, + shuffle=False, + collate_fn=collate_fn, +) + +loss_weight_dict = { + "energy": 1.0, + "energy_positions_grad": 1.0, + "energy_grain_grad": 1.0, +} +loss_fn = TensorMapDictLoss(loss_weight_dict) + +llpr_model = LLPRUncertaintyModel(model) + +print("Last layer parameters:") +parameters = [] +for name, param in llpr_model.named_parameters(): + if "last_layers" in name: + parameters.append(param) + print(name) + +llpr_model.compute_covariance_as_pseudo_hessian( + train_dataloader, target_info, loss_fn, parameters +) +llpr_model.compute_inverse_covariance() +llpr_model.calibrate(valid_dataloader) + +exported_model = MetatensorAtomisticModel( + llpr_model.eval(), + ModelMetadata(), + llpr_model.capabilities, +) + +evaluation_options = ModelEvaluationOptions( + length_unit="angstrom", + outputs={ + "mtt::aux::last_layer_features": ModelOutput(per_atom=False), + "mtt::aux::energy_uncertainty": ModelOutput(per_atom=False), + "energy": ModelOutput(per_atom=False), + }, + selected_atoms=None, +) + +force_errors = [] +force_uncertainties = [] + +for batch in test_dataloader: + systems, targets = batch + systems = [system.to("cuda", torch.float64) for system in systems] + for system in systems: + system.positions.requires_grad = True + targets = {name: tmap.to("cuda", torch.float64) for name, tmap in targets.items()} + + outputs = exported_model(systems, evaluation_options, check_consistency=True) + energy = outputs["energy"].block().values + energy_sum = torch.sum(energy) + energy_sum.backward(retain_graph=True) + + predicted_forces = -torch.concatenate( + [system.positions.grad.flatten() for system in systems] + ) + true_forces = -targets["energy"].block().gradient("positions").values.flatten() + + force_error = (predicted_forces - true_forces) ** 2 + force_errors.append(force_error.detach().clone().cpu().numpy()) + + last_layer_features = outputs["mtt::aux::last_layer_features"].block().values + last_layer_features = torch.sum(last_layer_features, dim=0) + ll_feature_grads = [] + for ll_feature in last_layer_features.reshape((-1,)): + ll_feature_grad = torch.autograd.grad( + ll_feature.reshape(()), + [system.positions for system in systems], + retain_graph=True, + ) + ll_feature_grad = torch.concatenate( + [ll_feature_g.flatten() for ll_feature_g in ll_feature_grad] + ) + ll_feature_grads.append(ll_feature_grad) + ll_feature_grads = torch.stack(ll_feature_grads, dim=1) + + force_uncertainty = torch.einsum( + "if, fg, ig -> i", + ll_feature_grads, + exported_model._module.inv_covariance, + ll_feature_grads, + ) + force_uncertainties.append(force_uncertainty.detach().clone().cpu().numpy()) + +force_errors = np.concatenate(force_errors) +force_uncertainties = np.concatenate(force_uncertainties) + + +plt.scatter(force_uncertainties, force_errors, s=1) +plt.xscale("log") +plt.yscale("log") +plt.xlabel("Predicted variance") +plt.ylabel("Squared error") + +plt.savefig("figure.pdf") diff --git a/examples/programmatic/llpr_forces/options.yaml b/examples/programmatic/llpr_forces/options.yaml new file mode 100644 index 000000000..492cfaa0b --- /dev/null +++ b/examples/programmatic/llpr_forces/options.yaml @@ -0,0 +1,35 @@ +seed: 42 + +architecture: + name: experimental.soap_bpnn + training: + batch_size: 8 + num_epochs: 100 + log_interval: 1 + +training_set: + systems: + read_from: train.xyz + length_unit: angstrom + targets: + energy: + key: energy + unit: eV + +validation_set: + systems: + read_from: valid.xyz + length_unit: angstrom + targets: + energy: + key: energy + unit: eV + +test_set: + systems: + read_from: test.xyz + length_unit: angstrom + targets: + energy: + key: energy + unit: eV diff --git a/examples/programmatic/llpr_forces/readme.txt b/examples/programmatic/llpr_forces/readme.txt new file mode 100644 index 000000000..3dce78c15 --- /dev/null +++ b/examples/programmatic/llpr_forces/readme.txt @@ -0,0 +1,4 @@ +This is a small example of how to calculate force uncertainties with the LLPR. +In order to run it, it is sufficient to split the ethanol dataset with `python split.py`. +Then train a model with `mtt train options.yaml`, and finally run the example +with `python force_llpr.py`. diff --git a/examples/programmatic/llpr_forces/split.py b/examples/programmatic/llpr_forces/split.py new file mode 100644 index 000000000..4c5902b62 --- /dev/null +++ b/examples/programmatic/llpr_forces/split.py @@ -0,0 +1,13 @@ +import ase.io +import numpy as np + + +structures = ase.io.read("ethanol_reduced_100.xyz", ":") +np.random.shuffle(structures) +train = structures[:50] +valid = structures[50:60] +test = structures[60:] + +ase.io.write("train.xyz", train) +ase.io.write("valid.xyz", valid) +ase.io.write("test.xyz", test) diff --git a/src/metatrain/utils/llpr.py b/src/metatrain/utils/llpr.py index 7f23fda9b..164dd4267 100644 --- a/src/metatrain/utils/llpr.py +++ b/src/metatrain/utils/llpr.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional +from typing import Callable, Dict, List, Optional import metatensor.torch import numpy as np @@ -12,6 +12,11 @@ ) from torch.utils.data import DataLoader +from .data import DatasetInfo, TargetInfoDict, get_atomic_types +from .data.extract_targets import get_targets_dict +from .evaluate_model import evaluate_model +from .per_atom import average_by_num_atoms + class LLPRUncertaintyModel(torch.nn.Module): """A wrapper that adds LLPR uncertainties to a model. @@ -229,12 +234,13 @@ def compute_covariance(self, train_loader: DataLoader) -> None: class in ``metatrain``. """ device = self.covariance.device + dtype = self.covariance.dtype for batch in train_loader: systems, _ = batch n_atoms = torch.tensor( [len(system.positions) for system in systems], device=device ) - systems = [system.to(device=device) for system in systems] + systems = [system.to(device=device, dtype=dtype) for system in systems] outputs = { "mtt::aux::last_layer_features": ModelOutput( quantity="", @@ -248,10 +254,100 @@ class in ``metatrain``. ) output = self.model(systems, options, check_consistency=False) ll_feat_tmap = output["mtt::aux::last_layer_features"] - ll_feats = ll_feat_tmap.block().values / n_atoms.unsqueeze(1) + ll_feats = ll_feat_tmap.block().values.detach() / n_atoms.unsqueeze(1) self.covariance += ll_feats.T @ ll_feats self.covariance_computed = True + def compute_covariance_as_pseudo_hessian( + self, + train_loader: DataLoader, + target_infos: TargetInfoDict, + loss_fn: Callable, + parameters: List[torch.nn.Parameter], + ) -> None: + """A function to compute the covariance matrix for a training set + as the pseudo-Hessian of the loss function. + + The covariance/pseudo-Hessian is stored as a buffer in the model. The + loss function must be compatible with the Dataloader (i.e., it should + have the same structure as the outputs of the model and as the targets + in the dataset). All contributions to the loss functions are assumed to + be per-atom, except for quantities that are already per-atom (e.g., + forces). + + :param train_loader: A PyTorch DataLoader with the training data. + The individual samples need to be compatible with the ``Dataset`` + class in ``metatrain``. + :param loss_fn: A loss function that takes the model outputs and the + targets and returns a scalar loss. + :param parameters: A list of model parameters for which the pseudo-Hessian + should be computed. This is often necessary as models can have very + large numbers of parameters, and the pseudo-Hessian's number of + elements grows quadratically with the number of parameters. For this + reason, only a subset of the parameters of the model is usually used + in the calculation. This list allows the user to feed the parameters + of interest directly to the function. In order to function correctly, + the model's parameters should be those corresponding to the last + layer(s) of the model, such that their concatenation corresponds to the + last-layer features, in the same order as those are returned by the + base model. + """ + self.model = self.model.train() # we need gradients w.r.t. parameters + # disable gradients for all parameters that are not in the list + for parameter in self.model.parameters(): + parameter.requires_grad = False + for parameter in parameters: + parameter.requires_grad = True + + dataset = train_loader.dataset + dataset_info = DatasetInfo( + length_unit=self.capabilities.length_unit, # TODO: check + atomic_types=get_atomic_types(dataset), + targets=target_infos, + ) + train_targets = get_targets_dict([train_loader.dataset], dataset_info) + device = self.covariance.device + dtype = self.covariance.dtype + for batch in train_loader: + systems, targets = batch + systems = [system.to(device=device, dtype=dtype) for system in systems] + targets = { + name: tmap.to(device=device, dtype=dtype) + for name, tmap in targets.items() + } + predictions = evaluate_model( + self.model, + systems, + TargetInfoDict(**{key: train_targets[key] for key in targets.keys()}), + is_training=True, # keep the computational graph + ) + + # average by the number of atoms + predictions = average_by_num_atoms(predictions, systems, []) + targets = average_by_num_atoms(targets, systems, []) + + loss = loss_fn(predictions, targets) + + grads = torch.autograd.grad( + loss, + parameters, + create_graph=False, + retain_graph=False, + allow_unused=True, # if there are multiple last-layers + materialize_grads=True, # avoid Nones + ) + + grads = torch.cat(grads, dim=1) + self.covariance += grads.T @ grads + for parameter in parameters: + parameter.grad = None # reset the gradients + + self.covariance_computed = True + + for parameter in self.model.parameters(): + parameter.requires_grad = True + self.model = self.model.eval() # restore the model to evaluation mode + def compute_inverse_covariance(self, regularizer: Optional[float] = None): """A function to compute the inverse covariance matrix. @@ -307,14 +403,16 @@ def calibrate(self, valid_loader: DataLoader): """ # calibrate the LLPR device = self.covariance.device + dtype = self.covariance.dtype all_predictions = {} # type: ignore all_targets = {} # type: ignore all_uncertainties = {} # type: ignore for batch in valid_loader: systems, targets = batch - systems = [system.to(device=device) for system in systems] + systems = [system.to(device=device, dtype=dtype) for system in systems] targets = { - name: target.to(device=device) for name, target in targets.items() + name: target.to(device=device, dtype=dtype) + for name, target in targets.items() } # evaluate the targets and their uncertainties, not per atom requested_outputs = {} @@ -337,10 +435,10 @@ def calibrate(self, valid_loader: DataLoader): all_predictions[name] = [] all_targets[name] = [] all_uncertainties[uncertainty_name] = [] - all_predictions[name].append(outputs[name].block().values) + all_predictions[name].append(outputs[name].block().values.detach()) all_targets[name].append(target.block().values) all_uncertainties[uncertainty_name].append( - outputs[uncertainty_name].block().values + outputs[uncertainty_name].block().values.detach() ) for name in all_predictions: diff --git a/tests/utils/test_llpr.py b/tests/utils/test_llpr.py index 189e7c2ac..4c97c79b1 100644 --- a/tests/utils/test_llpr.py +++ b/tests/utils/test_llpr.py @@ -9,6 +9,7 @@ from metatrain.utils.data import Dataset, collate_fn, read_systems, read_targets from metatrain.utils.llpr import LLPRUncertaintyModel +from metatrain.utils.loss import TensorMapDictLoss from metatrain.utils.neighbor_lists import ( get_requested_neighbor_lists, get_system_with_neighbor_lists, @@ -140,3 +141,140 @@ def test_llpr(tmpdir): torch.testing.assert_close( analytical_uncertainty, ensemble_uncertainty, rtol=1e-2, atol=1e-2 ) + + +def test_llpr_covariance_as_pseudo_hessian(tmpdir): + + model = load_atomistic_model( + str(RESOURCES_PATH / "model-64-bit.pt"), + extensions_directory=str(RESOURCES_PATH / "extensions/"), + ) + qm9_systems = read_systems(RESOURCES_PATH / "qm9_reduced_100.xyz") + target_config = { + "energy": { + "quantity": "energy", + "read_from": str(RESOURCES_PATH / "qm9_reduced_100.xyz"), + "reader": "ase", + "key": "U0", + "unit": "kcal/mol", + "forces": False, + "stress": False, + "virial": False, + }, + } + targets, target_info = read_targets(target_config) + requested_neighbor_lists = model.requested_neighbor_lists() + qm9_systems = [ + get_system_with_neighbor_lists(system, requested_neighbor_lists) + for system in qm9_systems + ] + dataset = Dataset.from_dict({"system": qm9_systems, **targets}) + dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=10, + shuffle=False, + collate_fn=collate_fn, + ) + + llpr_model = LLPRUncertaintyModel(model) + + parameters = [] + for name, param in llpr_model.named_parameters(): + if "last_layers" in name: + parameters.append(param) + + loss_weight_dict = { + "energy": 1.0, + "energy_positions_grad": 1.0, + "energy_grain_grad": 1.0, + } + loss_fn = TensorMapDictLoss(loss_weight_dict) + + llpr_model.compute_covariance_as_pseudo_hessian( + dataloader, target_info, loss_fn, parameters + ) + llpr_model.compute_inverse_covariance() + + exported_model = MetatensorAtomisticModel( + llpr_model.eval(), + ModelMetadata(), + llpr_model.capabilities, + ) + + evaluation_options = ModelEvaluationOptions( + length_unit="angstrom", + outputs={ + "mtt::aux::energy_uncertainty": ModelOutput(per_atom=True), + "energy": ModelOutput(per_atom=True), + "mtt::aux::last_layer_features": ModelOutput(per_atom=True), + }, + selected_atoms=None, + ) + + outputs = exported_model( + qm9_systems[:5], evaluation_options, check_consistency=True + ) + + assert "mtt::aux::energy_uncertainty" in outputs + assert "energy" in outputs + assert "mtt::aux::last_layer_features" in outputs + + assert outputs["mtt::aux::energy_uncertainty"].block().samples.names == [ + "system", + "atom", + ] + assert outputs["energy"].block().samples.names == ["system", "atom"] + assert outputs["mtt::aux::last_layer_features"].block().samples.names == [ + "system", + "atom", + ] + + # Now test the ensemble approach + params = [] # One per element, SOAP-BPNN + for name, param in llpr_model.model.named_parameters(): + if "last_layers" in name and "energy" in name: + params.append(param.squeeze()) + weights = torch.cat(params) + + n_ensemble_members = 10000 + llpr_model.calibrate(dataloader) + llpr_model.generate_ensemble({"energy": weights}, n_ensemble_members) + assert "mtt::energy_ensemble" in llpr_model.capabilities.outputs + + exported_model = MetatensorAtomisticModel( + llpr_model.eval(), + ModelMetadata(), + llpr_model.capabilities, + ) + + exported_model.save( + file=str(tmpdir / "llpr_model.pt"), + collect_extensions=str(tmpdir / "extensions"), + ) + llpr_model = load_atomistic_model( + str(tmpdir / "llpr_model.pt"), extensions_directory=str(tmpdir / "extensions") + ) + + evaluation_options = ModelEvaluationOptions( + length_unit="angstrom", + outputs={ + "mtt::aux::energy_uncertainty": ModelOutput(per_atom=False), + "mtt::energy_ensemble": ModelOutput(per_atom=False), + }, + selected_atoms=None, + ) + outputs = exported_model( + qm9_systems[:5], evaluation_options, check_consistency=True + ) + + assert "mtt::aux::energy_uncertainty" in outputs + assert "mtt::energy_ensemble" in outputs + + analytical_uncertainty = outputs["mtt::aux::energy_uncertainty"].block().values + ensemble_uncertainty = torch.var( + outputs["mtt::energy_ensemble"].block().values, dim=1, keepdim=True + ) + + torch.testing.assert_close( + analytical_uncertainty, ensemble_uncertainty, rtol=1e-2, atol=1e-2 + )